-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Expand file tree
/
Copy pathtest.py
More file actions
67 lines (54 loc) · 1.75 KB
/
test.py
File metadata and controls
67 lines (54 loc) · 1.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import config
from tqdm import tqdm
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from dataset import HorseZebraDataset
from generator_model import Generator
from utils import load_checkpoint
def test_fn(gen_Z, gen_H, loader):
loop = tqdm(loader, leave=True)
for idx, (zebra, horse) in enumerate(loop):
zebra = zebra.to(config.DEVICE)
horse = horse.to(config.DEVICE)
with torch.cuda.amp.autocast():
fake_horse = gen_H(zebra)
fake_zebra = gen_Z(horse)
save_image(fake_horse * 0.5 + 0.5, f"saved_images/horse_{idx}.png")
save_image(fake_zebra * 0.5 + 0.5, f"saved_images/zebra_{idx}.png")
def main():
gen_Z = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
gen_H = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
opt_gen = optim.Adam(
list(gen_Z.parameters()) + list(gen_H.parameters()),
lr=config.LEARNING_RATE,
betas=(0.5, 0.999),
)
load_checkpoint(
config.CHECKPOINT_GEN_H,
gen_H,
opt_gen,
config.LEARNING_RATE,
)
load_checkpoint(
config.CHECKPOINT_GEN_Z,
gen_Z,
opt_gen,
config.LEARNING_RATE,
)
val_dataset = HorseZebraDataset(
root_horse=config.VAL_DIR + "/testA",
root_zebra=config.VAL_DIR + "/testB",
transform=config.transforms,
)
loader = DataLoader(
val_dataset,
batch_size=config.BATCH_SIZE,
shuffle=False,
num_workers=config.NUM_WORKERS,
pin_memory=True,
)
test_fn(gen_Z, gen_H, loader)
if __name__ == "__main__":
main()