diff --git a/langextract/resolver.py b/langextract/resolver.py index 198470f8..45204d69 100644 --- a/langextract/resolver.py +++ b/langextract/resolver.py @@ -126,9 +126,9 @@ def align( source_text: str, token_offset: int, char_offset: int | None = None, - enable_fuzzy_alignment: bool = True, - fuzzy_alignment_threshold: float = _FUZZY_ALIGNMENT_MIN_THRESHOLD, - accept_match_lesser: bool = True, + enable_fuzzy_alignment: bool | None = None, + fuzzy_alignment_threshold: float | None = None, + accept_match_lesser: bool | None = None, **kwargs, ) -> Iterator[data.Extraction]: """Aligns extractions with source text, setting token/char intervals and alignment status. @@ -216,6 +216,14 @@ def __init__( else: format_handler = fh.FormatHandler() + # Consume alignment parameters + self.enable_fuzzy_alignment = kwargs.pop("enable_fuzzy_alignment", True) + self.fuzzy_alignment_threshold = kwargs.pop( + "fuzzy_alignment_threshold", _FUZZY_ALIGNMENT_MIN_THRESHOLD + ) + self.accept_match_lesser = kwargs.pop("accept_match_lesser", True) + self.suppress_parse_errors = kwargs.pop("suppress_parse_errors", False) + if kwargs: raise TypeError( f"got an unexpected keyword argument '{list(kwargs.keys())[0]}'" @@ -234,7 +242,7 @@ def __init__( def resolve( self, input_text: str, - suppress_parse_errors: bool = False, + suppress_parse_errors: bool | None = None, **kwargs, ) -> Sequence[data.Extraction]: """Runs resolve function on text with YAML/JSON extraction data. @@ -254,6 +262,9 @@ def resolve( logging.debug("Starting resolver process for input text.") logging.debug("Input Text: %s", input_text) + if suppress_parse_errors is None: + suppress_parse_errors = getattr(self, "suppress_parse_errors", False) + try: constraint = getattr(self, "_constraint", schema.Constraint()) strict = getattr(constraint, "strict", False) @@ -282,9 +293,9 @@ def align( source_text: str, token_offset: int, char_offset: int | None = None, - enable_fuzzy_alignment: bool = True, - fuzzy_alignment_threshold: float = _FUZZY_ALIGNMENT_MIN_THRESHOLD, - accept_match_lesser: bool = True, + enable_fuzzy_alignment: bool | None = None, + fuzzy_alignment_threshold: float | None = None, + accept_match_lesser: bool | None = None, tokenizer_inst: tokenizer_lib.Tokenizer | None = None, **kwargs, ) -> Iterator[data.Extraction]: @@ -302,10 +313,11 @@ def align( token_offset: The starting token index of the chunk. char_offset: The starting character index of the chunk. enable_fuzzy_alignment: Whether to enable fuzzy alignment fallback. + Defaults to instance setting or True. fuzzy_alignment_threshold: Minimum overlap ratio required for fuzzy - alignment. + alignment. Defaults to instance setting or 0.75. accept_match_lesser: Whether to accept partial exact matches (MATCH_LESSER - status). + status). Defaults to instance setting or True. tokenizer_inst: Optional tokenizer instance. **kwargs: Additional parameters. @@ -323,6 +335,17 @@ def align( else: extractions_group = [extractions] + if enable_fuzzy_alignment is None: + enable_fuzzy_alignment = getattr(self, "enable_fuzzy_alignment", True) + + if fuzzy_alignment_threshold is None: + fuzzy_alignment_threshold = getattr( + self, "fuzzy_alignment_threshold", _FUZZY_ALIGNMENT_MIN_THRESHOLD + ) + + if accept_match_lesser is None: + accept_match_lesser = getattr(self, "accept_match_lesser", True) + aligner = WordAligner() aligned_yaml_extractions = aligner.align_extractions( extractions_group, @@ -692,11 +715,11 @@ def align_extractions( character intervals. delim: Token used to separate multi-token extractions. enable_fuzzy_alignment: Whether to use fuzzy alignment when exact matching - fails. + fails. Defaults to instance setting or True. fuzzy_alignment_threshold: Minimum token overlap ratio for fuzzy alignment - (0-1). + (0-1). Defaults to instance setting or 0.75. accept_match_lesser: Whether to accept partial exact matches (MATCH_LESSER - status). + status). Defaults to instance setting or True. tokenizer_impl: Optional tokenizer instance. Returns: diff --git a/tests/resolver_params_test.py b/tests/resolver_params_test.py new file mode 100644 index 00000000..c3b24520 --- /dev/null +++ b/tests/resolver_params_test.py @@ -0,0 +1,73 @@ +import unittest +from langextract import resolver as resolver_lib +from langextract.core import data +from langextract.core import tokenizer + +class ResolverParamsTest(unittest.TestCase): + + def test_resolver_accepts_alignment_params(self): + # This should currently raise TypeError + resolver = resolver_lib.Resolver( + fuzzy_alignment_threshold=0.6, + enable_fuzzy_alignment=True + ) + self.assertEqual(resolver.fuzzy_alignment_threshold, 0.6) + self.assertTrue(resolver.enable_fuzzy_alignment) + + def test_align_uses_instance_threshold(self): + # Setup a case where default threshold (0.75) fails but 0.6 succeeds + # Extraction: "headache and fever" (3 tokens) + # Source: "Patient reports back pain and a fever." + # Intersection: "and", "fever" (2 tokens) -> 2/3 = 0.66 + + resolver = resolver_lib.Resolver( + fuzzy_alignment_threshold=0.6, + enable_fuzzy_alignment=True, + accept_match_lesser=False + ) + + extractions = [ + data.Extraction( + extraction_class="symptom", + extraction_text="headache and fever" + ) + ] + source_text = "Patient reports back pain and a fever." + + aligned = list(resolver.align( + extractions, + source_text, + token_offset=0 + )) + + self.assertEqual(len(aligned), 1) + # Should match because 0.66 >= 0.6 + self.assertEqual(aligned[0].alignment_status, data.AlignmentStatus.MATCH_FUZZY) + + def test_align_uses_instance_threshold_fail(self): + # Same case, but threshold 0.8 -> should fail + resolver = resolver_lib.Resolver( + fuzzy_alignment_threshold=0.8, + enable_fuzzy_alignment=True, + accept_match_lesser=False + ) + + extractions = [ + data.Extraction( + extraction_class="symptom", + extraction_text="headache and fever" + ) + ] + source_text = "Patient reports back pain and a fever." + + aligned = list(resolver.align( + extractions, + source_text, + token_offset=0 + )) + + # Should not match + self.assertIsNone(aligned[0].alignment_status) + +if __name__ == '__main__': + unittest.main()