Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 28 additions & 40 deletions src/aind_exaspim_image_compression/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@
import torch

from aind_exaspim_image_compression.machine_learning.unet3d import UNet
from aind_exaspim_image_compression.utils import img_util


def predict(
img,
model,
denoised=None,
batch_size=32,
normalization_percentiles=(0.5, 99.9),
patch_size=64,
Expand Down Expand Up @@ -64,59 +62,49 @@ def predict(
"""
# Preprocess image
mn, mx = np.percentile(img, normalization_percentiles)
img = (img - mn) / (mx - mn + 1e-5)
img = (img - mn) / (mx - mn + 1e-8)
img = np.clip(img, 0, 5)
while len(img.shape) < 5:
img = img[np.newaxis, ...]

# Initializations
patch_starts_generator = generate_patch_starts(img, patch_size, overlap)
n_starts = count_patches(img, patch_size, overlap)
if denoised is None:
denoised = np.zeros_like(img)
pbar = tqdm(total=n_starts, desc="Denoise") if verbose else None

# Main
pbar = tqdm(total=n_starts, desc="Denoise") if verbose else None
for i in range(0, n_starts, batch_size):
accum_pred = np.zeros(img.shape[2:])
accum_wgt = np.zeros(img.shape[2:])
for _ in range(0, n_starts, batch_size):
# Extract batch and run model
starts = list(itertools.islice(patch_starts_generator, batch_size))
patches = _predict_batch(img, model, starts, patch_size, trim=trim)

# Store result
# Add batch predictions to result
for patch, start in zip(patches, starts):
start = [max(s + trim, 0) for s in start]
end = [start[i] + patch.shape[i] for i in range(3)]
end = [min(e, s) for e, s in zip(end, img.shape[2:])]
denoised[
0, 0, start[0]:end[0], start[1]:end[1], start[2]:end[2]
] = patch[: end[0] - start[0], : end[1] - start[1], : end[2] - start[2]]
# Compute start and end coordinates
s = [max(si + trim, 0) for si in start]
e = [
min(si + pi, di)
for si, pi, di in zip(s, patch.shape, img.shape[2:])
]

# Create slices
pred_slices = tuple(slice(si, ei) for si, ei in zip(s, e))
patch_slices = tuple(slice(0, ei - si) for si, ei in zip(s, e))

# Add patch prediction to result
accum_pred[pred_slices] += patch[patch_slices]
accum_wgt[pred_slices] += 1

pbar.update(len(starts)) if verbose else None

# Postprocess image
# Postprocess prediction
denoised = accum_pred[:, ...] / (accum_wgt + 1e-8)
denoised = np.clip(denoised * (mx - mn) + mn, 0, 2**16 - 1)
return denoised.astype(np.uint16)


def predict_largescale(
img,
model,
output_path,
compressor,
batch_size=32,
normalization_percentiles=(0.5, 99.9),
patch_size=64,
overlap=12,
output_chunks=(1, 1, 64, 128, 128),
trim=5,
verbose=True
):
# Initializations
denoised = img_util.init_ome_zarr(
img, output_path, compressor=compressor, chunks=output_chunks
)
predict(img, model, denoised=denoised)


def predict_patch(patch, model, normalization_percentiles=(0.5, 99.9)):
"""
Denoises a single 3D patch using the provided model.
Expand All @@ -133,15 +121,15 @@ def predict_patch(patch, model, normalization_percentiles=(0.5, 99.9)):

Returns
-------
numpy.ndarray
pred : numpy.ndarray
Denoised 3D patch with the same shape as input patch.
"""
# Preprocess image
mn, mx = np.percentile(patch, normalization_percentiles)
patch = (patch - mn) / (mx - mn + 1e-5)
patch = (patch - mn) / (mx - mn + 1e-8)
patch = np.clip(patch, 0, 5)
while len(img.shape) < 5:
img = img[np.newaxis, ...]
while len(patch.shape) < 5:
patch = patch[np.newaxis, ...]

# Run model
patch = to_tensor(patch)
Expand Down Expand Up @@ -269,7 +257,7 @@ def load_model(path, device="cuda"):

Returns
-------
torch.nn.Module
model : torch.nn.Module
UNet model loaded with weights and set to evaluation mode.
"""
model = UNet()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ def __init__(
patch_shape,
anisotropy=(0.748, 0.748, 1.0),
boundary_buffer=5000,
foreground_sampling_rate=0.2,
foreground_sampling_rate=0.3,
min_brightness=200,
n_examples_per_epoch=300,
normalization_percentiles=(0.5, 99.9),
prefetch_foreground_sampling=12,
sigma_bm4d=10,
prefetch_foreground_sampling=16,
sigma_bm4d=16,
):
# Call parent class
super(TrainDataset, self).__init__()
Expand Down Expand Up @@ -290,9 +290,9 @@ def sample_segmentation_voxel(self, brain_id):

Returns
-------
Tuple[int]
best_voxel : Tuple[int]
Voxel coordinate whose patch contains a sufficiently large object
or had the largest object after 32 attempts.
or had the largest object after 5 * self.prefetch attempts.
"""
cnt = 0
best_voxel = self.sample_interior_voxel(brain_id)
Expand Down Expand Up @@ -330,8 +330,7 @@ def sample_segmentation_voxel(self, brain_id):

def sample_bright_voxel(self, brain_id):
"""
Samples a voxel coordinate whose surrounding image patch is
sufficiently bright.
Samples a voxel coordinate whose image patch is sufficiently bright.

Parameters
----------
Expand All @@ -340,9 +339,9 @@ def sample_bright_voxel(self, brain_id):

Returns
-------
Tuple[int]
brightest_voxel : Tuple[int]
Voxel coordinate whose patch is sufficiently bright or is the
highest observed brightness after 32 attempts.
highest observed brightness after 5 * self.prefetch attempts.
"""
cnt = 0
brightest_voxel = self.sample_interior_voxel(brain_id)
Expand Down Expand Up @@ -544,12 +543,13 @@ def __getitem__(self, idx):

Returns
-------
tuple
A tuple containing:
- noise (ndarray): Noisy image patch at the given index.
- denoised (ndarray): Corresponding denoised image patch.
- mn_mx (tuple): Minimum and maximum values used for normalization
of the image patches.
noise : numpy.ndarray
Noisy image patch at the given index.
denoised : numpy.ndarray
Corresponding denoised image patch.
mn_mx : Tuple[int]
Minimum and maximum values used for normalization of the image
patches.
"""
return self.noise[idx], self.denoised[idx], self.mn_mxs[idx]

Expand All @@ -563,6 +563,15 @@ class DataLoader:
"""
DataLoader that uses multithreading to fetch image patches from the cloud
to form batches.

Attributes
----------
dataset : torch.utils.data.Dataset
Dataset to iterated over.
batch_size : int
Number of examples in each batch.
patch_shape : Tuple[int]
Shape of image patch expected by the model.
"""

def __init__(self, dataset, batch_size=16):
Expand Down Expand Up @@ -629,7 +638,7 @@ def init_datasets(
n_train_examples_per_epoch=100,
n_validate_examples=0,
segmentation_prefixes_path=None,
sigma_bm4d=10,
sigma_bm4d=16,
swc_pointers=None
):
# Initializations
Expand Down
22 changes: 11 additions & 11 deletions src/aind_exaspim_image_compression/machine_learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
self.batch_size = batch_size
self.device = device
self.max_epochs = max_epochs
self.log_dir = log_dir
self.log_dir = log_dir

self.codec = blosc.Blosc(cname="zstd", clevel=5, shuffle=blosc.SHUFFLE)
self.criterion = nn.L1Loss()
Expand Down Expand Up @@ -162,11 +162,12 @@ def validate_step(self, val_dataloader, epoch):

Returns
-------
tuple
A tuple containing the following:
- float: Average loss over the validation dataset.
- float: Average compression ratio over the validation dataset.
- bool: Indication of whether the model is the best so far.
loss : float
Average loss over the validation dataset.
cratio : float
Average compression ratio over the validation dataset.
is_best : bool
Indication of whether the model is the best so far.
"""
losses = list()
cratios = list()
Expand All @@ -186,12 +187,11 @@ def validate_step(self, val_dataloader, epoch):
self.writer.add_scalar("val_cratio", cratio, epoch)

# Check if current model is best so far
if loss < self.best_l1:
is_best = True if loss < self.best_l1 else False
if is_best:
self.best_l1 = loss
self.save_model(epoch)
return loss, cratio, True
else:
return loss, cratio, False
return loss, cratio, is_best

def forward_pass(self, x, y):
"""
Expand Down Expand Up @@ -224,7 +224,7 @@ def compute_cratios(self, imgs, mn_mx):
imgs = np.array(imgs.detach().cpu())
for i in range(imgs.shape[0]):
mn, mx = tuple(mn_mx[i, :])
img = imgs[i, 0, ...] * (mx - mn) + mn
img = np.clip(imgs[i, 0, ...] * (mx - mn) + mn, 0, 2**16 - 1)
cratios.append(img_util.compute_cratio(img, self.codec))
if i < 10:
tifffile.imwrite(f"{i}.tiff", img)
Expand Down
Loading
Loading