Skip to content
Open
73 changes: 65 additions & 8 deletions langextract/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,39 @@
from langextract.core import tokenizer as tokenizer_lib


def _filter_ungrounded_extractions(
extractions: list[data.Extraction] | None,
) -> list[data.Extraction]:
"""Filters out extractions that are not grounded in the source text.

An extraction is considered ungrounded if its char_interval is None or
if either start_pos or end_pos is None. This typically indicates that
the extraction could not be located in the source document, which may
occur when the LLM extracts content from few-shot examples rather than
the actual input text.

Args:
extractions: List of extractions to filter.

Returns:
List of extractions with valid char_interval positions.
"""
if not extractions:
return []

grounded = []
for extraction in extractions:
if extraction.char_interval is None:
continue
if extraction.char_interval.start_pos is None:
continue
if extraction.char_interval.end_pos is None:
continue
grounded.append(extraction)

return grounded


def extract(
text_or_documents: str | Iterable[data.Document],
prompt_description: str | None = None,
Expand All @@ -53,7 +86,6 @@ def extract(
debug: bool = False,
model_url: str | None = None,
extraction_passes: int = 1,
context_window_chars: int | None = None,
config: typing.Any = None,
model: typing.Any = None,
*,
Expand All @@ -62,6 +94,7 @@ def extract(
prompt_validation_strict: bool = False,
show_progress: bool = True,
tokenizer: tokenizer_lib.Tokenizer | None = None,
require_grounding: bool = False,
) -> list[data.AnnotatedDocument] | data.AnnotatedDocument:
"""Extracts structured information from text.

Expand Down Expand Up @@ -145,10 +178,6 @@ def extract(
for overlaps). WARNING: Each additional pass reprocesses tokens,
potentially increasing API costs. For example, extraction_passes=3
reprocesses tokens 3x.
context_window_chars: Number of characters from the previous chunk to
include as context for the current chunk. This helps with coreference
resolution across chunk boundaries (e.g., resolving "She" to a person
mentioned in the previous chunk). Defaults to None (disabled).
config: Model configuration to use for extraction. Takes precedence over
model_id, api_key, and language_model_type parameters. When both model
and config are provided, model takes precedence.
Expand All @@ -165,6 +194,12 @@ def extract(
prompt_validation_strict: When True and prompt_validation_level is ERROR,
raises on non-exact matches (MATCH_FUZZY, MATCH_LESSER). Defaults to False.
show_progress: Whether to show progress bar during extraction. Defaults to True.
require_grounding: Whether to filter out extractions that cannot be grounded
to specific character positions in the source text. When True, only
extractions with valid char_interval (non-None start_pos and end_pos)
are returned. This helps prevent returning extractions that may have been
extracted from few-shot examples rather than the actual input text.
Defaults to False for backward compatibility.

Returns:
An AnnotatedDocument with the extracted information when input is a
Expand Down Expand Up @@ -353,14 +388,24 @@ def extract(
additional_context=additional_context,
debug=debug,
extraction_passes=extraction_passes,
context_window_chars=context_window_chars,
show_progress=show_progress,
max_workers=max_workers,
tokenizer=tokenizer,
**alignment_kwargs,
)

# Filter ungrounded extractions if requested
if require_grounding:
result = data.AnnotatedDocument(
document_id=result.document_id,
extractions=_filter_ungrounded_extractions(result.extractions),
text=result.text,
)

return result
else:
documents = cast(Iterable[data.Document], text_or_documents)
results = annotator.annotate_documents(
if additional_context is not None:
documents = (
doc.with_additional_context(additional_context)
Expand All @@ -377,10 +422,22 @@ def extract(
batch_length=batch_length,
debug=debug,
extraction_passes=extraction_passes,
context_window_chars=context_window_chars,
show_progress=show_progress,
max_workers=max_workers,
tokenizer=tokenizer,
**alignment_kwargs,
)
return list(result)

# Filter ungrounded extractions if requested
if require_grounding:
filtered_results = []
for doc in results:
filtered_doc = data.AnnotatedDocument(
document_id=doc.document_id,
extractions=_filter_ungrounded_extractions(doc.extractions),
text=doc.text,
)
filtered_results.append(filtered_doc)
return filtered_results

return list(results)
152 changes: 152 additions & 0 deletions tests/test_require_grounding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright 2025 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for require_grounding parameter in extraction.py."""

import unittest

from langextract.core import data
from langextract.extraction import _filter_ungrounded_extractions


class FilterUngroundedExtractionsTest(unittest.TestCase):
"""Tests for _filter_ungrounded_extractions function."""

def test_empty_list(self):
"""Returns empty list for empty input."""
result = _filter_ungrounded_extractions([])
self.assertEqual(result, [])

def test_none_input(self):
"""Returns empty list for None input."""
result = _filter_ungrounded_extractions(None)
self.assertEqual(result, [])

def test_filters_none_char_interval(self):
"""Filters out extractions with None char_interval."""
extractions = [
data.Extraction(
extraction_class="test",
extraction_text="grounded text",
char_interval=data.CharInterval(start_pos=0, end_pos=13),
),
data.Extraction(
extraction_class="test",
extraction_text="ungrounded text",
char_interval=None,
),
]
result = _filter_ungrounded_extractions(extractions)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].extraction_text, "grounded text")

def test_filters_none_start_pos(self):
"""Filters out extractions with None start_pos."""
extractions = [
data.Extraction(
extraction_class="test",
extraction_text="grounded text",
char_interval=data.CharInterval(start_pos=0, end_pos=13),
),
data.Extraction(
extraction_class="test",
extraction_text="ungrounded text",
char_interval=data.CharInterval(start_pos=None, end_pos=15),
),
]
result = _filter_ungrounded_extractions(extractions)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].extraction_text, "grounded text")

def test_filters_none_end_pos(self):
"""Filters out extractions with None end_pos."""
extractions = [
data.Extraction(
extraction_class="test",
extraction_text="grounded text",
char_interval=data.CharInterval(start_pos=0, end_pos=13),
),
data.Extraction(
extraction_class="test",
extraction_text="ungrounded text",
char_interval=data.CharInterval(start_pos=0, end_pos=None),
),
]
result = _filter_ungrounded_extractions(extractions)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].extraction_text, "grounded text")

def test_keeps_all_grounded(self):
"""Keeps all extractions when all are grounded."""
extractions = [
data.Extraction(
extraction_class="test",
extraction_text="first",
char_interval=data.CharInterval(start_pos=0, end_pos=5),
),
data.Extraction(
extraction_class="test",
extraction_text="second",
char_interval=data.CharInterval(start_pos=10, end_pos=16),
),
]
result = _filter_ungrounded_extractions(extractions)
self.assertEqual(len(result), 2)

def test_filters_all_ungrounded(self):
"""Returns empty list when all extractions are ungrounded."""
extractions = [
data.Extraction(
extraction_class="test",
extraction_text="ungrounded1",
char_interval=None,
),
data.Extraction(
extraction_class="test",
extraction_text="ungrounded2",
char_interval=data.CharInterval(start_pos=None, end_pos=None),
),
]
result = _filter_ungrounded_extractions(extractions)
self.assertEqual(len(result), 0)

def test_preserves_extraction_attributes(self):
"""Preserves all attributes of grounded extractions."""
extractions = [
data.Extraction(
extraction_class="medication",
extraction_text="aspirin",
char_interval=data.CharInterval(start_pos=10, end_pos=17),
alignment_status=data.AlignmentStatus.MATCH_EXACT,
extraction_index=1,
group_index=0,
description="A medication",
attributes={"dosage": "100mg"},
),
]
result = _filter_ungrounded_extractions(extractions)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].extraction_class, "medication")
self.assertEqual(result[0].extraction_text, "aspirin")
self.assertEqual(
result[0].alignment_status, data.AlignmentStatus.MATCH_EXACT
)
self.assertEqual(result[0].extraction_index, 1)
self.assertEqual(result[0].group_index, 0)
self.assertEqual(result[0].description, "A medication")
self.assertEqual(result[0].attributes, {"dosage": "100mg"})


if __name__ == "__main__":
unittest.main()
Loading