@@ -78,7 +78,7 @@ def test_interns1_text_run(self, device, tol):
7878
7979 seq_ctx_list = [seq_ctx ]
8080 LossContext = loss_cfg .loss_ctx_cls
81- loss_ctx = loss_cfg .build (shifted_labels = shifted_labels , sp_mesh = None )
81+ loss_ctx = loss_cfg .build (data = { " shifted_labels" : shifted_labels } , sp_mesh = None )
8282 loss_ctx_list = [loss_ctx ]
8383 loss_ctx_list = LossContext .build_batches (loss_ctx_list )
8484 loss_ctx = loss_ctx_list [0 ]
@@ -87,7 +87,7 @@ def test_interns1_text_run(self, device, tol):
8787 with torch .no_grad ():
8888 output = interns1_model (
8989 seq_ctx = seq_ctx ,
90- loss_ctx = loss_ctx ,
90+ loss_ctx = { "lm" : loss_ctx } ,
9191 )
9292 loss = output ["loss" ]
9393 self .assertTrue (torch .allclose (loss , expected_loss .to (loss .dtype ), atol = tol , rtol = tol ))
@@ -186,7 +186,7 @@ def test_interns1_image_run(self, device, sp_size, tol):
186186
187187 seq_ctx_list = [seq_ctx ]
188188 LossContext = loss_cfg .loss_ctx_cls
189- loss_ctx = loss_cfg .build (shifted_labels = shifted_labels , sp_mesh = sp_mesh )
189+ loss_ctx = loss_cfg .build (data = { " shifted_labels" : shifted_labels } , sp_mesh = sp_mesh )
190190 loss_ctx_list = [loss_ctx ]
191191 loss_ctx_list = LossContext .build_batches (loss_ctx_list )
192192 loss_ctx = loss_ctx_list [0 ]
@@ -195,7 +195,7 @@ def test_interns1_image_run(self, device, sp_size, tol):
195195 with torch .no_grad ():
196196 output = interns1_model (
197197 seq_ctx = seq_ctx ,
198- loss_ctx = loss_ctx ,
198+ loss_ctx = { "lm" : loss_ctx } ,
199199 )
200200 loss = output ["loss" ]
201201 self .assertTrue (torch .allclose (loss , expected_loss .to (loss .dtype ), atol = tol , rtol = tol ))
@@ -256,7 +256,7 @@ def test_fsdp_text_accuracy(self, device, tol):
256256 seq_ctx_list = [seq_ctx ]
257257 loss_cfg = CELossConfig ()
258258 LossContext = loss_cfg .loss_ctx_cls
259- loss_ctx = loss_cfg .build (shifted_labels = shifted_labels , sp_mesh = None )
259+ loss_ctx = loss_cfg .build (data = { " shifted_labels" : shifted_labels } , sp_mesh = None )
260260 loss_ctx_list = [loss_ctx ]
261261 loss_ctx_list = LossContext .build_batches (loss_ctx_list )
262262 loss_ctx = loss_ctx_list [0 ]
@@ -265,7 +265,7 @@ def test_fsdp_text_accuracy(self, device, tol):
265265 with torch .no_grad ():
266266 output = interns1_model (
267267 seq_ctx = seq_ctx ,
268- loss_ctx = loss_ctx ,
268+ loss_ctx = { "lm" : loss_ctx } ,
269269 )
270270 loss = output ["loss" ]
271271 self .assertTrue (torch .allclose (loss , expected_loss .to (loss .dtype ), atol = tol , rtol = tol ))
@@ -370,7 +370,7 @@ def test_fsdp_image_accuracy(self, device, sp_size, compile, tol):
370370 seq_ctx_list = [seq_ctx ]
371371 loss_cfg = CELossConfig ()
372372 LossContext = loss_cfg .loss_ctx_cls
373- loss_ctx = loss_cfg .build (shifted_labels = shifted_labels , sp_mesh = sp_mesh )
373+ loss_ctx = loss_cfg .build (data = { " shifted_labels" : shifted_labels } , sp_mesh = sp_mesh )
374374 loss_ctx_list = [loss_ctx ]
375375 loss_ctx_list = LossContext .build_batches (loss_ctx_list )
376376 loss_ctx = loss_ctx_list [0 ]
@@ -379,7 +379,7 @@ def test_fsdp_image_accuracy(self, device, sp_size, compile, tol):
379379 with torch .no_grad ():
380380 output = interns1_model (
381381 seq_ctx = seq_ctx ,
382- loss_ctx = loss_ctx ,
382+ loss_ctx = { "lm" : loss_ctx } ,
383383 )
384384 loss = output ["loss" ]
385385 self .assertTrue (torch .allclose (loss , expected_loss .to (loss .dtype ), atol = tol , rtol = tol ))
0 commit comments