diff --git a/medcat-den/src/medcat_den/wrappers.py b/medcat-den/src/medcat_den/wrappers.py index cfb2d65f2..c3c76981c 100644 --- a/medcat-den/src/medcat_den/wrappers.py +++ b/medcat-den/src/medcat_den/wrappers.py @@ -135,7 +135,7 @@ def load_model_pack(cls, model_pack_path: str, class WrappedTrainer(Trainer): def __init__(self, den_cnf: DenConfig, delegate: Trainer): - super().__init__(delegate.cdb, delegate.caller, delegate._pipeline) + super().__init__(delegate.cdb, delegate._pipeline) self._den_cnf = den_cnf def train_supervised_raw( diff --git a/medcat-den/tests/injection/test_medcat_injection.py b/medcat-den/tests/injection/test_medcat_injection.py index 8806f7a94..47cb09261 100644 --- a/medcat-den/tests/injection/test_medcat_injection.py +++ b/medcat-den/tests/injection/test_medcat_injection.py @@ -35,7 +35,7 @@ def test_calls_injected_method(): "medcat_den.injection.medcat_injector.injected_load_model_pack" ) as mock_load_model_pack: with medcat_injector.injected_den(): - CAT.load_model_pack("SOME MODEL") + _load_model_pack("SOME MODEL") mock_load_model_pack.assert_called_once() @@ -86,14 +86,14 @@ def den_with_nonempty_model(den_with_model: Den): def test_can_load_model(den_with_model: Den): model_hash = den_with_model.list_available_models()[0].model_id with medcat_injector.injected_den(den_getter=lambda: den_with_model): - cat = CAT.load_model_pack(model_hash) + cat = _load_model_pack(model_hash) assert cat.config.meta.hash == model_hash def test_no_prefix_cannot_load_from_disk(den_with_model: Den): with medcat_injector.injected_den(den_getter=lambda: den_with_model): with pytest.raises(ValueError): - CAT.load_model_pack(EXAMPLE_MODEL_PATH) + _load_model_pack(EXAMPLE_MODEL_PATH) def test_can_load_model_with_mappings(den_with_model: Den): @@ -102,14 +102,14 @@ def test_can_load_model_with_mappings(den_with_model: Den): name_map = {model_name: model_hash} with medcat_injector.injected_den(den_getter=lambda: den_with_model, model_name_mapper=name_map): - cat = CAT.load_model_pack(model_name) + cat = _load_model_pack(model_name) assert cat.config.meta.hash == model_hash def test_with_prefix_can_load_from_disk(den_with_model: Den): with medcat_injector.injected_den(den_getter=lambda: den_with_model, prefix="DEN:"): - cat = CAT.load_model_pack(EXAMPLE_MODEL_PATH) + cat = _load_model_pack(EXAMPLE_MODEL_PATH) assert isinstance(cat, CAT) @@ -117,7 +117,7 @@ def test_with_prefix_can_load_from_den(den_with_model: Den): model_hash = den_with_model.list_available_base_models()[0].model_id with medcat_injector.injected_den(den_getter=lambda: den_with_model, prefix="DEN:"): - cat = CAT.load_model_pack(f"DEN:{model_hash}") + cat = _load_model_pack(f"DEN:{model_hash}") assert isinstance(cat, CAT) @@ -128,7 +128,7 @@ def test_with_prefix_can_load_from_den_with_mapping(den_with_model: Den): with medcat_injector.injected_den(den_getter=lambda: den_with_model, prefix="DEN:", model_name_mapper=name_map): - cat = CAT.load_model_pack(f"DEN:{model_hash}") + cat = _load_model_pack(f"DEN:{model_hash}") assert isinstance(cat, CAT) @@ -146,6 +146,12 @@ def test_den_has_model_with_data(den_with_nonempty_model: Den): assert model.cdb.name2info +def _load_model_pack(name_or_hash: str) -> CAT: + cat = CAT.load_model_pack(name_or_hash) + cat.config.components.linking.train = False + return cat + + def _helper_can_save_model(den: Den, saver: Callable[[CAT, str], None]): message = "Made Some Good Changes" base_model_info = den.list_available_base_models()[0] @@ -155,7 +161,7 @@ def _helper_can_save_model(den: Den, saver: Callable[[CAT, str], None]): with medcat_injector.injected_den( den_getter=lambda: den, inject_save=True): - cat = CAT.load_model_pack(example_hash) + cat = _load_model_pack(example_hash) cat.trainer.train_unsupervised([ "Well, acute kidney failure never gets old, does it?", "What does a chronic epileptic fit even mean?"])