Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions src/aind_exaspim_image_compression/machine_learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

"""

from contextlib import nullcontext
from datetime import datetime
from numcodecs import blosc
from torch.optim.lr_scheduler import CosineAnnealingLR
Expand All @@ -34,7 +35,8 @@ def __init__(
device="cuda:0",
lr=1e-3,
max_epochs=200,
model=None
model=None,
use_amp=True,
):
"""
Instantiates a Trainer object.
Expand All @@ -53,6 +55,8 @@ def __init__(
Maximum number of training epochs. Default is 200.
model : None or nn.Module, optional
Model to be trained on the given datasets. Default is None.
use_amp : bool, optional
Indication of whether to use mixed precision. Default is True.
"""
# Initializations
exp_name = "session-" + datetime.today().strftime("%Y%m%d_%H%M")
Expand All @@ -63,18 +67,19 @@ def __init__(
self.batch_size = batch_size
self.device = device
self.max_epochs = max_epochs
self.log_dir = log_dir
self.log_dir = log_dir

self.codec = blosc.Blosc(cname="zstd", clevel=5, shuffle=blosc.SHUFFLE)
self.criterion = nn.L1Loss()
self.model = model.to(device) if model else UNet().to(device)
self.optimizer = optim.AdamW(self.model.parameters(), lr=lr)
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=25)
self.writer = SummaryWriter(log_dir=log_dir)

if model is None:
self.model = UNet().to("cuda")
if use_amp:
self.autocast = torch.autocast(device_type="cuda", dtype=torch.float16)
else:
self.model = model
self.autocast = nullcontext()

# --- Core Routines ---
def run(self, train_dataset, val_dataset):
Expand Down Expand Up @@ -176,7 +181,7 @@ def validate_step(self, val_dataloader, epoch):
losses.append(loss.detach().cpu())

# Log results
loss, cratio = np.mean(losses), np.mean(cratios)
loss, cratio = np.mean(losses), np.median(cratios)
self.writer.add_scalar("val_loss", loss, epoch)
self.writer.add_scalar("val_cratio", cratio, epoch)

Expand Down Expand Up @@ -206,11 +211,12 @@ def forward_pass(self, x, y):
loss : torch.Tensor
Computed loss value.
"""
x = x.to("cuda")
y = y.to("cuda")
hat_y = self.model(x)
loss = self.criterion(hat_y, y)
return hat_y, loss
with self.autocast:
x = x.to("cuda")
y = y.to("cuda")
hat_y = self.model(x)
loss = self.criterion(hat_y, y)
return hat_y, loss

# --- Helpers ---
def compute_cratios(self, imgs, mn_mx):
Expand Down
4 changes: 2 additions & 2 deletions src/aind_exaspim_image_compression/machine_learning/unet3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ def __init__(self, in_channels, out_channels, mid_channels=None):

# Instance attributes
self.double_conv = nn.Sequential(
nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.Conv3d(in_channels, mid_channels, kernel_size=4, padding=1),
nn.BatchNorm3d(mid_channels),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1),
nn.Conv3d(mid_channels, out_channels, kernel_size=4, padding=1),
nn.BatchNorm3d(out_channels),
nn.LeakyReLU(negative_slope=0.01, inplace=True)
)
Expand Down
105 changes: 0 additions & 105 deletions src/aind_exaspim_image_compression/machine_learning/vit3d.py

This file was deleted.

Loading