Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 88 additions & 34 deletions src/aind_exaspim_image_compression/machine_learning/data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ def __init__(
patch_shape,
anisotropy=(0.748, 0.748, 1.0),
boundary_buffer=5000,
foreground_sampling_rate=0.3,
foreground_sampling_rate=0.5,
n_examples_per_epoch=300,
normalization_percentiles=(1, 99.9),
normalization_percentiles=(0.5, 99.9),
normalized_brightness_clip=8,
prefetch_foreground_sampling=16,
sigma_bm4d=16,
):
Expand All @@ -59,6 +60,7 @@ def __init__(
self.foreground_sampling_rate = foreground_sampling_rate
self.n_examples_per_epoch = n_examples_per_epoch
self.normalization_percentiles = normalization_percentiles
self.normalized_brightness_clip = normalized_brightness_clip
self.patch_shape = patch_shape
self.prefetch_foreground_sampling = prefetch_foreground_sampling
self.sigma_bm4d = sigma_bm4d
Expand Down Expand Up @@ -86,22 +88,20 @@ def ingest_brain(self, brain_id, img_path, segmentation_path, swc_pointer):
swc_path : str
Path to SWC files.
"""
self.segmentations[brain_id] = self._load_segmentation(segmentation_path)
self.imgs[brain_id] = img_util.read(img_path)
self.skeletons[brain_id] = self._load_swcs(swc_pointer)
self._load_segmentation(brain_id, segmentation_path)
self._load_swcs(brain_id, swc_pointer)

def _load_segmentation(self, segmentation_path):
def _load_segmentation(self, brain_id, segmentation_path):
"""
Reads a segmentation mask generated by Google Applied Sciences (GAS).

Parameters
----------
brain_id : str
Unique identifier for the brain corresponding to the given path.
segmentation_path : str
Path to segmentation.

Returns
-------
...
"""
if segmentation_path:
# Load image
Expand All @@ -126,11 +126,9 @@ def _load_segmentation(self, segmentation_path):
label_mask = label_mask[ts.d["channel"][0]]
label_mask = label_mask[ts.d[0].transpose[2]]
label_mask = label_mask[ts.d[0].transpose[1]]
return label_mask
else:
return None
self.segmentations[brain_id] = label_mask

def _load_swcs(self, swc_pointer):
def _load_swcs(self, brain_id, swc_pointer):
if swc_pointer:
# Initializations
swc_dicts = self.swc_reader.read(swc_pointer)
Expand All @@ -144,8 +142,7 @@ def _load_swcs(self, swc_pointer):
end = start + len(swc_dict["xyz"])
skeletons[start:end] = self.to_voxels(swc_dict["xyz"])
start = end
return skeletons
return None
self.skeletons[brain_id] = skeletons

# --- Sample Image Patches ---
def __getitem__(self, dummy_input):
Expand Down Expand Up @@ -177,8 +174,8 @@ def __getitem__(self, dummy_input):
denoised = bm4d(noise, self.sigma_bm4d)

# Normalize image patches
noise = np.clip((noise - mn) / (mx - mn + 1e-8), 0, 5)
denoised = np.clip((denoised - mn) / (mx - mn + 1e-8), 0, 5)
noise = self.normalize(noise, mn, mx)
denoised = self.normalize(denoised, mn, mx)
return noise, denoised, (mn, mx)

def sample_brain(self):
Expand Down Expand Up @@ -226,9 +223,9 @@ def sample_foreground_voxel(self, brain_id):
Tuple[int]
Voxel coordinate representing a likely foreground location.
"""
if self.skeletons[brain_id] is not None:
if brain_id in self.skeletons and np.random.random() > 0.5:
return self.sample_skeleton_voxel(brain_id)
elif self.segmentations[brain_id] is not None:
elif brain_id in self.segmentations:
return self.sample_segmentation_voxel(brain_id)
else:
return self.sample_bright_voxel(brain_id)
Expand Down Expand Up @@ -343,7 +340,7 @@ def sample_bright_voxel(self, brain_id):
best_voxel = self.sample_interior_voxel(brain_id)
cnt = 0
with ThreadPoolExecutor() as executor:
while best_brightness < 1600:
while best_brightness < 1000:
# Read random image patches
pending = dict()
for _ in range(self.prefetch_foreground_sampling):
Expand Down Expand Up @@ -380,6 +377,29 @@ def __len__(self):
"""
return self.n_examples_per_epoch

def normalize(self, img, mn, mx):
"""
Normalizes the given image using a percentile-based scheme and clips
the max brightness.

Parameters
----------
img : numpy.ndarray
Image to be normalized
mn : float
Lower percentile.
mx : float
Upper percentile

Returns
-------
img : numpy.ndarray
Normalized image.
"""
img = (img - mn) / (mx - mn + 1e-8)
img = np.clip(img, 0, self.normalized_brightness_clip)
return img

def read_patch(self, brain_id, center):
"""
Reads an image patch from a Zarr array.
Expand Down Expand Up @@ -415,15 +435,8 @@ def read_precomputed_patch(self, brain_id, center):
numpy.ndarray
Image patch.
"""
try:
s = img_util.get_slices(center, self.patch_shape)
return self.segmentations[brain_id][s].read().result()
except Exception as e:
print("Exception:", e)
print("Brain ID:", brain_id)
print("img.shape:", self.imgs[brain_id].shape)
print("label_mask.shape:", self.segmentations[brain_id].shape)
return np.zeros(self.patch_shape)
s = img_util.get_slices(center, self.patch_shape)
return self.segmentations[brain_id][s].read().result()

def to_voxels(self, xyz_arr):
"""
Expand All @@ -449,7 +462,8 @@ class ValidateDataset(Dataset):
def __init__(
self,
patch_shape,
normalization_percentiles=(1, 99.9),
normalization_percentiles=(0.5, 99.9),
normalized_brightness_clip=8,
sigma_bm4d=16,
):
"""
Expand All @@ -461,7 +475,7 @@ def __init__(
Shape of image patches to be extracted.
normalization_percentiles : Tuple[float], optional
Upper and lower percentiles used to normalize the input image.
Default is (0.5, 99.9).
Default is (0.5, 99.5).
sigma_bm4d : float, optional
Smoothing parameter used in the BM4D denoising algorithm. Default
is 16.
Expand All @@ -471,6 +485,7 @@ def __init__(

# Instance attributes
self.normalization_percentiles = normalization_percentiles
self.normalized_brightness_clip = normalized_brightness_clip
self.patch_shape = patch_shape
self.sigma_bm4d = sigma_bm4d

Expand Down Expand Up @@ -498,7 +513,7 @@ def ingest_brain(self, brain_id, img_path):

Parameters
----------
brain_id : hashable
brain_id : str
Unique identifier for the brain corresponding to the image.
img_path : str or Path
Path to whole-brain image to be read.
Expand All @@ -523,8 +538,8 @@ def ingest_example(self, brain_id, voxel):
denoised = bm4d(noise, self.sigma_bm4d)

# Normalize image patches
noise = np.clip((noise - mn) / (mx - mn + 1e-8), 0, 5)
denoised = np.clip((denoised - mn) / (mx - mn + 1e-8), 0, 5)
noise = self.normalize(noise, mn, mx)
denoised = self.normalize(denoised, mn, mx)

# Store results
self.example_ids.append((brain_id, voxel))
Expand Down Expand Up @@ -553,7 +568,46 @@ def __getitem__(self, idx):
"""
return self.noise[idx], self.denoised[idx], self.mn_mxs[idx]

# --- Helpers ---
def normalize(self, img, mn, mx):
"""
Normalizes the given image using a percentile-based scheme and clips
the max brightness.

Parameters
----------
img : numpy.ndarray
Image to be normalized
mn : float
Lower percentile.
mx : float
Upper percentile

Returns
-------
img : numpy.ndarray
Normalized image.
"""
img = (img - mn) / (mx - mn + 1e-8)
img = np.clip(img, 0, self.normalized_brightness_clip)
return img

def read_patch(self, brain_id, center):
"""
Reads an image patch from a Zarr array.

Parameters
----------
brain_id : str
Unique identifier of the sampled brain.
center : Tuple[int]
Center of image patch to be read.

Returns
-------
numpy.ndarray
Image patch.
"""
slices = img_util.get_slices(center, self.patch_shape)
return self.imgs[brain_id][(0, 0, *slices)]

Expand Down
14 changes: 7 additions & 7 deletions src/aind_exaspim_image_compression/machine_learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ class Trainer:
def __init__(
self,
output_dir,
batch_size=8,
device="cuda:0",
batch_size=16,
device="cuda",
lr=1e-3,
max_epochs=200,
max_epochs=400,
model=None,
use_amp=True,
):
Expand All @@ -46,13 +46,13 @@ def __init__(
output_dir : str
Directory that model checkpoints and tensorboard are written to.
batch_size : int, optional
Number of samples per batch during training. Default is 32.
Number of samples per batch during training. Default is 16.
device : str, optional
GPU device that model is trained on. Default is "cuda:0".
GPU device that model is trained on. Default is "cuda".
lr : float, optional
Learning rate. Default is 1e-3.
max_epochs : int, optional
Maximum number of training epochs. Default is 200.
Maximum number of training epochs. Default is 400.
model : None or nn.Module, optional
Model to be trained on the given datasets. Default is None.
use_amp : bool, optional
Expand Down Expand Up @@ -253,6 +253,6 @@ def save_model(self, epoch):
Current training epoch.
"""
date = datetime.today().strftime("%Y%m%d")
filename = f"BM4DNet-{date}-{epoch}-{self.best_l1:.4f}.pth"
filename = f"BM4DNet-{date}-{epoch}-{self.best_l1:.6f}.pth"
path = os.path.join(self.log_dir, filename)
torch.save(self.model.state_dict(), path)
20 changes: 8 additions & 12 deletions src/aind_exaspim_image_compression/machine_learning/unet3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, width_multiplier=1, trilinear=True):
super(UNet, self).__init__()

# Initializations
_channels = (32, 64, 128, 256, 512)
_channels = (32, 64, 128, 256)
factor = 2 if trilinear else 1

# Instance attributes
Expand All @@ -62,14 +62,12 @@ def __init__(self, width_multiplier=1, trilinear=True):
self.inc = DoubleConv(1, self.channels[0])
self.down1 = Down(self.channels[0], self.channels[1])
self.down2 = Down(self.channels[1], self.channels[2])
self.down3 = Down(self.channels[2], self.channels[3])
self.down4 = Down(self.channels[3], self.channels[4] // factor)
self.down3 = Down(self.channels[2], self.channels[3] // factor)

# Expanding layers
self.up1 = Up(self.channels[4], self.channels[3] // factor, trilinear)
self.up2 = Up(self.channels[3], self.channels[2] // factor, trilinear)
self.up3 = Up(self.channels[2], self.channels[1] // factor, trilinear)
self.up4 = Up(self.channels[1], self.channels[0], trilinear)
self.up1 = Up(self.channels[3], self.channels[2] // factor, trilinear)
self.up2 = Up(self.channels[2], self.channels[1] // factor, trilinear)
self.up3 = Up(self.channels[1], self.channels[0], trilinear)
self.outc = OutConv(self.channels[0], 1)

def forward(self, x):
Expand All @@ -92,13 +90,11 @@ def forward(self, x):
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)

# Expanding layers
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.up1(x4, x3)
x = self.up2(x, x2)
x = self.up3(x, x1)
logits = self.outc(x)
return logits

Expand Down
Loading