From 4b6a8e23ca45acbf461ba463863be8d8b7fa37ac Mon Sep 17 00:00:00 2001 From: yuguerten Date: Wed, 3 Dec 2025 15:01:29 +0100 Subject: [PATCH 1/6] we plot but not synchronized --- cellpose/train.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/cellpose/train.py b/cellpose/train.py index 401c0efc..706cd366 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -314,7 +314,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, n_epochs=100, weight_decay=0.1, normalize=True, compute_flows=False, save_path=None, save_every=100, save_each=False, nimg_per_epoch=None, nimg_test_per_epoch=None, rescale=False, scale_range=None, bsize=256, - min_train_masks=5, model_name=None, class_weights=None): + min_train_masks=5, model_name=None, class_weights=None, + loss_callback=None, return_loss_arrays=True): """ Train the network with images for segmentation. @@ -346,9 +347,11 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, rescale (bool, optional): Boolean - whether or not to rescale images during training. Defaults to False. min_train_masks (int, optional): Integer - minimum number of masks an image must have to use in the training set. Defaults to 5. model_name (str, optional): String - name of the network. Defaults to None. + loss_callback (callable, optional): Function called after each epoch with (epoch, train_loss, test_loss). Defaults to None. + return_loss_arrays (bool, optional): Whether to return full loss arrays or just the model path. Defaults to True. Returns: - tuple: A tuple containing the path to the saved model weights, training losses, and test losses. + tuple or str: If return_loss_arrays=True, returns (path, train_losses, test_losses). If False, returns just the path to saved model weights. """ if SGD: @@ -432,7 +435,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, train_logger.info(f">>> saving model to {filename}") lavg, nsum = 0, 0 - train_losses, test_losses = np.zeros(n_epochs), np.zeros(n_epochs) + train_losses, test_losses = (np.zeros(n_epochs), np.zeros(n_epochs)) if return_loss_arrays else (None, None) for iepoch in range(n_epochs): np.random.seed(iepoch) if nimg != nimg_per_epoch: @@ -481,8 +484,13 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, lavg += train_loss nsum += len(imgi) # per epoch training loss - train_losses[iepoch] += train_loss - train_losses[iepoch] /= nimg_per_epoch + if return_loss_arrays: + train_losses[iepoch] += train_loss + if return_loss_arrays: + train_losses[iepoch] /= nimg_per_epoch + epoch_train_loss = (lavg / nsum) if not return_loss_arrays else (train_losses[iepoch] / nimg_per_epoch if iepoch == 0 else train_losses[iepoch]) + + # TODO: add real time tracking of the loss if iepoch == 5 or iepoch % 10 == 0: lavgt = 0. @@ -523,8 +531,14 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, test_loss *= len(imgi) lavgt += test_loss lavgt /= len(rperm) - test_losses[iepoch] = lavgt + if return_loss_arrays: + test_losses[iepoch] = lavgt lavg /= nsum + + # Call the callback function if provided + if loss_callback is not None: + loss_callback(iepoch, lavg, lavgt) + train_logger.info( f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.6f}, time {time.time()-t0:.2f}s" ) @@ -544,4 +558,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, net.dtype = original_net_dtype net.to(original_net_dtype) - return filename, train_losses, test_losses + if return_loss_arrays: + return filename, train_losses, test_losses + else: + return filename From 697468b5be57a665dfa1ff592e957d69f66b7339 Mon Sep 17 00:00:00 2001 From: yuguerten Date: Thu, 4 Dec 2025 10:27:38 +0100 Subject: [PATCH 2/6] loss synchronized --- cellpose/train.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/cellpose/train.py b/cellpose/train.py index 706cd366..75e6cc41 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -490,6 +490,10 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, train_losses[iepoch] /= nimg_per_epoch epoch_train_loss = (lavg / nsum) if not return_loss_arrays else (train_losses[iepoch] / nimg_per_epoch if iepoch == 0 else train_losses[iepoch]) + # Call the callback function after every epoch with train loss + # Test loss will be computed less frequently (every 10 epochs) + current_train_loss = lavg / nsum + # TODO: add real time tracking of the loss if iepoch == 5 or iepoch % 10 == 0: @@ -533,15 +537,21 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, lavgt /= len(rperm) if return_loss_arrays: test_losses[iepoch] = lavgt - lavg /= nsum - # Call the callback function if provided - if loss_callback is not None: - loss_callback(iepoch, lavg, lavgt) - train_logger.info( - f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.6f}, time {time.time()-t0:.2f}s" + f"{iepoch}, train_loss={current_train_loss:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.6f}, time {time.time()-t0:.2f}s" ) + + # Call the callback function if provided (with test loss) + if loss_callback is not None: + loss_callback(iepoch, current_train_loss, lavgt) + else: + # For epochs without test evaluation, call callback with only train loss + if loss_callback is not None: + loss_callback(iepoch, current_train_loss, None) + + # Reset accumulators after logging + if iepoch == 5 or iepoch % 10 == 0: lavg, nsum = 0, 0 if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0): From 5c635a657acf375ea30c7051bd7e25e41b5448df Mon Sep 17 00:00:00 2001 From: yuguerten Date: Wed, 17 Dec 2025 17:12:41 +0100 Subject: [PATCH 3/6] val loss every epoch --- cellpose/train.py | 100 ++++++++++++++++++++++------------------------ 1 file changed, 47 insertions(+), 53 deletions(-) diff --git a/cellpose/train.py b/cellpose/train.py index 75e6cc41..e47cec78 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -490,65 +490,59 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, train_losses[iepoch] /= nimg_per_epoch epoch_train_loss = (lavg / nsum) if not return_loss_arrays else (train_losses[iepoch] / nimg_per_epoch if iepoch == 0 else train_losses[iepoch]) - # Call the callback function after every epoch with train loss - # Test loss will be computed less frequently (every 10 epochs) + # Compute validation loss every epoch for real-time tracking current_train_loss = lavg / nsum + lavgt = 0. - # TODO: add real time tracking of the loss - + if test_data is not None or test_files is not None: + np.random.seed(42) + if nimg_test != nimg_test_per_epoch: + rperm = np.random.choice(np.arange(0, nimg_test), + size=(nimg_test_per_epoch,), p=test_probs) + else: + rperm = np.random.permutation(np.arange(0, nimg_test)) + for ibatch in range(0, len(rperm), batch_size): + with torch.no_grad(): + net.eval() + inds = rperm[ibatch:ibatch + batch_size] + imgs, lbls = _get_batch(inds, data=test_data, + labels=test_labels, files=test_files, + labels_files=test_labels_files, + **kwargs) + diams = np.array([diam_test[i] for i in inds]) + rsc = diams / net.diam_mean.item() if rescale else np.ones( + len(diams), "float32") + imgi, lbl = random_rotate_and_resize( + imgs, Y=lbls, rescale=rsc, scale_range=scale_range, + xy=(bsize, bsize))[:2] + X = torch.from_numpy(imgi).to(device) + lbl = torch.from_numpy(lbl).to(device) + + if X.dtype != net.dtype: + X = X.to(net.dtype) + lbl = lbl.to(net.dtype) + + y = net(X)[0] + loss = _loss_fn_seg(lbl, y, device) + if y.shape[1] > 3: + loss3 = _loss_fn_class(lbl, y, class_weights=class_weights) + loss += loss3 + test_loss = loss.item() + test_loss *= len(imgi) + lavgt += test_loss + lavgt /= len(rperm) + if return_loss_arrays: + test_losses[iepoch] = lavgt + + # Log every epoch (more frequent logging for real-time tracking) if iepoch == 5 or iepoch % 10 == 0: - lavgt = 0. - if test_data is not None or test_files is not None: - np.random.seed(42) - if nimg_test != nimg_test_per_epoch: - rperm = np.random.choice(np.arange(0, nimg_test), - size=(nimg_test_per_epoch,), p=test_probs) - else: - rperm = np.random.permutation(np.arange(0, nimg_test)) - for ibatch in range(0, len(rperm), batch_size): - with torch.no_grad(): - net.eval() - inds = rperm[ibatch:ibatch + batch_size] - imgs, lbls = _get_batch(inds, data=test_data, - labels=test_labels, files=test_files, - labels_files=test_labels_files, - **kwargs) - diams = np.array([diam_test[i] for i in inds]) - rsc = diams / net.diam_mean.item() if rescale else np.ones( - len(diams), "float32") - imgi, lbl = random_rotate_and_resize( - imgs, Y=lbls, rescale=rsc, scale_range=scale_range, - xy=(bsize, bsize))[:2] - X = torch.from_numpy(imgi).to(device) - lbl = torch.from_numpy(lbl).to(device) - - if X.dtype != net.dtype: - X = X.to(net.dtype) - lbl = lbl.to(net.dtype) - - y = net(X)[0] - loss = _loss_fn_seg(lbl, y, device) - if y.shape[1] > 3: - loss3 = _loss_fn_class(lbl, y, class_weights=class_weights) - loss += loss3 - test_loss = loss.item() - test_loss *= len(imgi) - lavgt += test_loss - lavgt /= len(rperm) - if return_loss_arrays: - test_losses[iepoch] = lavgt - train_logger.info( f"{iepoch}, train_loss={current_train_loss:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.6f}, time {time.time()-t0:.2f}s" ) - - # Call the callback function if provided (with test loss) - if loss_callback is not None: - loss_callback(iepoch, current_train_loss, lavgt) - else: - # For epochs without test evaluation, call callback with only train loss - if loss_callback is not None: - loss_callback(iepoch, current_train_loss, None) + + # Call the callback function every epoch with both train and test loss + if loss_callback is not None: + loss_callback(iepoch, current_train_loss, lavgt if (test_data is not None or test_files is not None) else None) # Reset accumulators after logging if iepoch == 5 or iepoch % 10 == 0: From 006b385350818b0a0168393bf8e8a8cf988cc49c Mon Sep 17 00:00:00 2001 From: yuguerten Date: Thu, 18 Dec 2025 10:38:40 +0100 Subject: [PATCH 4/6] early stopping added to cellpose --- cellpose/train.py | 42 +++++++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/cellpose/train.py b/cellpose/train.py index e47cec78..3d5a4994 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -234,7 +234,7 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None, files=train_files, device=device) if test_files is not None: for k in trange(nimg_test): - tl = dynamics.labels_to_flows(io.imread(test_labels_files), + tl = dynamics.labels_to_flows(io.imread(testLabels_files), files=test_files, device=device) ### compute diameters @@ -315,7 +315,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, save_path=None, save_every=100, save_each=False, nimg_per_epoch=None, nimg_test_per_epoch=None, rescale=False, scale_range=None, bsize=256, min_train_masks=5, model_name=None, class_weights=None, - loss_callback=None, return_loss_arrays=True): + loss_callback=None, return_loss_arrays=True, early_stopping_patience=None): """ Train the network with images for segmentation. @@ -349,10 +349,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, model_name (str, optional): String - name of the network. Defaults to None. loss_callback (callable, optional): Function called after each epoch with (epoch, train_loss, test_loss). Defaults to None. return_loss_arrays (bool, optional): Whether to return full loss arrays or just the model path. Defaults to True. - - Returns: - tuple or str: If return_loss_arrays=True, returns (path, train_losses, test_losses). If False, returns just the path to saved model weights. - + early_stopping_patience (int, optional): Number of epochs without validation loss improvement before stopping. + If None, no early stopping. Defaults to None. """ if SGD: train_logger.warning("SGD is deprecated, using AdamW instead") @@ -436,6 +434,12 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, lavg, nsum = 0, 0 train_losses, test_losses = (np.zeros(n_epochs), np.zeros(n_epochs)) if return_loss_arrays else (None, None) + + # Early stopping variables + best_val_loss = float('inf') + patience_counter = 0 + best_model_path = None + for iepoch in range(n_epochs): np.random.seed(iepoch) if nimg != nimg_per_epoch: @@ -533,6 +537,23 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, lavgt /= len(rperm) if return_loss_arrays: test_losses[iepoch] = lavgt + + # Early stopping logic + if early_stopping_patience is not None: + if lavgt < best_val_loss: + best_val_loss = lavgt + patience_counter = 0 + # Save best model + best_model_path = str(filename) + "_best" + train_logger.info(f"New best validation loss: {lavgt:.4f}, saving model to {best_model_path}") + net.save_model(best_model_path) + else: + patience_counter += 1 + train_logger.info(f"No improvement in validation loss for {patience_counter} epoch(s)") + + if patience_counter >= early_stopping_patience: + train_logger.info(f"Early stopping triggered after {iepoch + 1} epochs") + break # Log every epoch (more frequent logging for real-time tracking) if iepoch == 5 or iepoch % 10 == 0: @@ -556,7 +577,14 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, train_logger.info(f"saving network parameters to {filename0}") net.save_model(filename0) - net.save_model(filename) + # Save final model if not using early stopping or if we completed all epochs + if early_stopping_patience is None or patience_counter < early_stopping_patience: + net.save_model(filename) + + # If early stopping was used and we have a best model, use that + if best_model_path is not None: + train_logger.info(f"Training finished. Best model saved at: {best_model_path}") + filename = best_model_path if original_net_dtype is not None: net.dtype = original_net_dtype From d55f8b4e38808a1f7060497175f53fde4bbd0946 Mon Sep 17 00:00:00 2001 From: "d.klockenbring" Date: Thu, 18 Dec 2025 13:05:30 +0100 Subject: [PATCH 5/6] Hello world collaborative test --- hello_world.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 hello_world.py diff --git a/hello_world.py b/hello_world.py new file mode 100644 index 00000000..6d95fe97 --- /dev/null +++ b/hello_world.py @@ -0,0 +1 @@ +print("Hello world") \ No newline at end of file From a0f53696899074d483d95e8e50a4d2dd5922d1c3 Mon Sep 17 00:00:00 2001 From: "d.klockenbring" Date: Thu, 18 Dec 2025 13:07:52 +0100 Subject: [PATCH 6/6] remove hello_world --- hello_world.py | 1 - 1 file changed, 1 deletion(-) delete mode 100644 hello_world.py diff --git a/hello_world.py b/hello_world.py deleted file mode 100644 index 6d95fe97..00000000 --- a/hello_world.py +++ /dev/null @@ -1 +0,0 @@ -print("Hello world") \ No newline at end of file