diff --git a/src/aind_exaspim_image_compression/machine_learning/train.py b/src/aind_exaspim_image_compression/machine_learning/train.py index 9056f5f..eebf601 100644 --- a/src/aind_exaspim_image_compression/machine_learning/train.py +++ b/src/aind_exaspim_image_compression/machine_learning/train.py @@ -8,6 +8,7 @@ """ +from contextlib import nullcontext from datetime import datetime from numcodecs import blosc from torch.optim.lr_scheduler import CosineAnnealingLR @@ -34,7 +35,8 @@ def __init__( device="cuda:0", lr=1e-3, max_epochs=200, - model=None + model=None, + use_amp=True, ): """ Instantiates a Trainer object. @@ -53,6 +55,8 @@ def __init__( Maximum number of training epochs. Default is 200. model : None or nn.Module, optional Model to be trained on the given datasets. Default is None. + use_amp : bool, optional + Indication of whether to use mixed precision. Default is True. """ # Initializations exp_name = "session-" + datetime.today().strftime("%Y%m%d_%H%M") @@ -63,18 +67,19 @@ def __init__( self.batch_size = batch_size self.device = device self.max_epochs = max_epochs - self.log_dir = log_dir + self.log_dir = log_dir self.codec = blosc.Blosc(cname="zstd", clevel=5, shuffle=blosc.SHUFFLE) self.criterion = nn.L1Loss() + self.model = model.to(device) if model else UNet().to(device) self.optimizer = optim.AdamW(self.model.parameters(), lr=lr) self.scheduler = CosineAnnealingLR(self.optimizer, T_max=25) self.writer = SummaryWriter(log_dir=log_dir) - if model is None: - self.model = UNet().to("cuda") + if use_amp: + self.autocast = torch.autocast(device_type="cuda", dtype=torch.float16) else: - self.model = model + self.autocast = nullcontext() # --- Core Routines --- def run(self, train_dataset, val_dataset): @@ -176,7 +181,7 @@ def validate_step(self, val_dataloader, epoch): losses.append(loss.detach().cpu()) # Log results - loss, cratio = np.mean(losses), np.mean(cratios) + loss, cratio = np.mean(losses), np.median(cratios) self.writer.add_scalar("val_loss", loss, epoch) self.writer.add_scalar("val_cratio", cratio, epoch) @@ -206,11 +211,12 @@ def forward_pass(self, x, y): loss : torch.Tensor Computed loss value. """ - x = x.to("cuda") - y = y.to("cuda") - hat_y = self.model(x) - loss = self.criterion(hat_y, y) - return hat_y, loss + with self.autocast: + x = x.to("cuda") + y = y.to("cuda") + hat_y = self.model(x) + loss = self.criterion(hat_y, y) + return hat_y, loss # --- Helpers --- def compute_cratios(self, imgs, mn_mx): diff --git a/src/aind_exaspim_image_compression/machine_learning/unet3d.py b/src/aind_exaspim_image_compression/machine_learning/unet3d.py index 9ec10ce..6a6d351 100644 --- a/src/aind_exaspim_image_compression/machine_learning/unet3d.py +++ b/src/aind_exaspim_image_compression/machine_learning/unet3d.py @@ -138,10 +138,10 @@ def __init__(self, in_channels, out_channels, mid_channels=None): # Instance attributes self.double_conv = nn.Sequential( - nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1), + nn.Conv3d(in_channels, mid_channels, kernel_size=4, padding=1), nn.BatchNorm3d(mid_channels), nn.LeakyReLU(negative_slope=0.01, inplace=True), - nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1), + nn.Conv3d(mid_channels, out_channels, kernel_size=4, padding=1), nn.BatchNorm3d(out_channels), nn.LeakyReLU(negative_slope=0.01, inplace=True) ) diff --git a/src/aind_exaspim_image_compression/machine_learning/vit3d.py b/src/aind_exaspim_image_compression/machine_learning/vit3d.py deleted file mode 100644 index 6897a17..0000000 --- a/src/aind_exaspim_image_compression/machine_learning/vit3d.py +++ /dev/null @@ -1,105 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange - - -class PatchEmbedding3D(nn.Module): - def __init__(self, in_channels, patch_shape, emb_size, img_shape): - # Call parent class - super().__init__() - - # Class attributes - self.patch_shape = patch_shape - self.emb_size = emb_size - - # Embedding - n_patches = np.prod([img_shape[i] // patch_shape[i] for i in range(3)]) - self.pos_embedding = nn.Parameter(torch.randn(1, n_patches, emb_size)) - self.proj = nn.Conv3d( - in_channels, emb_size, kernel_size=patch_shape, stride=patch_shape - ) - self.dropout = nn.Dropout(0.1) - - def forward(self, x): - x = self.proj(x) - x = rearrange(x, "b c d h w -> b (d h w) c") - x = x + self.pos_embedding - return self.dropout(x) - - -class TransformerEncoderBlock(nn.Module): - def __init__(self, emb_size, heads, mlp_dim, dropout=0.1): - # Call parent class - super().__init__() - - # Attention head - self.norm1 = nn.LayerNorm(emb_size) - self.attn = nn.MultiheadAttention( - emb_size, heads, dropout=dropout, batch_first=True - ) - self.norm2 = nn.LayerNorm(emb_size) - self.mlp = nn.Sequential( - nn.Linear(emb_size, mlp_dim), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(mlp_dim, emb_size), - nn.Dropout(dropout), - ) - - def forward(self, x): - x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0] - x = x + self.mlp(self.norm2(x)) - return x - - -class ViT3D(nn.Module): - def __init__( - self, - in_channels=1, - img_shape=(64, 64, 64), - patch_shape=(8, 8, 8), - emb_size=512, - depth=6, - heads=8, - mlp_dim=1024, - ): - # Call parent class - super().__init__() - - # Class attributes - self.patch_shape = patch_shape - self.grid_size = [img_shape[i] // patch_shape[i] for i in range(3)] - - # Transformer Layers - self.patch_embed = PatchEmbedding3D( - in_channels, patch_shape, emb_size, img_shape - ) - self.transformer = nn.Sequential( - *[ - TransformerEncoderBlock(emb_size, heads, mlp_dim) - for _ in range(depth) - ] - ) - self.output_head = nn.Linear( - emb_size, np.prod(patch_shape) * in_channels - ) - - def forward(self, x): - batch_size = x.size(0) - x = self.patch_embed(x) - x = self.transformer(x) - x = self.output_head(x) - x = x.view(batch_size, -1, *self.patch_shape) - x = rearrange( - x, - "(b d h w) c pd ph pw -> b c (d pd) (h ph) (w pw)", - b=batch_size, - d=self.grid_size[0], - h=self.grid_size[1], - w=self.grid_size[2], - pd=self.patch_shape[0], - ph=self.patch_shape[1], - pw=self.patch_shape[2], - ) - return x