Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 53 additions & 53 deletions src/aind_exaspim_image_compression/machine_learning/data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ def __init__(
anisotropy=(0.748, 0.748, 1.0),
boundary_buffer=5000,
foreground_sampling_rate=0.3,
min_brightness=200,
n_examples_per_epoch=300,
normalization_percentiles=(0.5, 99.9),
normalization_percentiles=(1, 99.9),
prefetch_foreground_sampling=16,
sigma_bm4d=16,
):
Expand All @@ -58,7 +57,6 @@ def __init__(
self.anisotropy = anisotropy
self.boundary_buffer = boundary_buffer
self.foreground_sampling_rate = foreground_sampling_rate
self.min_brightness = min_brightness
self.n_examples_per_epoch = n_examples_per_epoch
self.normalization_percentiles = normalization_percentiles
self.patch_shape = patch_shape
Expand Down Expand Up @@ -163,15 +161,13 @@ def __getitem__(self, dummy_input):

Returns
-------
tuple
A tuple containing:
- noise : numpy.ndarray
Noisy image patch, normalized and clipped.
- denoised : numpy.ndarray
Denoised image patch, normalized and clipped using the same
scale as the noisy patch.
- (mn, mx) : Tuple[float]
Lower and upper percentiles used for normalization.
noise : numpy.ndarray
Noisy image patch, normalized and clipped.
denoised : numpy.ndarray
Denoised image patch, normalized and clipped using the same scale
as the noisy patch.
(mn, mx) : Tuple[float]
Lower and upper percentiles used for normalization.
"""
# Get image patches
brain_id = self.sample_brain()
Expand All @@ -181,8 +177,8 @@ def __getitem__(self, dummy_input):
denoised = bm4d(noise, self.sigma_bm4d)

# Normalize image patches
noise = np.clip((noise - mn) / max(mx - mn, 1), 0, 5)
denoised = np.clip((denoised - mn) / max(mx - mn, 1), 0, 5)
noise = np.clip((noise - mn) / (mx - mn + 1e-8), 0, 5)
denoised = np.clip((denoised - mn) / (mx - mn + 1e-8), 0, 5)
return noise, denoised, (mn, mx)

def sample_brain(self):
Expand Down Expand Up @@ -230,7 +226,7 @@ def sample_foreground_voxel(self, brain_id):
Tuple[int]
Voxel coordinate representing a likely foreground location.
"""
if self.skeletons[brain_id] is not None and np.random.random() > 0.5:
if self.skeletons[brain_id] is not None:
return self.sample_skeleton_voxel(brain_id)
elif self.segmentations[brain_id] is not None:
return self.sample_segmentation_voxel(brain_id)
Expand Down Expand Up @@ -294,11 +290,11 @@ def sample_segmentation_voxel(self, brain_id):
Voxel coordinate whose patch contains a sufficiently large object
or had the largest object after 5 * self.prefetch attempts.
"""
cnt = 0
best_volume = 0
best_voxel = self.sample_interior_voxel(brain_id)
max_volume = 0
while max_volume < 3000:
with ThreadPoolExecutor() as executor:
cnt = 0
with ThreadPoolExecutor() as executor:
while best_volume < 1600:
# Read random image patches
pending = dict()
for _ in range(self.prefetch_foreground_sampling):
Expand All @@ -318,14 +314,14 @@ def sample_segmentation_voxel(self, brain_id):

if len(cnts) > 1:
volume = np.max(cnts[1:])
if volume > max_volume:
if volume > best_volume:
best_voxel = voxel
max_volume = volume
best_volume = volume

# Check number of tries
cnt += 1
if cnt > 5:
break
# Check number of tries
cnt += 1
if cnt > 5:
break
return best_voxel

def sample_bright_voxel(self, brain_id):
Expand All @@ -339,15 +335,15 @@ def sample_bright_voxel(self, brain_id):

Returns
-------
brightest_voxel : Tuple[int]
best_voxel : Tuple[int]
Voxel coordinate whose patch is sufficiently bright or is the
highest observed brightness after 5 * self.prefetch attempts.
highest observed brightness after 4 * self.prefetch attempts.
"""
best_brightness = 0
best_voxel = self.sample_interior_voxel(brain_id)
cnt = 0
brightest_voxel = self.sample_interior_voxel(brain_id)
max_brightness = 0
while max_brightness < self.min_brightness:
with ThreadPoolExecutor() as executor:
with ThreadPoolExecutor() as executor:
while best_brightness < 1600:
# Read random image patches
pending = dict()
for _ in range(self.prefetch_foreground_sampling):
Expand All @@ -361,19 +357,16 @@ def sample_bright_voxel(self, brain_id):
for thread in as_completed(pending.keys()):
voxel = pending.pop(thread)
img_patch = thread.result()
brightness = np.sum(img_patch > 500)
if brightness > 100:
brightest_voxel = voxel
max_brightness = brightness

if max_brightness > self.min_brightness:
break

# Check number of tries
cnt += 1
if cnt > 5:
break
return brightest_voxel
brightness = np.sum(img_patch > 100)
if brightness > best_brightness:
best_voxel = voxel
best_brightness = brightness

# Check number of tries
cnt += 1
if cnt > 5:
break
return best_voxel

# --- Helpers ---
def __len__(self):
Expand Down Expand Up @@ -422,8 +415,15 @@ def read_precomputed_patch(self, brain_id, center):
numpy.ndarray
Image patch.
"""
s = img_util.get_slices(center, self.patch_shape)
return self.segmentations[brain_id][s].read().result()
try:
s = img_util.get_slices(center, self.patch_shape)
return self.segmentations[brain_id][s].read().result()
except Exception as e:
print("Exception:", e)
print("Brain ID:", brain_id)
print("img.shape:", self.imgs[brain_id].shape)
print("label_mask.shape:", self.segmentations[brain_id].shape)
return np.zeros(self.patch_shape)

def to_voxels(self, xyz_arr):
"""
Expand All @@ -449,8 +449,8 @@ class ValidateDataset(Dataset):
def __init__(
self,
patch_shape,
normalization_percentiles=[0.5, 99.9],
sigma_bm4d=10,
normalization_percentiles=(1, 99.9),
sigma_bm4d=16,
):
"""
Instantiates a ValidateDataset object.
Expand All @@ -459,12 +459,12 @@ def __init__(
----------
patch_shape : Tuple[int]
Shape of image patches to be extracted.
normalization_percentiles : List[float], optional
normalization_percentiles : Tuple[float], optional
Upper and lower percentiles used to normalize the input image.
Default is [0.5, 99.9].
Default is (0.5, 99.9).
sigma_bm4d : float, optional
Smoothing parameter used in the BM4D denoising algorithm. Default
is 10.
is 16.
"""
# Call parent class
super(ValidateDataset, self).__init__()
Expand Down Expand Up @@ -523,8 +523,8 @@ def ingest_example(self, brain_id, voxel):
denoised = bm4d(noise, self.sigma_bm4d)

# Normalize image patches
noise = np.clip((noise - mn) / max(mx - mn, 1), 0, 5)
denoised = np.clip((denoised - mn) / max(mx - mn, 1), 0, 5)
noise = np.clip((noise - mn) / (mx - mn + 1e-8), 0, 5)
denoised = np.clip((denoised - mn) / (mx - mn + 1e-8), 0, 5)

# Store results
self.example_ids.append((brain_id, voxel))
Expand Down
15 changes: 11 additions & 4 deletions src/aind_exaspim_image_compression/machine_learning/unet3d.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

"""
Created on Fri Aug 14 15:00:00 2025

Expand Down Expand Up @@ -193,7 +192,14 @@ def __init__(self, in_channels, out_channels):

# Instance attributes
self.maxpool_conv = nn.Sequential(
nn.MaxPool3d(2), DoubleConv(in_channels, out_channels)
nn.Conv3d(
in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1
),
DoubleConv(out_channels, out_channels)
)

def forward(self, x):
Expand Down Expand Up @@ -273,7 +279,7 @@ def forward(self, x1, x2):

Returns
-------
torch.Tensor
x : torch.Tensor
Output tensor after upsampling, concatenation with the skip
connection, and double convolution. The output shape is
(B, out_channels, D, H2, W2).
Expand All @@ -287,7 +293,8 @@ def forward(self, x1, x2):
[diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2],
)
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
x = self.conv(x)
return x


class OutConv(nn.Module):
Expand Down
Loading