diff --git a/src/aind_exaspim_image_compression/machine_learning/data_handling.py b/src/aind_exaspim_image_compression/machine_learning/data_handling.py index c521907..61a6492 100644 --- a/src/aind_exaspim_image_compression/machine_learning/data_handling.py +++ b/src/aind_exaspim_image_compression/machine_learning/data_handling.py @@ -45,9 +45,8 @@ def __init__( anisotropy=(0.748, 0.748, 1.0), boundary_buffer=5000, foreground_sampling_rate=0.3, - min_brightness=200, n_examples_per_epoch=300, - normalization_percentiles=(0.5, 99.9), + normalization_percentiles=(1, 99.9), prefetch_foreground_sampling=16, sigma_bm4d=16, ): @@ -58,7 +57,6 @@ def __init__( self.anisotropy = anisotropy self.boundary_buffer = boundary_buffer self.foreground_sampling_rate = foreground_sampling_rate - self.min_brightness = min_brightness self.n_examples_per_epoch = n_examples_per_epoch self.normalization_percentiles = normalization_percentiles self.patch_shape = patch_shape @@ -163,15 +161,13 @@ def __getitem__(self, dummy_input): Returns ------- - tuple - A tuple containing: - - noise : numpy.ndarray - Noisy image patch, normalized and clipped. - - denoised : numpy.ndarray - Denoised image patch, normalized and clipped using the same - scale as the noisy patch. - - (mn, mx) : Tuple[float] - Lower and upper percentiles used for normalization. + noise : numpy.ndarray + Noisy image patch, normalized and clipped. + denoised : numpy.ndarray + Denoised image patch, normalized and clipped using the same scale + as the noisy patch. + (mn, mx) : Tuple[float] + Lower and upper percentiles used for normalization. """ # Get image patches brain_id = self.sample_brain() @@ -181,8 +177,8 @@ def __getitem__(self, dummy_input): denoised = bm4d(noise, self.sigma_bm4d) # Normalize image patches - noise = np.clip((noise - mn) / max(mx - mn, 1), 0, 5) - denoised = np.clip((denoised - mn) / max(mx - mn, 1), 0, 5) + noise = np.clip((noise - mn) / (mx - mn + 1e-8), 0, 5) + denoised = np.clip((denoised - mn) / (mx - mn + 1e-8), 0, 5) return noise, denoised, (mn, mx) def sample_brain(self): @@ -230,7 +226,7 @@ def sample_foreground_voxel(self, brain_id): Tuple[int] Voxel coordinate representing a likely foreground location. """ - if self.skeletons[brain_id] is not None and np.random.random() > 0.5: + if self.skeletons[brain_id] is not None: return self.sample_skeleton_voxel(brain_id) elif self.segmentations[brain_id] is not None: return self.sample_segmentation_voxel(brain_id) @@ -294,11 +290,11 @@ def sample_segmentation_voxel(self, brain_id): Voxel coordinate whose patch contains a sufficiently large object or had the largest object after 5 * self.prefetch attempts. """ - cnt = 0 + best_volume = 0 best_voxel = self.sample_interior_voxel(brain_id) - max_volume = 0 - while max_volume < 3000: - with ThreadPoolExecutor() as executor: + cnt = 0 + with ThreadPoolExecutor() as executor: + while best_volume < 1600: # Read random image patches pending = dict() for _ in range(self.prefetch_foreground_sampling): @@ -318,14 +314,14 @@ def sample_segmentation_voxel(self, brain_id): if len(cnts) > 1: volume = np.max(cnts[1:]) - if volume > max_volume: + if volume > best_volume: best_voxel = voxel - max_volume = volume + best_volume = volume - # Check number of tries - cnt += 1 - if cnt > 5: - break + # Check number of tries + cnt += 1 + if cnt > 5: + break return best_voxel def sample_bright_voxel(self, brain_id): @@ -339,15 +335,15 @@ def sample_bright_voxel(self, brain_id): Returns ------- - brightest_voxel : Tuple[int] + best_voxel : Tuple[int] Voxel coordinate whose patch is sufficiently bright or is the - highest observed brightness after 5 * self.prefetch attempts. + highest observed brightness after 4 * self.prefetch attempts. """ + best_brightness = 0 + best_voxel = self.sample_interior_voxel(brain_id) cnt = 0 - brightest_voxel = self.sample_interior_voxel(brain_id) - max_brightness = 0 - while max_brightness < self.min_brightness: - with ThreadPoolExecutor() as executor: + with ThreadPoolExecutor() as executor: + while best_brightness < 1600: # Read random image patches pending = dict() for _ in range(self.prefetch_foreground_sampling): @@ -361,19 +357,16 @@ def sample_bright_voxel(self, brain_id): for thread in as_completed(pending.keys()): voxel = pending.pop(thread) img_patch = thread.result() - brightness = np.sum(img_patch > 500) - if brightness > 100: - brightest_voxel = voxel - max_brightness = brightness - - if max_brightness > self.min_brightness: - break - - # Check number of tries - cnt += 1 - if cnt > 5: - break - return brightest_voxel + brightness = np.sum(img_patch > 100) + if brightness > best_brightness: + best_voxel = voxel + best_brightness = brightness + + # Check number of tries + cnt += 1 + if cnt > 5: + break + return best_voxel # --- Helpers --- def __len__(self): @@ -422,8 +415,15 @@ def read_precomputed_patch(self, brain_id, center): numpy.ndarray Image patch. """ - s = img_util.get_slices(center, self.patch_shape) - return self.segmentations[brain_id][s].read().result() + try: + s = img_util.get_slices(center, self.patch_shape) + return self.segmentations[brain_id][s].read().result() + except Exception as e: + print("Exception:", e) + print("Brain ID:", brain_id) + print("img.shape:", self.imgs[brain_id].shape) + print("label_mask.shape:", self.segmentations[brain_id].shape) + return np.zeros(self.patch_shape) def to_voxels(self, xyz_arr): """ @@ -449,8 +449,8 @@ class ValidateDataset(Dataset): def __init__( self, patch_shape, - normalization_percentiles=[0.5, 99.9], - sigma_bm4d=10, + normalization_percentiles=(1, 99.9), + sigma_bm4d=16, ): """ Instantiates a ValidateDataset object. @@ -459,12 +459,12 @@ def __init__( ---------- patch_shape : Tuple[int] Shape of image patches to be extracted. - normalization_percentiles : List[float], optional + normalization_percentiles : Tuple[float], optional Upper and lower percentiles used to normalize the input image. - Default is [0.5, 99.9]. + Default is (0.5, 99.9). sigma_bm4d : float, optional Smoothing parameter used in the BM4D denoising algorithm. Default - is 10. + is 16. """ # Call parent class super(ValidateDataset, self).__init__() @@ -523,8 +523,8 @@ def ingest_example(self, brain_id, voxel): denoised = bm4d(noise, self.sigma_bm4d) # Normalize image patches - noise = np.clip((noise - mn) / max(mx - mn, 1), 0, 5) - denoised = np.clip((denoised - mn) / max(mx - mn, 1), 0, 5) + noise = np.clip((noise - mn) / (mx - mn + 1e-8), 0, 5) + denoised = np.clip((denoised - mn) / (mx - mn + 1e-8), 0, 5) # Store results self.example_ids.append((brain_id, voxel)) diff --git a/src/aind_exaspim_image_compression/machine_learning/unet3d.py b/src/aind_exaspim_image_compression/machine_learning/unet3d.py index 185633c..9fdeb2a 100644 --- a/src/aind_exaspim_image_compression/machine_learning/unet3d.py +++ b/src/aind_exaspim_image_compression/machine_learning/unet3d.py @@ -1,4 +1,3 @@ - """ Created on Fri Aug 14 15:00:00 2025 @@ -193,7 +192,14 @@ def __init__(self, in_channels, out_channels): # Instance attributes self.maxpool_conv = nn.Sequential( - nn.MaxPool3d(2), DoubleConv(in_channels, out_channels) + nn.Conv3d( + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1 + ), + DoubleConv(out_channels, out_channels) ) def forward(self, x): @@ -273,7 +279,7 @@ def forward(self, x1, x2): Returns ------- - torch.Tensor + x : torch.Tensor Output tensor after upsampling, concatenation with the skip connection, and double convolution. The output shape is (B, out_channels, D, H2, W2). @@ -287,7 +293,8 @@ def forward(self, x1, x2): [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2], ) x = torch.cat([x2, x1], dim=1) - return self.conv(x) + x = self.conv(x) + return x class OutConv(nn.Module):