diff --git a/langextract/extraction.py b/langextract/extraction.py index 71cb17c3..e5856691 100644 --- a/langextract/extraction.py +++ b/langextract/extraction.py @@ -33,6 +33,39 @@ from langextract.core import tokenizer as tokenizer_lib +def _filter_ungrounded_extractions( + extractions: list[data.Extraction] | None, +) -> list[data.Extraction]: + """Filters out extractions that are not grounded in the source text. + + An extraction is considered ungrounded if its char_interval is None or + if either start_pos or end_pos is None. This typically indicates that + the extraction could not be located in the source document, which may + occur when the LLM extracts content from few-shot examples rather than + the actual input text. + + Args: + extractions: List of extractions to filter. + + Returns: + List of extractions with valid char_interval positions. + """ + if not extractions: + return [] + + grounded = [] + for extraction in extractions: + if extraction.char_interval is None: + continue + if extraction.char_interval.start_pos is None: + continue + if extraction.char_interval.end_pos is None: + continue + grounded.append(extraction) + + return grounded + + def extract( text_or_documents: str | Iterable[data.Document], prompt_description: str | None = None, @@ -53,7 +86,6 @@ def extract( debug: bool = False, model_url: str | None = None, extraction_passes: int = 1, - context_window_chars: int | None = None, config: typing.Any = None, model: typing.Any = None, *, @@ -62,6 +94,7 @@ def extract( prompt_validation_strict: bool = False, show_progress: bool = True, tokenizer: tokenizer_lib.Tokenizer | None = None, + require_grounding: bool = False, ) -> list[data.AnnotatedDocument] | data.AnnotatedDocument: """Extracts structured information from text. @@ -145,10 +178,6 @@ def extract( for overlaps). WARNING: Each additional pass reprocesses tokens, potentially increasing API costs. For example, extraction_passes=3 reprocesses tokens 3x. - context_window_chars: Number of characters from the previous chunk to - include as context for the current chunk. This helps with coreference - resolution across chunk boundaries (e.g., resolving "She" to a person - mentioned in the previous chunk). Defaults to None (disabled). config: Model configuration to use for extraction. Takes precedence over model_id, api_key, and language_model_type parameters. When both model and config are provided, model takes precedence. @@ -165,6 +194,12 @@ def extract( prompt_validation_strict: When True and prompt_validation_level is ERROR, raises on non-exact matches (MATCH_FUZZY, MATCH_LESSER). Defaults to False. show_progress: Whether to show progress bar during extraction. Defaults to True. + require_grounding: Whether to filter out extractions that cannot be grounded + to specific character positions in the source text. When True, only + extractions with valid char_interval (non-None start_pos and end_pos) + are returned. This helps prevent returning extractions that may have been + extracted from few-shot examples rather than the actual input text. + Defaults to False for backward compatibility. Returns: An AnnotatedDocument with the extracted information when input is a @@ -353,14 +388,24 @@ def extract( additional_context=additional_context, debug=debug, extraction_passes=extraction_passes, - context_window_chars=context_window_chars, show_progress=show_progress, max_workers=max_workers, tokenizer=tokenizer, **alignment_kwargs, ) + + # Filter ungrounded extractions if requested + if require_grounding: + result = data.AnnotatedDocument( + document_id=result.document_id, + extractions=_filter_ungrounded_extractions(result.extractions), + text=result.text, + ) + return result else: + documents = cast(Iterable[data.Document], text_or_documents) + results = annotator.annotate_documents( if additional_context is not None: documents = ( doc.with_additional_context(additional_context) @@ -377,10 +422,22 @@ def extract( batch_length=batch_length, debug=debug, extraction_passes=extraction_passes, - context_window_chars=context_window_chars, show_progress=show_progress, max_workers=max_workers, tokenizer=tokenizer, **alignment_kwargs, ) - return list(result) + + # Filter ungrounded extractions if requested + if require_grounding: + filtered_results = [] + for doc in results: + filtered_doc = data.AnnotatedDocument( + document_id=doc.document_id, + extractions=_filter_ungrounded_extractions(doc.extractions), + text=doc.text, + ) + filtered_results.append(filtered_doc) + return filtered_results + + return list(results) diff --git a/tests/test_require_grounding.py b/tests/test_require_grounding.py new file mode 100644 index 00000000..8eedcf3b --- /dev/null +++ b/tests/test_require_grounding.py @@ -0,0 +1,152 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for require_grounding parameter in extraction.py.""" + +import unittest + +from langextract.core import data +from langextract.extraction import _filter_ungrounded_extractions + + +class FilterUngroundedExtractionsTest(unittest.TestCase): + """Tests for _filter_ungrounded_extractions function.""" + + def test_empty_list(self): + """Returns empty list for empty input.""" + result = _filter_ungrounded_extractions([]) + self.assertEqual(result, []) + + def test_none_input(self): + """Returns empty list for None input.""" + result = _filter_ungrounded_extractions(None) + self.assertEqual(result, []) + + def test_filters_none_char_interval(self): + """Filters out extractions with None char_interval.""" + extractions = [ + data.Extraction( + extraction_class="test", + extraction_text="grounded text", + char_interval=data.CharInterval(start_pos=0, end_pos=13), + ), + data.Extraction( + extraction_class="test", + extraction_text="ungrounded text", + char_interval=None, + ), + ] + result = _filter_ungrounded_extractions(extractions) + self.assertEqual(len(result), 1) + self.assertEqual(result[0].extraction_text, "grounded text") + + def test_filters_none_start_pos(self): + """Filters out extractions with None start_pos.""" + extractions = [ + data.Extraction( + extraction_class="test", + extraction_text="grounded text", + char_interval=data.CharInterval(start_pos=0, end_pos=13), + ), + data.Extraction( + extraction_class="test", + extraction_text="ungrounded text", + char_interval=data.CharInterval(start_pos=None, end_pos=15), + ), + ] + result = _filter_ungrounded_extractions(extractions) + self.assertEqual(len(result), 1) + self.assertEqual(result[0].extraction_text, "grounded text") + + def test_filters_none_end_pos(self): + """Filters out extractions with None end_pos.""" + extractions = [ + data.Extraction( + extraction_class="test", + extraction_text="grounded text", + char_interval=data.CharInterval(start_pos=0, end_pos=13), + ), + data.Extraction( + extraction_class="test", + extraction_text="ungrounded text", + char_interval=data.CharInterval(start_pos=0, end_pos=None), + ), + ] + result = _filter_ungrounded_extractions(extractions) + self.assertEqual(len(result), 1) + self.assertEqual(result[0].extraction_text, "grounded text") + + def test_keeps_all_grounded(self): + """Keeps all extractions when all are grounded.""" + extractions = [ + data.Extraction( + extraction_class="test", + extraction_text="first", + char_interval=data.CharInterval(start_pos=0, end_pos=5), + ), + data.Extraction( + extraction_class="test", + extraction_text="second", + char_interval=data.CharInterval(start_pos=10, end_pos=16), + ), + ] + result = _filter_ungrounded_extractions(extractions) + self.assertEqual(len(result), 2) + + def test_filters_all_ungrounded(self): + """Returns empty list when all extractions are ungrounded.""" + extractions = [ + data.Extraction( + extraction_class="test", + extraction_text="ungrounded1", + char_interval=None, + ), + data.Extraction( + extraction_class="test", + extraction_text="ungrounded2", + char_interval=data.CharInterval(start_pos=None, end_pos=None), + ), + ] + result = _filter_ungrounded_extractions(extractions) + self.assertEqual(len(result), 0) + + def test_preserves_extraction_attributes(self): + """Preserves all attributes of grounded extractions.""" + extractions = [ + data.Extraction( + extraction_class="medication", + extraction_text="aspirin", + char_interval=data.CharInterval(start_pos=10, end_pos=17), + alignment_status=data.AlignmentStatus.MATCH_EXACT, + extraction_index=1, + group_index=0, + description="A medication", + attributes={"dosage": "100mg"}, + ), + ] + result = _filter_ungrounded_extractions(extractions) + self.assertEqual(len(result), 1) + self.assertEqual(result[0].extraction_class, "medication") + self.assertEqual(result[0].extraction_text, "aspirin") + self.assertEqual( + result[0].alignment_status, data.AlignmentStatus.MATCH_EXACT + ) + self.assertEqual(result[0].extraction_index, 1) + self.assertEqual(result[0].group_index, 0) + self.assertEqual(result[0].description, "A medication") + self.assertEqual(result[0].attributes, {"dosage": "100mg"}) + + +if __name__ == "__main__": + unittest.main()