Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cellpose/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1770,7 +1770,7 @@ def initialize_model(self, model_name=None, custom=False):
if model_name is None or custom:
self.get_model_path(custom=custom)
if not os.path.exists(self.current_model_path):
raise ValueError("need to specify model (use dropdown)")
raise ValueError("Model file not found: need to specify model (use dropdown)")

if model_name is None or not isinstance(model_name, str):
self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
Expand Down Expand Up @@ -1867,7 +1867,7 @@ def compute_cprob(self):
self.logger.error("Flows don't exist, try running model again.")
return

maski = dynamics.resize_and_compute_masks(
maski = dynamics.compute_masks_and_clean(
dP=dP,
cellprob=cellprob,
niter=niter,
Expand Down
8 changes: 7 additions & 1 deletion cellpose/vit_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,15 @@ def forward(self, x):

return x1, torch.zeros((x.shape[0], 256), device=x.device)

def load_model(self, PATH, device, strict = False):
def load_model(self, PATH, device, strict = False):
state_dict = torch.load(PATH, map_location = device, weights_only=True)
keys = [k for k in state_dict.keys()]

# loudly fail on attempt to load not cp4 model:
w2_data = state_dict.get('W2', None)
if w2_data == None:
raise ValueError('This model does not appear to be a CP4 model. CP3 models are not compatible with CP4.')

if keys[0][:7] == "module.":
from collections import OrderedDict
new_state_dict = OrderedDict()
Expand Down
18 changes: 18 additions & 0 deletions tests/test_import.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import pytest


def test_cellpose_imports_without_error():
import cellpose
from cellpose import models, core
Expand Down Expand Up @@ -28,3 +31,18 @@ def itest_model_dir():
model = models.CellposeModel(pretrained_model='cpsam')
masks = model.eval(np.random.randn(256, 256))[0]
assert masks.shape == (256, 256)


def test_load_cp3_fail():
from cellpose.models import CellposeModel, MODEL_DIR
from cellpose import utils

cyto3_model_path = (MODEL_DIR / 'cyto3').absolute()

if not cyto3_model_path.exists():
url = 'https://www.cellpose.org/models/cyto3'
utils.download_url_to_file(url, cyto3_model_path, progress=False)

with pytest.raises(ValueError):
# using `pretrained_model=cyto3` just loads the cpsam model unless the path is given
model = CellposeModel(pretrained_model=str(cyto3_model_path))