diff --git a/src/aind_exaspim_image_compression/inference.py b/src/aind_exaspim_image_compression/inference.py index 8efcf99..9c3203a 100644 --- a/src/aind_exaspim_image_compression/inference.py +++ b/src/aind_exaspim_image_compression/inference.py @@ -18,13 +18,11 @@ import torch from aind_exaspim_image_compression.machine_learning.unet3d import UNet -from aind_exaspim_image_compression.utils import img_util def predict( img, model, - denoised=None, batch_size=32, normalization_percentiles=(0.5, 99.9), patch_size=64, @@ -64,7 +62,7 @@ def predict( """ # Preprocess image mn, mx = np.percentile(img, normalization_percentiles) - img = (img - mn) / (mx - mn + 1e-5) + img = (img - mn) / (mx - mn + 1e-8) img = np.clip(img, 0, 5) while len(img.shape) < 5: img = img[np.newaxis, ...] @@ -72,51 +70,41 @@ def predict( # Initializations patch_starts_generator = generate_patch_starts(img, patch_size, overlap) n_starts = count_patches(img, patch_size, overlap) - if denoised is None: - denoised = np.zeros_like(img) + pbar = tqdm(total=n_starts, desc="Denoise") if verbose else None # Main - pbar = tqdm(total=n_starts, desc="Denoise") if verbose else None - for i in range(0, n_starts, batch_size): + accum_pred = np.zeros(img.shape[2:]) + accum_wgt = np.zeros(img.shape[2:]) + for _ in range(0, n_starts, batch_size): # Extract batch and run model starts = list(itertools.islice(patch_starts_generator, batch_size)) patches = _predict_batch(img, model, starts, patch_size, trim=trim) - # Store result + # Add batch predictions to result for patch, start in zip(patches, starts): - start = [max(s + trim, 0) for s in start] - end = [start[i] + patch.shape[i] for i in range(3)] - end = [min(e, s) for e, s in zip(end, img.shape[2:])] - denoised[ - 0, 0, start[0]:end[0], start[1]:end[1], start[2]:end[2] - ] = patch[: end[0] - start[0], : end[1] - start[1], : end[2] - start[2]] + # Compute start and end coordinates + s = [max(si + trim, 0) for si in start] + e = [ + min(si + pi, di) + for si, pi, di in zip(s, patch.shape, img.shape[2:]) + ] + + # Create slices + pred_slices = tuple(slice(si, ei) for si, ei in zip(s, e)) + patch_slices = tuple(slice(0, ei - si) for si, ei in zip(s, e)) + + # Add patch prediction to result + accum_pred[pred_slices] += patch[patch_slices] + accum_wgt[pred_slices] += 1 + pbar.update(len(starts)) if verbose else None - # Postprocess image + # Postprocess prediction + denoised = accum_pred[:, ...] / (accum_wgt + 1e-8) denoised = np.clip(denoised * (mx - mn) + mn, 0, 2**16 - 1) return denoised.astype(np.uint16) -def predict_largescale( - img, - model, - output_path, - compressor, - batch_size=32, - normalization_percentiles=(0.5, 99.9), - patch_size=64, - overlap=12, - output_chunks=(1, 1, 64, 128, 128), - trim=5, - verbose=True -): - # Initializations - denoised = img_util.init_ome_zarr( - img, output_path, compressor=compressor, chunks=output_chunks - ) - predict(img, model, denoised=denoised) - - def predict_patch(patch, model, normalization_percentiles=(0.5, 99.9)): """ Denoises a single 3D patch using the provided model. @@ -133,15 +121,15 @@ def predict_patch(patch, model, normalization_percentiles=(0.5, 99.9)): Returns ------- - numpy.ndarray + pred : numpy.ndarray Denoised 3D patch with the same shape as input patch. """ # Preprocess image mn, mx = np.percentile(patch, normalization_percentiles) - patch = (patch - mn) / (mx - mn + 1e-5) + patch = (patch - mn) / (mx - mn + 1e-8) patch = np.clip(patch, 0, 5) - while len(img.shape) < 5: - img = img[np.newaxis, ...] + while len(patch.shape) < 5: + patch = patch[np.newaxis, ...] # Run model patch = to_tensor(patch) @@ -269,7 +257,7 @@ def load_model(path, device="cuda"): Returns ------- - torch.nn.Module + model : torch.nn.Module UNet model loaded with weights and set to evaluation mode. """ model = UNet() diff --git a/src/aind_exaspim_image_compression/machine_learning/data_handling.py b/src/aind_exaspim_image_compression/machine_learning/data_handling.py index c156177..c521907 100644 --- a/src/aind_exaspim_image_compression/machine_learning/data_handling.py +++ b/src/aind_exaspim_image_compression/machine_learning/data_handling.py @@ -44,12 +44,12 @@ def __init__( patch_shape, anisotropy=(0.748, 0.748, 1.0), boundary_buffer=5000, - foreground_sampling_rate=0.2, + foreground_sampling_rate=0.3, min_brightness=200, n_examples_per_epoch=300, normalization_percentiles=(0.5, 99.9), - prefetch_foreground_sampling=12, - sigma_bm4d=10, + prefetch_foreground_sampling=16, + sigma_bm4d=16, ): # Call parent class super(TrainDataset, self).__init__() @@ -290,9 +290,9 @@ def sample_segmentation_voxel(self, brain_id): Returns ------- - Tuple[int] + best_voxel : Tuple[int] Voxel coordinate whose patch contains a sufficiently large object - or had the largest object after 32 attempts. + or had the largest object after 5 * self.prefetch attempts. """ cnt = 0 best_voxel = self.sample_interior_voxel(brain_id) @@ -330,8 +330,7 @@ def sample_segmentation_voxel(self, brain_id): def sample_bright_voxel(self, brain_id): """ - Samples a voxel coordinate whose surrounding image patch is - sufficiently bright. + Samples a voxel coordinate whose image patch is sufficiently bright. Parameters ---------- @@ -340,9 +339,9 @@ def sample_bright_voxel(self, brain_id): Returns ------- - Tuple[int] + brightest_voxel : Tuple[int] Voxel coordinate whose patch is sufficiently bright or is the - highest observed brightness after 32 attempts. + highest observed brightness after 5 * self.prefetch attempts. """ cnt = 0 brightest_voxel = self.sample_interior_voxel(brain_id) @@ -544,12 +543,13 @@ def __getitem__(self, idx): Returns ------- - tuple - A tuple containing: - - noise (ndarray): Noisy image patch at the given index. - - denoised (ndarray): Corresponding denoised image patch. - - mn_mx (tuple): Minimum and maximum values used for normalization - of the image patches. + noise : numpy.ndarray + Noisy image patch at the given index. + denoised : numpy.ndarray + Corresponding denoised image patch. + mn_mx : Tuple[int] + Minimum and maximum values used for normalization of the image + patches. """ return self.noise[idx], self.denoised[idx], self.mn_mxs[idx] @@ -563,6 +563,15 @@ class DataLoader: """ DataLoader that uses multithreading to fetch image patches from the cloud to form batches. + + Attributes + ---------- + dataset : torch.utils.data.Dataset + Dataset to iterated over. + batch_size : int + Number of examples in each batch. + patch_shape : Tuple[int] + Shape of image patch expected by the model. """ def __init__(self, dataset, batch_size=16): @@ -629,7 +638,7 @@ def init_datasets( n_train_examples_per_epoch=100, n_validate_examples=0, segmentation_prefixes_path=None, - sigma_bm4d=10, + sigma_bm4d=16, swc_pointers=None ): # Initializations diff --git a/src/aind_exaspim_image_compression/machine_learning/train.py b/src/aind_exaspim_image_compression/machine_learning/train.py index eebf601..d97ffd8 100644 --- a/src/aind_exaspim_image_compression/machine_learning/train.py +++ b/src/aind_exaspim_image_compression/machine_learning/train.py @@ -67,7 +67,7 @@ def __init__( self.batch_size = batch_size self.device = device self.max_epochs = max_epochs - self.log_dir = log_dir + self.log_dir = log_dir self.codec = blosc.Blosc(cname="zstd", clevel=5, shuffle=blosc.SHUFFLE) self.criterion = nn.L1Loss() @@ -162,11 +162,12 @@ def validate_step(self, val_dataloader, epoch): Returns ------- - tuple - A tuple containing the following: - - float: Average loss over the validation dataset. - - float: Average compression ratio over the validation dataset. - - bool: Indication of whether the model is the best so far. + loss : float + Average loss over the validation dataset. + cratio : float + Average compression ratio over the validation dataset. + is_best : bool + Indication of whether the model is the best so far. """ losses = list() cratios = list() @@ -186,12 +187,11 @@ def validate_step(self, val_dataloader, epoch): self.writer.add_scalar("val_cratio", cratio, epoch) # Check if current model is best so far - if loss < self.best_l1: + is_best = True if loss < self.best_l1 else False + if is_best: self.best_l1 = loss self.save_model(epoch) - return loss, cratio, True - else: - return loss, cratio, False + return loss, cratio, is_best def forward_pass(self, x, y): """ @@ -224,7 +224,7 @@ def compute_cratios(self, imgs, mn_mx): imgs = np.array(imgs.detach().cpu()) for i in range(imgs.shape[0]): mn, mx = tuple(mn_mx[i, :]) - img = imgs[i, 0, ...] * (mx - mn) + mn + img = np.clip(imgs[i, 0, ...] * (mx - mn) + mn, 0, 2**16 - 1) cratios.append(img_util.compute_cratio(img, self.codec)) if i < 10: tifffile.imwrite(f"{i}.tiff", img) diff --git a/src/aind_exaspim_image_compression/utils/img_util.py b/src/aind_exaspim_image_compression/utils/img_util.py index 01ae8f4..adaaadf 100644 --- a/src/aind_exaspim_image_compression/utils/img_util.py +++ b/src/aind_exaspim_image_compression/utils/img_util.py @@ -42,7 +42,7 @@ def read(img_path): Returns ------- - ArrayLike + img : ArrayLike Image volume. """ # Read image @@ -125,7 +125,7 @@ def _read_tiff(img_path, storage_options=None): Returns ------- - np.ndarray + numpy.ndarray Image volume. """ if _is_gcs_path(img_path): @@ -143,7 +143,7 @@ def _is_gcs_path(path): Parameters ---------- path : str - Path to an object. + Path to be checked. Returns ------- @@ -160,7 +160,7 @@ def _is_s3_path(path): Parameters ---------- path : str - Path to an object. + Path to be checked. Returns ------- @@ -185,7 +185,7 @@ def get_patch(img, voxel, shape, is_center=True): shape : Tuple[int] Shape of the image patch to extract. is_center : bool, optional - Indicates whether the given voxel is the center or top, left, front + Indicates whether the given voxel is the center or front-top-left corner of the patch to be extracted. Returns @@ -194,6 +194,7 @@ def get_patch(img, voxel, shape, is_center=True): Patch extracted from the given image. """ # Get patch coordinates + assert len(img.shape) == 5, "Error: Image must have shape TxCxDxHxW!" start, end = get_start_end(voxel, shape, is_center=is_center) valid_start = any([s >= 0 for s in start]) valid_end = any([e < img.shape[i + 2] for i, e in enumerate(end)]) @@ -207,43 +208,6 @@ def get_patch(img, voxel, shape, is_center=True): return np.ones(shape) -def calculate_offsets(img, window_shape, overlap): - """ - Generates a list of 3D coordinates representing the front-top-left corner - by sliding a window over a 3D image, given a specified window size and - overlap between adjacent windows. - - Parameters - ---------- - img : zarr.core.Array - Input 3D image. - window_shape : Tuple[int] - Shape of the sliding window. - overlap : Tuple[int] - Overlap between adjacent sliding windows. - - Returns - ------- - List[Tuple[int]] - 3D voxel coordinates that represent the front-top-left corner. - """ - # Calculate stride based on the overlap and window size - stride = tuple(w - o for w, o in zip(window_shape, overlap)) - i_stride, j_stride, k_stride = stride - - # Get dimensions of the window - _, _, i_dim, j_dim, k_dim = img.shape - i_win, j_win, k_win = window_shape - - # Loop over the with the sliding window - coords = [] - for i in range(0, i_dim - i_win + 1, i_stride): - for j in range(0, j_dim - j_win + 1, j_stride): - for k in range(0, k_dim - k_win + 1, k_stride): - coords.append((i, j, k)) - return coords - - def get_start_end(voxel, shape, is_center=True): """ Gets the start and end indices of the image patch to be read. @@ -278,9 +242,9 @@ def to_physical(voxel, anisotropy): ---------- voxel : ArrayLike Voxel coordinate to be converted. - multiscale - Level in the image pyramid that the voxel coordinate must index into. - + anisotropy : Tuple[float] + Image to physical coordinates scaling factors to account for the + anisotropy of the microscope. Returns ------- @@ -299,8 +263,9 @@ def to_voxels(xyz, anisotropy): ---------- xyz : ArrayLike Physical coordinate to be converted to a voxel coordinate. - multiscale : int - Level in the image pyramid that the voxel coordinate must index into. + anisotropy : Tuple[float] + Image to physical coordinates scaling factors to account for the + anisotropy of the microscope. Returns ------- @@ -488,10 +453,11 @@ def plot_mips(img, output_path=None, vmax=None): ---------- img : numpy.ndarray Input image to generate MIPs from. - - Returns - ------- - None + output_path : None or str, optional + Path that plot is saved to if provided. Default is None. + vmax : None or float, optional + Brightness intensity used as upper limit of the colormap. Default is + None. """ vmax = vmax or np.percentile(img, 99.9) fig, axs = plt.subplots(1, 3, figsize=(10, 4)) @@ -522,6 +488,11 @@ def plot_slices(img, output_path=None, vmax=None): ---------- img : numpy.ndarray Image to generate MIPs from. + output_path : None or str, optional + Path that plot is saved to if provided. Default is None. + vmax : None or float, optional + Brightness intensity used as upper limit of the colormap. Default is + None. """ # Get middle slice shape = img.shape[2:] if len(img.shape) == 5 else img.shape