11"""Training/loss methods for GaussianDiffusion."""
22
3- import json
4- import os
5-
63import torch as th
74from geomloss import SamplesLoss
85
@@ -43,23 +40,6 @@ def _build_training_target(model_mean_type, x_start, noise):
4340 ModelMeanType .EPSILON : noise ,
4441 }[model_mean_type ]
4542
46-
47-
48- def _safe_stats (x ):
49- """Execute `_safe_stats` and return values used by downstream logic."""
50- if x is None :
51- return None
52- if not isinstance (x , th .Tensor ):
53- return str (type (x ))
54- y = x .detach ().float ()
55- return {
56- "shape" : list (y .shape ),
57- "mean" : float (y .mean ().item ()),
58- "std" : float (y .std (unbiased = False ).item ()),
59- "norm" : float (y .norm ().item ()),
60- }
61-
62-
6343class GaussianDiffusionTrainingMixin :
6444 """Gaussiandiffusiontrainingmixin implementation used by the PerturbDiff pipeline."""
6545 def get_model_output (
@@ -209,15 +189,9 @@ def training_losses(
209189
210190 x_0 = th .zeros_like (x_t )
211191 control_0 = th .zeros_like (control_input_t )
212- selfcond_used = False
213- force_selfcond = os .getenv ("PDIFF_FORCE_SELFCOND" )
214- if force_selfcond is None :
215- use_selfcond_now = bool ((th .rand (1 ) > 0.5 ).item ())
216- else :
217- use_selfcond_now = force_selfcond .strip () in {"1" , "true" , "True" }
192+ use_selfcond_now = bool ((th .rand (1 ) > 0.5 ).item ())
218193
219194 if use_selfcond_now :
220- selfcond_used = True
221195 with th .no_grad ():
222196 out = self .get_model_output (
223197 model = model ,
@@ -230,7 +204,7 @@ def training_losses(
230204 x_0 = out ["x" ]
231205 control_0 = th .zeros_like (out ["x_control" ])
232206
233- terms , model_output = self .diffusion_loss (
207+ terms = self .diffusion_loss (
234208 model = model ,
235209 x_start = x_start ,
236210 x_t = x_t ,
@@ -242,35 +216,10 @@ def training_losses(
242216 x_0 = x_0 ,
243217 control_0 = control_0 ,
244218 MMD_loss_fn = MMD_loss_fn ,
245- return_model_output = True ,
246219 )
247220
248221 if model .model_cfg .no_mse_loss :
249222 terms ["mse1" ] = th .zeros_like (terms ["mse1" ])
250- trace_path = os .getenv ("PDIFF_TRACE_PATH" )
251- if trace_path and not getattr (self , "_pdiff_trace_dumped" , False ):
252- trace = {
253- "t_head" : [int (v ) for v in t .detach ().cpu ().view (- 1 )[:8 ]],
254- "x_start" : _safe_stats (x_start ),
255- "control_input_start" : _safe_stats (control_input_start ),
256- "noise" : _safe_stats (noise ),
257- "x_t" : _safe_stats (x_t ),
258- "control_input_t" : _safe_stats (control_input_t ),
259- "batch_emb_is_none" : self_condition .get ("batch_emb" ) is None ,
260- "cont_emb_is_none" : self_condition .get ("cont_emb" ) is None ,
261- "model_output_x" : _safe_stats (model_output .get ("x" )),
262- "model_output_x_control" : _safe_stats (model_output .get ("x_control" )),
263- "selfcond_used" : bool (selfcond_used ),
264- "x_0" : _safe_stats (x_0 ),
265- "control_0" : _safe_stats (control_0 ),
266- "terms_loss1_mean" : float (terms ["loss1" ].detach ().float ().mean ().item ()),
267- "terms_mse1_mean" : float (terms ["mse1" ].detach ().float ().mean ().item ()),
268- "terms_mmd1_mean" : float (terms ["mmd1" ].detach ().float ().mean ().item ()) if "mmd1" in terms else None ,
269- }
270- os .makedirs (os .path .dirname (trace_path ), exist_ok = True )
271- with open (trace_path , "w" , encoding = "utf-8" ) as fout :
272- json .dump (trace , fout , indent = 2 )
273- self ._pdiff_trace_dumped = True
274223 return terms
275224
276225
0 commit comments