Skip to content

Commit bf8f0f3

Browse files
committed
[Refactor] Refactor loss context API to support multiple loss types
Change loss context from single object to dict-based API: - Update loss_cfg.build() to accept data parameter as dict - Change ModelItem.loss_ctx to dict with loss type keys (e.g. 'lm') - Update model forward pass to accept loss_ctx_dict parameter - Update all tests to use new dict-based loss context API ghstack-source-id: c938145 Pull-Request: InternLM#1569
1 parent fb28789 commit bf8f0f3

29 files changed

+659
-181
lines changed

tests/engine/test_dense_train_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,13 @@ def warmup_fn(x):
8383
seq_ctx = seq_ctx.split(sequence_parallel_mesh=sp_mesh)
8484
seq_ctx_list = [seq_ctx]
8585
LossContext = loss_cfg.loss_ctx_cls
86-
loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=sp_mesh)
86+
loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=sp_mesh)
8787
loss_ctx_list = [loss_ctx]
8888
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
8989

9090
seq_ctx = seq_ctx_list[0]
9191
loss_ctx = loss_ctx_list[0]
92-
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
92+
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})]
9393
loss_log = engine.train_step(engine_input)["logs_info"]
9494
grad_norm = engine.clip_grad_norm()
9595
engine.step_optimizer(grad_norm)

tests/engine/test_moe_train_engine.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,12 @@ def warmup_fn(x):
9393
seq_ctx.num_padding = pack_len
9494
seq_ctx_list = [seq_ctx]
9595
LossContext = loss_cfg.loss_ctx_cls
96-
loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None)
96+
loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None)
9797
loss_ctx_list = [loss_ctx]
9898
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
9999
loss_ctx = loss_ctx_list[0]
100100
seq_ctx = seq_ctx_list[0]
101-
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
101+
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})]
102102
loss_log = engine.train_step(engine_input)["logs_info"]
103103
grad_norm = engine.clip_grad_norm()
104104
engine.step_optimizer(grad_norm)
@@ -184,12 +184,12 @@ def warmup_fn(x):
184184
seq_ctx.num_padding = pack_len
185185
seq_ctx_list = [seq_ctx]
186186
LossContext = loss_cfg.loss_ctx_cls
187-
loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None)
187+
loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None)
188188
loss_ctx_list = [loss_ctx]
189189
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
190190
loss_ctx = loss_ctx_list[0]
191191
seq_ctx = seq_ctx_list[0]
192-
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
192+
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})]
193193
loss_log = engine.train_step(engine_input)["logs_info"]
194194
grad_norm = engine.clip_grad_norm()
195195
engine.step_optimizer(grad_norm)

tests/engine/test_moe_train_engine_float8.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,12 @@ def warmup_fn(x):
8787
seq_ctx.num_padding = pack_len
8888
seq_ctx_list = [seq_ctx]
8989
LossContext = loss_cfg.loss_ctx_cls
90-
loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None)
90+
loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None)
9191
loss_ctx_list = [loss_ctx]
9292
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
9393
loss_ctx = loss_ctx_list[0]
9494
seq_ctx = seq_ctx_list[0]
95-
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
95+
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})]
9696
loss_log = engine.train_step(engine_input)["logs_info"]
9797
grad_norm = engine.clip_grad_norm()
9898
engine.step_optimizer(grad_norm)
@@ -165,12 +165,12 @@ def warmup_fn(x):
165165
seq_ctx.num_padding = pack_len
166166
seq_ctx_list = [seq_ctx]
167167
LossContext = loss_cfg.loss_ctx_cls
168-
loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None)
168+
loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None)
169169
loss_ctx_list = [loss_ctx]
170170
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
171171
loss_ctx = loss_ctx_list[0]
172172
seq_ctx = seq_ctx_list[0]
173-
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
173+
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})]
174174
loss_log = engine.train_step(engine_input)["logs_info"]
175175
grad_norm = engine.clip_grad_norm()
176176
engine.step_optimizer(grad_norm)
@@ -264,12 +264,12 @@ def warmup_fn(x):
264264
seq_ctx.to('cuda')
265265
seq_ctx_list = [seq_ctx]
266266
LossContext = loss_cfg.loss_ctx_cls
267-
loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None)
267+
loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None)
268268
loss_ctx_list = [loss_ctx]
269269
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
270270
loss_ctx = loss_ctx_list[0]
271271
seq_ctx = seq_ctx_list[0]
272-
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
272+
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})]
273273
logs_info = engine.train_step(engine_input)["logs_info"]
274274
grad_norm = engine.clip_grad_norm()
275275
engine.step_optimizer(grad_norm)

tests/loss/test_ce_loss.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test_global_loss_reduction(self, loss_mode, grad_accumulation_steps, chunk_s
7272
for data in data_batch:
7373
seq_ctx = data["seq_ctx"]
7474
seq_ctx_list.append(seq_ctx)
75-
loss_ctx = loss_cfg.build(shifted_labels=data["shifted_labels"], sp_mesh=None)
75+
loss_ctx = loss_cfg.build(data={"shifted_labels": data["shifted_labels"]}, sp_mesh=None)
7676
loss_ctx_list.append(loss_ctx)
7777
loss_ctx_list = CELossContext.build_batches(loss_ctx_list, cu_seq_lens_list=[seq_ctx.cu_seq_lens_q for seq_ctx in seq_ctx_list])
7878

@@ -172,7 +172,7 @@ def test_other_loss_reduction(self, loss_reduction, loss_mode, grad_accumulation
172172
for data in data_batch:
173173
seq_ctx = data["seq_ctx"]
174174
seq_ctx_list.append(seq_ctx)
175-
loss_ctx = loss_cfg.build(shifted_labels=data["shifted_labels"], sp_mesh=None)
175+
loss_ctx = loss_cfg.build(data={"shifted_labels": data["shifted_labels"]}, sp_mesh=None)
176176
loss_ctx_list.append(loss_ctx)
177177
loss_ctx_list = CELossContext.build_batches(loss_ctx_list, cu_seq_lens_list=[seq_ctx.cu_seq_lens_q for seq_ctx in seq_ctx_list])
178178

@@ -310,7 +310,7 @@ def test_sp_global_loss_reduction(self, loss_mode, sp_size, grad_accumulation_st
310310
sp_mesh = data_mesh['sp']
311311
seq_ctx.sequence_parallel_mesh = sp_mesh
312312
seq_ctx_list = [seq_ctx]
313-
loss_ctx = loss_cfg.build(shifted_labels=target, sp_mesh=sp_mesh)
313+
loss_ctx = loss_cfg.build(data={"shifted_labels": target}, sp_mesh=sp_mesh)
314314
loss_ctx_list = [loss_ctx]
315315
if sp_size > 1:
316316
seq_ctx_list[0] = seq_ctx_list[0].split(sequence_parallel_mesh=sp_mesh)
@@ -397,7 +397,7 @@ def test_sp_others_loss_reduction(self, loss_reduction, loss_mode, sp_size, grad
397397
sp_mesh = data_mesh['sp']
398398
seq_ctx.sequence_parallel_mesh = sp_mesh
399399
seq_ctx_list = [seq_ctx]
400-
loss_ctx = loss_cfg.build(shifted_labels=target, sp_mesh=sp_mesh)
400+
loss_ctx = loss_cfg.build(data={"shifted_labels": target}, sp_mesh=sp_mesh)
401401
loss_ctx_list = [loss_ctx]
402402
if sp_size > 1:
403403
seq_ctx_list[0] = seq_ctx_list[0].split(sequence_parallel_mesh=sp_mesh)

tests/loss/test_grpo_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def test_grpo_loss(self, grad_acc, sp_size, kl_loss_coef, loss_mode, chunk_size,
147147
if sp_size > 1:
148148
seq_ctx = seq_ctx.split(sp_mesh)
149149
seq_ctx_list.append(seq_ctx)
150-
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels_list_rank[iter_idx], advantages=advantages_list_rank[iter_idx], sp_mesh=sp_mesh)
150+
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels_list_rank[iter_idx], "advantages": advantages_list_rank[iter_idx]}, sp_mesh=sp_mesh)
151151
loss_ctx_list.append(loss_ctx)
152152

153153
with torch.no_grad():

tests/loss/test_oreal_loss.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,7 @@ def test_grpo_loss(self, grad_acc, sp_size, kl_loss_coef, loss_mode, chunk_size,
216216
seq_ctx = seq_ctx.split(sp_mesh)
217217
seq_ctx_list.append(seq_ctx)
218218
loss_ctx = loss_cfg.build(
219-
shifted_labels=shifted_labels_list_rank[iter_idx],
220-
advantages=advantages_list_rank[iter_idx],
219+
data={"shifted_labels": shifted_labels_list_rank[iter_idx], "advantages": advantages_list_rank[iter_idx]},
221220
sp_mesh=sp_mesh,
222221
)
223222
loss_ctx_list.append(loss_ctx)

tests/model/test_gpt_oss_moe.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_gpt_oss_run(self, device, dispatcher, ep_size, compile, tol, loss_class
7878
loss_cfg = CELossConfig()
7979
seq_ctx_list = [seq_ctx]
8080
LossContext = loss_cfg.loss_ctx_cls
81-
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None)
81+
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
8282
loss_ctx_list = [loss_ctx]
8383
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
8484
loss_ctx = loss_ctx_list[0]
@@ -87,7 +87,7 @@ def test_gpt_oss_run(self, device, dispatcher, ep_size, compile, tol, loss_class
8787
with torch.no_grad():
8888
output = gpt_oss_model(
8989
seq_ctx=seq_ctx,
90-
loss_ctx=loss_ctx,
90+
loss_ctx={"lm": loss_ctx},
9191
)
9292
loss = output["loss"]
9393
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol))
@@ -141,7 +141,7 @@ def test_fsdp_accuracy(self, device, dispatcher, ep_size):
141141
loss_cfg = CELossConfig()
142142
seq_ctx_list = [seq_ctx]
143143
LossContext = loss_cfg.loss_ctx_cls
144-
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None)
144+
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
145145
loss_ctx_list = [loss_ctx]
146146
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
147147
loss_ctx = loss_ctx_list[0]
@@ -152,7 +152,7 @@ def test_fsdp_accuracy(self, device, dispatcher, ep_size):
152152
with torch.no_grad():
153153
output = gpt_oss_model(
154154
seq_ctx=seq_ctx,
155-
loss_ctx=loss_ctx,
155+
loss_ctx={"lm": loss_ctx},
156156
)
157157
loss = output["loss"]
158158
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=1e-2, rtol=1e-2))

tests/model/test_intern_s1.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_interns1_text_run(self, device, tol):
7878

7979
seq_ctx_list = [seq_ctx]
8080
LossContext = loss_cfg.loss_ctx_cls
81-
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None)
81+
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
8282
loss_ctx_list = [loss_ctx]
8383
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
8484
loss_ctx = loss_ctx_list[0]
@@ -87,7 +87,7 @@ def test_interns1_text_run(self, device, tol):
8787
with torch.no_grad():
8888
output = interns1_model(
8989
seq_ctx=seq_ctx,
90-
loss_ctx=loss_ctx,
90+
loss_ctx={"lm": loss_ctx},
9191
)
9292
loss = output["loss"]
9393
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol))
@@ -186,7 +186,7 @@ def test_interns1_image_run(self, device, sp_size, tol):
186186

187187
seq_ctx_list = [seq_ctx]
188188
LossContext = loss_cfg.loss_ctx_cls
189-
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=sp_mesh)
189+
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=sp_mesh)
190190
loss_ctx_list = [loss_ctx]
191191
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
192192
loss_ctx = loss_ctx_list[0]
@@ -195,7 +195,7 @@ def test_interns1_image_run(self, device, sp_size, tol):
195195
with torch.no_grad():
196196
output = interns1_model(
197197
seq_ctx=seq_ctx,
198-
loss_ctx=loss_ctx,
198+
loss_ctx={"lm": loss_ctx},
199199
)
200200
loss = output["loss"]
201201
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol))
@@ -256,7 +256,7 @@ def test_fsdp_text_accuracy(self, device, tol):
256256
seq_ctx_list = [seq_ctx]
257257
loss_cfg = CELossConfig()
258258
LossContext = loss_cfg.loss_ctx_cls
259-
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None)
259+
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
260260
loss_ctx_list = [loss_ctx]
261261
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
262262
loss_ctx = loss_ctx_list[0]
@@ -265,7 +265,7 @@ def test_fsdp_text_accuracy(self, device, tol):
265265
with torch.no_grad():
266266
output = interns1_model(
267267
seq_ctx=seq_ctx,
268-
loss_ctx=loss_ctx,
268+
loss_ctx={"lm": loss_ctx},
269269
)
270270
loss = output["loss"]
271271
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol))
@@ -370,7 +370,7 @@ def test_fsdp_image_accuracy(self, device, sp_size, compile, tol):
370370
seq_ctx_list = [seq_ctx]
371371
loss_cfg = CELossConfig()
372372
LossContext = loss_cfg.loss_ctx_cls
373-
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=sp_mesh)
373+
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=sp_mesh)
374374
loss_ctx_list = [loss_ctx]
375375
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
376376
loss_ctx = loss_ctx_list[0]
@@ -379,7 +379,7 @@ def test_fsdp_image_accuracy(self, device, sp_size, compile, tol):
379379
with torch.no_grad():
380380
output = interns1_model(
381381
seq_ctx=seq_ctx,
382-
loss_ctx=loss_ctx,
382+
loss_ctx={"lm": loss_ctx},
383383
)
384384
loss = output["loss"]
385385
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol))

tests/model/test_moe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@ def test_moe_config(self, dtype, device):
6262

6363
seq_ctx_list = [seq_ctx]
6464
LossContext = loss_cfg.loss_ctx_cls
65-
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None)
65+
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
6666
loss_ctx_list = [loss_ctx]
6767
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
6868
loss_ctx = loss_ctx_list[0]
6969
seq_ctx = seq_ctx_list[0]
70-
model(seq_ctx=seq_ctx, loss_ctx=loss_ctx)
70+
model(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})
7171

7272

7373
class TestDistributedMoE(DeterministicDDPTestCase):
@@ -135,15 +135,15 @@ def test_parallel_accuracy(self, dtype, device, dispatcher, n_shared_experts, fi
135135

136136
seq_ctx_list = [seq_ctx]
137137
LossContext = loss_cfg.loss_ctx_cls
138-
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None)
138+
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
139139
loss_ctx_list = [loss_ctx]
140140
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
141141
loss_ctx = loss_ctx_list[0]
142142
seq_ctx = seq_ctx_list[0]
143143

144-
loss_parallel = parallel_model(seq_ctx=seq_ctx, loss_ctx=loss_ctx)["loss"]
144+
loss_parallel = parallel_model(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})["loss"]
145145

146-
loss_expected = model(seq_ctx=seq_ctx, loss_ctx=loss_ctx)["loss"]
146+
loss_expected = model(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})["loss"]
147147

148148
torch.allclose(loss_expected, loss_parallel, atol=1e-6, rtol=1e-4)
149149

tests/model/test_qwen3_5.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def _forward(self, model, type, device, sp_size):
138138

139139
seq_ctx_list = [seq_ctx]
140140
LossContext = loss_cfg.loss_ctx_cls
141-
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=sp_mesh)
141+
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=sp_mesh)
142142
loss_ctx_list = [loss_ctx]
143143
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
144144
loss_ctx = loss_ctx_list[0]

0 commit comments

Comments
 (0)