Skip to content

Commit afbd39a

Browse files
committed
remove deterministic flags for refactor control
1 parent f843a8f commit afbd39a

6 files changed

Lines changed: 5 additions & 75 deletions

File tree

-658 Bytes
Binary file not shown.

src/data/dataset/dataset_core.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
"""Module `data/dataset_core.py`."""
22
import ctypes
33
import ctypes.util
4-
import hashlib
5-
import os
64
from logging import Logger
75
from typing import Any, Dict, List, Tuple
86

@@ -123,17 +121,8 @@ def collate_fn(dataset, batch):
123121
def mapping_cells(dataset, ds_name, local_idx):
124122
"""Execute `mapping_cells` and return values used by downstream logic."""
125123
assert dataset.data_args.mapping_strategy == "random"
126-
deterministic_mapping = os.getenv("PDIFF_DETERMINISTIC_MAPPING", "0") == "1"
127-
if deterministic_mapping:
128-
base_seed = int(os.getenv("PDIFF_DETERMINISTIC_MAPPING_SEED", "20260304"))
129-
key = f"{base_seed}|{ds_name}|{int(local_idx)}"
130-
digest = hashlib.blake2b(key.encode("utf-8"), digest_size=8).digest()
131-
rng = np.random.default_rng(int.from_bytes(digest, "little", signed=False))
132-
randint = lambda high: int(rng.integers(0, high))
133-
choice = lambda arr: arr[int(rng.integers(0, len(arr)))]
134-
else:
135-
randint = lambda high: int(np.random.randint(0, high, size=(1))[0])
136-
choice = lambda arr: np.random.choice(arr, 1)[0]
124+
randint = lambda high: int(np.random.randint(0, high, size=(1))[0])
125+
choice = lambda arr: np.random.choice(arr, 1)[0]
137126

138127
cache = dataset.meta_cache._cache[dataset.dataset_path_map[ds_name]]
139128

-1.54 KB
Binary file not shown.

src/models/diffusion/diffusion_training.py

Lines changed: 2 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
"""Training/loss methods for GaussianDiffusion."""
22

3-
import json
4-
import os
5-
63
import torch as th
74
from geomloss import SamplesLoss
85

@@ -43,23 +40,6 @@ def _build_training_target(model_mean_type, x_start, noise):
4340
ModelMeanType.EPSILON: noise,
4441
}[model_mean_type]
4542

46-
47-
48-
def _safe_stats(x):
49-
"""Execute `_safe_stats` and return values used by downstream logic."""
50-
if x is None:
51-
return None
52-
if not isinstance(x, th.Tensor):
53-
return str(type(x))
54-
y = x.detach().float()
55-
return {
56-
"shape": list(y.shape),
57-
"mean": float(y.mean().item()),
58-
"std": float(y.std(unbiased=False).item()),
59-
"norm": float(y.norm().item()),
60-
}
61-
62-
6343
class GaussianDiffusionTrainingMixin:
6444
"""Gaussiandiffusiontrainingmixin implementation used by the PerturbDiff pipeline."""
6545
def get_model_output(
@@ -209,15 +189,9 @@ def training_losses(
209189

210190
x_0 = th.zeros_like(x_t)
211191
control_0 = th.zeros_like(control_input_t)
212-
selfcond_used = False
213-
force_selfcond = os.getenv("PDIFF_FORCE_SELFCOND")
214-
if force_selfcond is None:
215-
use_selfcond_now = bool((th.rand(1) > 0.5).item())
216-
else:
217-
use_selfcond_now = force_selfcond.strip() in {"1", "true", "True"}
192+
use_selfcond_now = bool((th.rand(1) > 0.5).item())
218193

219194
if use_selfcond_now:
220-
selfcond_used = True
221195
with th.no_grad():
222196
out = self.get_model_output(
223197
model=model,
@@ -230,7 +204,7 @@ def training_losses(
230204
x_0 = out["x"]
231205
control_0 = th.zeros_like(out["x_control"])
232206

233-
terms, model_output = self.diffusion_loss(
207+
terms = self.diffusion_loss(
234208
model=model,
235209
x_start=x_start,
236210
x_t=x_t,
@@ -242,35 +216,10 @@ def training_losses(
242216
x_0=x_0,
243217
control_0=control_0,
244218
MMD_loss_fn=MMD_loss_fn,
245-
return_model_output=True,
246219
)
247220

248221
if model.model_cfg.no_mse_loss:
249222
terms["mse1"] = th.zeros_like(terms["mse1"])
250-
trace_path = os.getenv("PDIFF_TRACE_PATH")
251-
if trace_path and not getattr(self, "_pdiff_trace_dumped", False):
252-
trace = {
253-
"t_head": [int(v) for v in t.detach().cpu().view(-1)[:8]],
254-
"x_start": _safe_stats(x_start),
255-
"control_input_start": _safe_stats(control_input_start),
256-
"noise": _safe_stats(noise),
257-
"x_t": _safe_stats(x_t),
258-
"control_input_t": _safe_stats(control_input_t),
259-
"batch_emb_is_none": self_condition.get("batch_emb") is None,
260-
"cont_emb_is_none": self_condition.get("cont_emb") is None,
261-
"model_output_x": _safe_stats(model_output.get("x")),
262-
"model_output_x_control": _safe_stats(model_output.get("x_control")),
263-
"selfcond_used": bool(selfcond_used),
264-
"x_0": _safe_stats(x_0),
265-
"control_0": _safe_stats(control_0),
266-
"terms_loss1_mean": float(terms["loss1"].detach().float().mean().item()),
267-
"terms_mse1_mean": float(terms["mse1"].detach().float().mean().item()),
268-
"terms_mmd1_mean": float(terms["mmd1"].detach().float().mean().item()) if "mmd1" in terms else None,
269-
}
270-
os.makedirs(os.path.dirname(trace_path), exist_ok=True)
271-
with open(trace_path, "w", encoding="utf-8") as fout:
272-
json.dump(trace, fout, indent=2)
273-
self._pdiff_trace_dumped = True
274223
return terms
275224

276225

-288 Bytes
Binary file not shown.

src/models/lightning/lightning_module.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Lightning training module split from lightning_module (logic-preserving)."""
22

33
import gc
4-
import os
54
import pickle
65
import sys
76
import time
@@ -231,22 +230,15 @@ def _compute_loss(self, batch):
231230
"ds_name": batch["ds_name"],
232231
}
233232

234-
align_rng = os.getenv("PDIFF_FORCE_ALIGNED_RNG", "0") == "1"
235-
if align_rng:
236-
step_seed = int(os.getenv("PDIFF_SEED_BASE", "12345")) + int(self.global_step)
237-
np.random.seed(step_seed)
238-
torch.manual_seed(step_seed)
239-
240233
t, weights = self.schedule_sampler.sample(pert_emb.shape[0], device)
241-
noise = torch.randn_like(pert_emb, dtype=torch.float64) if align_rng else None
242234

243235
losses = self.diffusion.training_losses(
244236
self.model,
245237
pert_emb,
246238
t,
247239
self_condition=cond,
248240
model_kwargs=None,
249-
noise=noise,
241+
noise=None,
250242
p_drop_cond=self.model_cfg.p_drop_cond,
251243
MMD_loss_fn=self.loss_fn,
252244
)

0 commit comments

Comments
 (0)