-
Notifications
You must be signed in to change notification settings - Fork 740
Expand file tree
/
Copy pathglm4_moe.py
More file actions
649 lines (549 loc) · 24 KB
/
glm4_moe.py
File metadata and controls
649 lines (549 loc) · 24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import re
from functools import partial
from typing import Dict
import paddle
from paddle import nn
from paddleformers.transformers import PretrainedModel
from paddleformers.utils.log import logger
import fastdeploy
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import (
support_graph_optimization,
)
from fastdeploy.model_executor.layers.activation import SiluAndMul
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear,
MergedReplicatedLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.normalization import QKRMSNorm, RMSNorm
from fastdeploy.model_executor.models.model_base import (
ModelCategory,
ModelForCasualLM,
ModelRegistry,
)
class Glm4MoeMLP(nn.Layer):
""" """
def __init__(
self,
fd_config: FDConfig,
intermediate_size: int,
layer_id: int,
prefix: str = "",
reduce_results: bool = True,
) -> None:
super().__init__()
# shared experts not split when use_sequence_parallel_moe in ep + tp
if (
fd_config.parallel_config.use_sequence_parallel_moe
and layer_id >= fd_config.model_config.moe_layer_start_index
):
self.up_gate_proj = MergedReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.up_gate_proj",
input_size=fd_config.model_config.hidden_size,
output_size=[intermediate_size, intermediate_size],
with_bias=False,
)
self.down_proj = ReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.down_proj",
input_size=intermediate_size,
output_size=fd_config.model_config.hidden_size,
with_bias=False,
)
else:
self.up_gate_proj = MergedColumnParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.up_gate_proj",
input_size=fd_config.model_config.hidden_size,
output_size=intermediate_size * 2,
with_bias=False,
activation=fd_config.model_config.hidden_act,
)
self.down_proj = RowParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.down_proj",
input_size=intermediate_size,
output_size=fd_config.model_config.hidden_size,
with_bias=False,
reduce_results=reduce_results,
)
self.act_fn = SiluAndMul(
fd_config=fd_config,
bias=None,
act_method=fd_config.model_config.hidden_act,
)
def forward(self, x, forward_meta=None):
""" """
gate_up_out = self.up_gate_proj(x)
act_out = self.act_fn(gate_up_out)
down_out = self.down_proj(act_out)
return down_out
class Glm4Moe(nn.Layer):
def __init__(
self,
fd_config: FDConfig,
layer_id: int = -1,
prefix: str = "",
) -> None:
super().__init__()
self.expert_parallel_size = fd_config.parallel_config.expert_parallel_size
self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
self.tp_group = fd_config.parallel_config.tp_group
self.use_ep = self.expert_parallel_size > 1
self.use_tp = self.tensor_parallel_size > 1
self.n_routed_experts: int = fd_config.model_config.n_routed_experts
self.n_shared_experts: int = fd_config.model_config.n_shared_experts
self.norm_topk_prob = fd_config.model_config.norm_topk_prob
weight_key_map = {
"gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
}
self.gate = ReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.gate",
input_size=fd_config.model_config.hidden_size,
output_size=fd_config.model_config.n_routed_experts,
with_bias=False,
skip_quant=True,
weight_dtype=("float32" if fd_config.model_config.moe_gate_fp32 else ""),
)
self.gate.e_score_correction_bias = self.create_parameter(
shape=[1, fd_config.model_config.n_routed_experts],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
)
# In pure-TP mode (tp>1, ep=1) both branches return partial sums, so we
# defer the all-reduce to after combining them — saving one collective.
# In all other modes (EP, EP+attn-TP, no parallelism) each branch handles
# its own reduction internally (reduce_results default=True), so we must
# NOT add an extra all-reduce here.
self.merge_ffn_tp = self.use_tp and not self.use_ep
self.experts = FusedMoE(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
reduce_results=not self.merge_ffn_tp,
renormalize=self.norm_topk_prob,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
num_experts=fd_config.model_config.n_routed_experts,
top_k=fd_config.model_config.num_experts_per_tok,
topk_method="noaux_tc",
topk_group=fd_config.model_config.topk_group,
n_group=fd_config.model_config.n_group,
routed_scaling_factor=fd_config.model_config.routed_scaling_factor,
layer_idx=layer_id,
gate_correction_bias=self.gate.e_score_correction_bias,
weight_key_map=weight_key_map,
topk_reduce_func=lambda x: x.sum(axis=-1, keepdim=True) + 1e-20,
)
if self.n_shared_experts > 0:
shared_experts_intermediate_size = self.n_shared_experts * fd_config.model_config.moe_intermediate_size
self.shared_experts = Glm4MoeMLP(
fd_config=fd_config,
intermediate_size=shared_experts_intermediate_size,
layer_id=layer_id,
prefix=f"{prefix}.shared_experts",
reduce_results=not self.merge_ffn_tp,
)
def forward(self, x, forward_meta: ForwardMeta = None):
out = self.experts(x, self.gate, forward_meta)
if self.n_shared_experts > 0:
out = out + self.shared_experts(x)
if self.merge_ffn_tp:
# Both branches produced partial sums; combine first, then single all-reduce.
out = tensor_model_parallel_all_reduce(out, self.tp_group)
return out
class Glm4MoeAttention(nn.Layer):
""" """
def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None:
super().__init__()
tp_size = fd_config.parallel_config.tensor_parallel_size
self.fd_config = fd_config
self.head_dim = fd_config.model_config.head_dim
self.num_heads = fd_config.model_config.num_attention_heads // tp_size
self.num_kv_heads = fd_config.model_config.num_key_value_heads // tp_size
self.attention_bias = fd_config.model_config.attention_bias
self.use_qk_norm = fd_config.model_config.use_qk_norm
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.qkv_proj = QKVParallelLinear(fd_config, prefix=f"{prefix}.qkv_proj", with_bias=self.attention_bias)
self.o_proj = RowParallelLinear(
fd_config,
prefix=f"{prefix}.o_proj",
input_size=fd_config.model_config.num_attention_heads * fd_config.model_config.head_dim,
output_size=fd_config.model_config.hidden_size,
layer_id=layer_id,
enable_all_reduce_fusion=fd_config.parallel_config.enable_flashinfer_allreduce_fusion,
)
self.attn = Attention(
fd_config,
layer_id=layer_id,
prefix=prefix,
use_neox_rotary_style=True,
rms_norm_eps=fd_config.model_config.rms_norm_eps,
)
if self.use_qk_norm:
self.qk_norm = QKRMSNorm(
fd_config,
head_dim=self.head_dim,
q_size=self.q_size,
kv_size=self.kv_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=prefix,
begin_norm_axis=2,
)
def forward(
self,
forward_meta: ForwardMeta,
hidden_states: paddle.Tensor,
):
""" """
qkv_out = self.qkv_proj(hidden_states)
if self.use_qk_norm:
qkv_out = self.qk_norm(qkv_out)
atten_out = self.attn(
qkv=qkv_out,
forward_meta=forward_meta,
)
output = self.o_proj(atten_out)
return output
def rms_norm_func(x, weight, eps):
rms_norm_out = paddle.nn.functional.rms_norm(x, x.shape[-1:], weight, eps)
if isinstance(rms_norm_out, (tuple, list)):
return rms_norm_out[0].astype(weight.dtype)
else:
return rms_norm_out.astype(weight.dtype)
class Glm4MoeDecoderLayer(nn.Layer):
""" """
def __init__(
self,
fd_config: FDConfig,
prefix: str = "",
is_mtp: bool = False,
) -> None:
super().__init__()
layer_id = int(prefix.split(sep=".")[-1])
self.self_attn = Glm4MoeAttention(
fd_config=fd_config,
layer_id=layer_id,
prefix=f"{prefix}.self_attn",
)
if fd_config.model_config.n_routed_experts is not None and (
layer_id >= fd_config.model_config.first_k_dense_replace or is_mtp
):
self.mlp = Glm4Moe(fd_config, layer_id, prefix=f"{prefix}.mlp")
else:
self.mlp = Glm4MoeMLP(
fd_config,
intermediate_size=fd_config.model_config.intermediate_size,
layer_id=layer_id,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.input_layernorm",
layer_id=layer_id,
)
self.post_attention_layernorm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.post_attention_layernorm",
layer_id=layer_id,
)
def forward(
self,
forward_meta: ForwardMeta,
hidden_states: paddle.Tensor,
residual: paddle.Tensor = None,
):
""" """
proxy_rmsnorm = rms_norm_func if fastdeploy.envs.FD_USE_PHI_RMSNORM else None
hidden_states, residual = self.input_layernorm(
hidden_states, residual_input=residual, forward_meta=forward_meta, proxy_rmsnorm=proxy_rmsnorm
)
hidden_states = self.self_attn(
hidden_states=hidden_states,
forward_meta=forward_meta,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual, proxy_rmsnorm=proxy_rmsnorm)
hidden_states = self.mlp(hidden_states, forward_meta)
return hidden_states, residual
@support_graph_optimization
class Glm4MoeModel(nn.Layer):
""" """
def __init__(
self,
fd_config: FDConfig = None,
):
"""
Initializer for the Qwen2Model class.
Args:
"""
super().__init__()
self.num_layers = fd_config.model_config.num_hidden_layers
fd_config.model_config.pretrained_config.prefix_name = "model"
self.embed_tokens = VocabParallelEmbedding(
fd_config,
num_embeddings=fd_config.model_config.vocab_size,
embedding_dim=fd_config.model_config.hidden_size,
params_dtype=paddle.get_default_dtype,
prefix=(f"{fd_config.model_config.pretrained_config.prefix_name}.embed_tokens"),
)
self.layers = nn.LayerList(
[
Glm4MoeDecoderLayer(
fd_config,
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}",
)
for i in range(self.num_layers)
]
)
self.norm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.norm",
)
def forward(
self,
ids_remove_padding: paddle.Tensor,
forward_meta: ForwardMeta,
):
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta)
residual = None
for i in range(self.num_layers):
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
out = self.norm.allgather(out, forward_meta.ids_remove_padding.shape[0])
return out
@ModelRegistry.register_model_class(
architecture="Glm4MoeForCausalLM",
module_name="glm4_moe",
category=ModelCategory.TEXT_GENERATION,
primary_use=ModelCategory.TEXT_GENERATION,
)
class Glm4MoeForCausalLM(ModelForCasualLM):
"""
Glm4MoeForCausalLM
"""
def __init__(self, fd_config: FDConfig):
"""
Args:
fd_config (FDConfig): Configurations for the LLM model.
"""
super(Glm4MoeForCausalLM, self).__init__(fd_config)
self.model = Glm4MoeModel(fd_config)
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
self.lm_head = ParallelLMHead(
fd_config,
embedding_dim=fd_config.model_config.hidden_size,
num_embeddings=fd_config.model_config.vocab_size,
prefix="lm_head",
)
@classmethod
def name(self):
""" """
return "Glm4MoeForCausalLM"
@paddle.no_grad()
def load_weights(self, weights_iterator) -> None:
"""
Load model parameters from a given weights_iterator object.
Args:
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
"""
from fastdeploy.model_executor.utils import (
default_weight_loader,
process_weights_after_loading,
)
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("up_gate_proj", "gate_proj", "gate"),
("up_gate_proj", "up_proj", "up"),
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
("experts.gate_correction_bias", "gate.e_score_correction_bias", None),
]
if self.fd_config.model_config.use_qk_norm:
stacked_params_mapping.append(("qk_norm.q_norm", "q_norm", None))
stacked_params_mapping.append(("qk_norm.k_norm", "k_norm", None))
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
num_experts=self.fd_config.model_config.n_routed_experts,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
param_gate_up_proj_name="experts.up_gate_proj_",
param_down_proj_name="experts.down_proj_",
)
params_dict = dict(self.named_parameters())
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()), self.fd_config)
for loaded_weight_name, loaded_weight in weights_iterator:
logger.debug(f"Loading weight: {loaded_weight_name}")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in loaded_weight_name:
continue
if "mlp.experts" in loaded_weight_name:
continue
model_param_name = loaded_weight_name.replace(weight_name, param_name)
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in loaded_weight_name:
continue
model_param_name = loaded_weight_name.replace(weight_name, param_name)
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id)
break
else:
model_param_name = loaded_weight_name
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight)
model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name)
process_weights_after_loading_fn(model_sublayer_name, param)
@paddle.no_grad()
def set_state_dict(self, state_dict):
"""
glm4_moe only support loader_v1.
"""
assert False, "glm4_moe only support --load-choices default_v1."
def compute_logits(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta = None):
""" """
logits = self.lm_head(hidden_states)
logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf")
return logits
def empty_input_forward(self, forward_meta):
"""
empty_input_forward
"""
fake_hidden_states = paddle.empty(
shape=[0, self.fd_config.model_config.hidden_size],
dtype=paddle.get_default_dtype(),
)
for i in range(
self.fd_config.model_config.first_k_dense_replace,
self.fd_config.model_config.num_hidden_layers,
):
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate, forward_meta)
def forward(
self,
inputs: Dict,
forward_meta: ForwardMeta,
):
ids_remove_padding = inputs["ids_remove_padding"]
hidden_states = self.model(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta)
return hidden_states
def clear_grpah_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
class Glm4MoePretrainedModel(PretrainedModel):
"""
Glm4MoePretrainedModel
"""
config_class = FDConfig
def _init_weight(self, layer):
"""
_init_weight
"""
return None
@classmethod
def arch_name(self):
return "Glm4MoeForCausalLM"
@classmethod
def _get_tensor_parallel_mappings(cls, config, is_split=True):
logger.info("Glm4Moe inference model _get_tensor_parallel_mappings")
from fastdeploy.model_executor.models.tp_utils import split_or_merge_func_v1
fn = split_or_merge_func_v1(
is_split=is_split,
tensor_model_parallel_size=config.tensor_model_parallel_size,
tensor_parallel_rank=config.tensor_parallel_rank,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
head_dim=config.head_dim,
)
def get_tensor_parallel_split_mappings(num_layers):
final_actions = {}
base_actions = {
"lm_head.weight": partial(fn, is_column=True),
"embed_tokens.weight": partial(fn, is_column=False),
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
}
# Self Attention Layer which are need TP.
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.bias"] = partial(fn, is_column=True)
# MLP Layer
base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.down_proj.weight"] = partial(fn, is_column=False)
# Moe Layer
for expert_idx in range(config.n_routed_experts):
base_actions[f"layers.0.mlp.experts.{expert_idx}.up_proj.weight"] = partial(fn, is_column=True)
base_actions[f"layers.0.mlp.experts.{expert_idx}.gate_proj.weight"] = partial(fn, is_column=True)
base_actions[f"layers.0.mlp.experts.{expert_idx}.down_proj.weight"] = partial(fn, is_column=False)
# Shared Expert Layer
base_actions["layers.0.mlp.shared_experts.up_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.shared_experts.gate_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.shared_experts.down_proj.weight"] = partial(fn, is_column=False)
# MTP parts
base_actions["layers.46.embed_tokens.weight"] = partial(fn, is_column=False)
base_actions["layers.46.eh_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.46.shared_head.head.weight"] = partial(fn, is_column=True)
for key, action in base_actions.items():
if "layers.0." in key:
for i in range(num_layers):
final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
final_actions[key] = action
return final_actions
mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)
return mappings