Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion medcat-den/src/medcat_den/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def load_model_pack(cls, model_pack_path: str,
class WrappedTrainer(Trainer):

def __init__(self, den_cnf: DenConfig, delegate: Trainer):
super().__init__(delegate.cdb, delegate.caller, delegate._pipeline)
super().__init__(delegate.cdb, delegate._pipeline)
self._den_cnf = den_cnf

def train_supervised_raw(
Expand Down
22 changes: 14 additions & 8 deletions medcat-den/tests/injection/test_medcat_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_calls_injected_method():
"medcat_den.injection.medcat_injector.injected_load_model_pack"
) as mock_load_model_pack:
with medcat_injector.injected_den():
CAT.load_model_pack("SOME MODEL")
_load_model_pack("SOME MODEL")
mock_load_model_pack.assert_called_once()


Expand Down Expand Up @@ -86,14 +86,14 @@ def den_with_nonempty_model(den_with_model: Den):
def test_can_load_model(den_with_model: Den):
model_hash = den_with_model.list_available_models()[0].model_id
with medcat_injector.injected_den(den_getter=lambda: den_with_model):
cat = CAT.load_model_pack(model_hash)
cat = _load_model_pack(model_hash)
assert cat.config.meta.hash == model_hash


def test_no_prefix_cannot_load_from_disk(den_with_model: Den):
with medcat_injector.injected_den(den_getter=lambda: den_with_model):
with pytest.raises(ValueError):
CAT.load_model_pack(EXAMPLE_MODEL_PATH)
_load_model_pack(EXAMPLE_MODEL_PATH)


def test_can_load_model_with_mappings(den_with_model: Den):
Expand All @@ -102,22 +102,22 @@ def test_can_load_model_with_mappings(den_with_model: Den):
name_map = {model_name: model_hash}
with medcat_injector.injected_den(den_getter=lambda: den_with_model,
model_name_mapper=name_map):
cat = CAT.load_model_pack(model_name)
cat = _load_model_pack(model_name)
assert cat.config.meta.hash == model_hash


def test_with_prefix_can_load_from_disk(den_with_model: Den):
with medcat_injector.injected_den(den_getter=lambda: den_with_model,
prefix="DEN:"):
cat = CAT.load_model_pack(EXAMPLE_MODEL_PATH)
cat = _load_model_pack(EXAMPLE_MODEL_PATH)
assert isinstance(cat, CAT)


def test_with_prefix_can_load_from_den(den_with_model: Den):
model_hash = den_with_model.list_available_base_models()[0].model_id
with medcat_injector.injected_den(den_getter=lambda: den_with_model,
prefix="DEN:"):
cat = CAT.load_model_pack(f"DEN:{model_hash}")
cat = _load_model_pack(f"DEN:{model_hash}")
assert isinstance(cat, CAT)


Expand All @@ -128,7 +128,7 @@ def test_with_prefix_can_load_from_den_with_mapping(den_with_model: Den):
with medcat_injector.injected_den(den_getter=lambda: den_with_model,
prefix="DEN:",
model_name_mapper=name_map):
cat = CAT.load_model_pack(f"DEN:{model_hash}")
cat = _load_model_pack(f"DEN:{model_hash}")
assert isinstance(cat, CAT)


Expand All @@ -146,6 +146,12 @@ def test_den_has_model_with_data(den_with_nonempty_model: Den):
assert model.cdb.name2info


def _load_model_pack(name_or_hash: str) -> CAT:
cat = CAT.load_model_pack(name_or_hash)
cat.config.components.linking.train = False
return cat


def _helper_can_save_model(den: Den, saver: Callable[[CAT, str], None]):
message = "Made Some Good Changes"
base_model_info = den.list_available_base_models()[0]
Expand All @@ -155,7 +161,7 @@ def _helper_can_save_model(den: Den, saver: Callable[[CAT, str], None]):
with medcat_injector.injected_den(
den_getter=lambda: den,
inject_save=True):
cat = CAT.load_model_pack(example_hash)
cat = _load_model_pack(example_hash)
cat.trainer.train_unsupervised([
"Well, acute kidney failure never gets old, does it?",
"What does a chronic epileptic fit even mean?"])
Expand Down
Loading