diff --git a/cellpose/gui/gui.py b/cellpose/gui/gui.py index 99c7fcd7..a61cd7a7 100644 --- a/cellpose/gui/gui.py +++ b/cellpose/gui/gui.py @@ -1770,7 +1770,7 @@ def initialize_model(self, model_name=None, custom=False): if model_name is None or custom: self.get_model_path(custom=custom) if not os.path.exists(self.current_model_path): - raise ValueError("need to specify model (use dropdown)") + raise ValueError("Model file not found: need to specify model (use dropdown)") if model_name is None or not isinstance(model_name, str): self.model = models.CellposeModel(gpu=self.useGPU.isChecked(), @@ -1867,7 +1867,7 @@ def compute_cprob(self): self.logger.error("Flows don't exist, try running model again.") return - maski = dynamics.resize_and_compute_masks( + maski = dynamics.compute_masks_and_clean( dP=dP, cellprob=cellprob, niter=niter, diff --git a/cellpose/vit_sam.py b/cellpose/vit_sam.py index 7e82bd18..7d93378a 100644 --- a/cellpose/vit_sam.py +++ b/cellpose/vit_sam.py @@ -81,9 +81,15 @@ def forward(self, x): return x1, torch.zeros((x.shape[0], 256), device=x.device) - def load_model(self, PATH, device, strict = False): + def load_model(self, PATH, device, strict = False): state_dict = torch.load(PATH, map_location = device, weights_only=True) keys = [k for k in state_dict.keys()] + + # loudly fail on attempt to load not cp4 model: + w2_data = state_dict.get('W2', None) + if w2_data == None: + raise ValueError('This model does not appear to be a CP4 model. CP3 models are not compatible with CP4.') + if keys[0][:7] == "module.": from collections import OrderedDict new_state_dict = OrderedDict() diff --git a/tests/test_import.py b/tests/test_import.py index 6407d116..2f4c709f 100644 --- a/tests/test_import.py +++ b/tests/test_import.py @@ -1,3 +1,6 @@ +import pytest + + def test_cellpose_imports_without_error(): import cellpose from cellpose import models, core @@ -28,3 +31,18 @@ def itest_model_dir(): model = models.CellposeModel(pretrained_model='cpsam') masks = model.eval(np.random.randn(256, 256))[0] assert masks.shape == (256, 256) + + +def test_load_cp3_fail(): + from cellpose.models import CellposeModel, MODEL_DIR + from cellpose import utils + + cyto3_model_path = (MODEL_DIR / 'cyto3').absolute() + + if not cyto3_model_path.exists(): + url = 'https://www.cellpose.org/models/cyto3' + utils.download_url_to_file(url, cyto3_model_path, progress=False) + + with pytest.raises(ValueError): + # using `pretrained_model=cyto3` just loads the cpsam model unless the path is given + model = CellposeModel(pretrained_model=str(cyto3_model_path)) \ No newline at end of file