-
Notifications
You must be signed in to change notification settings - Fork 687
Expand file tree
/
Copy pathmodule.py
More file actions
635 lines (510 loc) · 24.8 KB
/
module.py
File metadata and controls
635 lines (510 loc) · 24.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from functools import partial
import torch
from .parameter import get_params
from .source_model.base import BaseReader
from .target_model.base import BaseOutputModel
def permute_v2(x: torch.Tensor, size_per_head: int = 128):
"""
Contract: x.size(-1) is output dims
"""
assert x.size(-1) > 1
output_dims = x.size(-1)
head_num = output_dims // size_per_head
return x.view(-1, head_num, 2, size_per_head // 2).transpose(2, 3).reshape(x.shape)
def permute_v2_partial(x: torch.Tensor, size_per_head: int, rotary_dim: int):
"""Permute only the first rotary_dim elements of each head.
Used when partial_rotary_factor < 1.0: only the rotary portion needs interleaving for TurboMind's RoPE kernel
layout.
"""
assert x.size(-1) > 1
assert rotary_dim % 2 == 0, f'rotary_dim must be even, got {rotary_dim}'
assert rotary_dim <= size_per_head, f'rotary_dim ({rotary_dim}) must be <= size_per_head ({size_per_head})'
output_dims = x.size(-1)
assert output_dims % size_per_head == 0, (f'output_dims ({output_dims}) must be divisible by '
f'size_per_head ({size_per_head})')
head_num = output_dims // size_per_head
orig_shape = x.shape
if x.dim() == 1:
x = x.unsqueeze(0)
x = x.view(x.size(0), head_num, size_per_head)
rotary = x[:, :, :rotary_dim]
passthrough = x[:, :, rotary_dim:]
# Interleave rotary part: [2, rotary_dim//2] -> [rotary_dim//2, 2]
rotary = rotary.view(x.size(0), head_num, 2, rotary_dim // 2).transpose(2, 3).contiguous()
rotary = rotary.view(x.size(0), head_num, rotary_dim)
x = torch.cat([rotary, passthrough], dim=-1)
return x.reshape(orig_shape)
def merge_qkv_v2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tp: int):
"""
Contract: x.size(-1) is output dims
"""
def reshape(x):
return x.view(x.size(0), tp, -1) if q.dim() == 2 else x.view(tp, -1)
qkv = torch.cat(tuple(map(reshape, (q, k, v))), dim=-1)
qkv = qkv.view(-1, qkv.size(-1) * tp)
if q.dim() == 1:
qkv.squeeze_()
return qkv
def merge_qkvg_v2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, gate: torch.Tensor, tp: int):
"""Merge Q, K, V, and Gate with gate appended after V.
Layout per tp-shard: [Q | K | V | Gate].
"""
def reshape(x):
return x.view(x.size(0), tp, -1) if q.dim() == 2 else x.view(tp, -1)
qkvg = torch.cat(tuple(map(reshape, (q, k, v, gate))), dim=-1)
qkvg = qkvg.view(-1, qkvg.size(-1) * tp)
if q.dim() == 1:
qkvg.squeeze_()
return qkvg
def transpose(x):
return x.t() if x is not None else x
def pad_out_dims(x: torch.Tensor, dims: int):
pad = dims - x.size(-1)
assert pad >= 0
return torch.nn.functional.pad(x, (0, pad), 'constant', 0)
def pad_in_dims(x: torch.Tensor, dims: int):
if x.dim() == 1: # 1-dim object does not have input dim (e.g. bias)
return x
pad = dims - x.size(0)
assert x.dim() == 2
assert pad >= 0
return torch.nn.functional.pad(x, (0, 0, 0, pad), 'constant', 0)
# split out dims -> copy A, split-out-dims B (qkv, w1, w3)
# split in dims -> split-in-dims A, copy B ( o, w2)
def get_lora_flags(kind: str):
return ('lora_a' in kind, 'lora_b' in kind)
class Module(ABC):
def __init__(self, model: BaseOutputModel):
self.model = model
def __call__(self, *args, **kwargs):
return self.apply(*args, **kwargs)
@abstractmethod
def apply(self, idx: int, r: BaseReader):
pass
class LayerNorm(Module):
def apply(self, i: int, r: BaseReader):
attn_norm = r.attn_norm(i)
ffn_norm = r.ffn_norm(i)
self.model.save_split(attn_norm, f'layers.{i}.attention_norm.weight')
self.model.save_split(ffn_norm, f'layers.{i}.ffn_norm.weight')
class Ffn(Module):
"""
requires:
r.ffn(i, kind)
"""
_ffn = 'layers.{0}.feed_forward.{1}.{2}'
def __init__(self, model: BaseOutputModel):
self.model = model
self.tp = model.mlp_tp_size if model.model_config.ep_size == 1 else 1
# inter_sizes in config are padded and may be different from what's
# in the weights
self.inter_size = model.model_config.inter_size
self.group_size = max(1, model.model_config.group_size)
def _export(self, inter_size: int, fmt: str, idx: int, w123, kind: str, pack_fn, apply_gs=[], **kwargs):
is_lora_a, is_lora_b = get_lora_flags(kind)
w1, w2, w3 = map(transpose, w123)
gs1 = self.group_size if 'w1' in apply_gs else 1
w1 = pad_out_dims(w1, inter_size // gs1)
gs3 = self.group_size if 'w3' in apply_gs else 1
w3 = pad_out_dims(w3, inter_size // gs3)
gs2 = self.group_size if 'w2' in apply_gs else 1
w2 = pad_in_dims(w2, inter_size // gs2)
w1, w2, w3 = map(pack_fn, (w1, w2, w3))
self.model.save_split(w1, fmt.format(idx, 'w1', kind), split_dim=-1, split_num=self.tp, copy=is_lora_a)
self.model.save_split(w3, fmt.format(idx, 'w3', kind), split_dim=-1, split_num=self.tp, copy=is_lora_a)
self.model.save_split(w2, fmt.format(idx, 'w2', kind), split_dim=0, split_num=self.tp, copy=is_lora_b)
def apply(self, i: int, r: BaseReader):
if i >= len(self.inter_size) or not self.inter_size[i]:
return
keys = r.ffn(i, None)
for e in get_params(keys):
e(partial(self._export, self.inter_size[i], self._ffn), partial(r.ffn, i), i)
class MoeFfn(Ffn):
"""
requires:
r.moe_ffn_expert(e, i, kind)
r.moe_ffn_gate(i)
r.moe_ffn_shared_gate(i)
"""
_moe_ffn_expert = 'layers.{0}.moe_ffn.experts.E.{1}.{2}'
_moe_ffn_gate = 'layers.{0}.moe_ffn.gate.{1}'
_moe_ffn_shared_gate = 'layers.{0}.moe_ffn.shared_gate.weight'
def __init__(self, model: BaseOutputModel):
super().__init__(model)
self.expert_num = model.model_config.expert_num
self.inter_size = model.model_config.expert_inter_size
self.shared_gate = model.model_config.moe_shared_gate
def apply(self, i: int, r: BaseReader):
if i >= len(self.expert_num) or self.expert_num[i] == 0:
return
# Export expert weights with outer loop over experts (not params)
# to ensure each expert's full weight set is grouped together
for e in range(self.expert_num[i]):
for p in get_params(r.moe_ffn_expert(), 1):
fmt = self._moe_ffn_expert.replace('E', str(e))
p(partial(self._export, self.inter_size, fmt), partial(r.moe_ffn_expert, e, i), i)
# router
gate = transpose(r.moe_ffn_gate(i, 'weight'))
self.model.save_split(gate, self._moe_ffn_gate.format(i, 'weight'))
bias = r.moe_ffn_gate(i, 'bias')
if bias is not None:
self.model.save_split(bias, self._moe_ffn_gate.format(i, 'bias'))
# Export score_correction_bias for noaux_tc routing (GLM 4.7 Flash)
correction_bias = getattr(r, 'moe_ffn_gate_correction_bias', None)
if callable(correction_bias):
correction = correction_bias(i)
if correction is not None:
self.model.save_split(correction, self._moe_ffn_gate.format(i, 'score_correction_bias'))
if self.shared_gate:
shared_gate = transpose(r.moe_ffn_shared_gate(i))
self.model.save_split(shared_gate, self._moe_ffn_shared_gate.format(i))
class Attn(Module):
"""
requires:
r.attn(i, kind)
"""
_attn = 'layers.{0}.attention.{1}.{2}'
def __init__(self, model: BaseOutputModel):
self.model = model
self.tp = model.attn_tp_size
self.head_dim = model.model_config.size_per_head
self.attn_bias = model.model_config.attn_bias
self.qk_norm = model.model_config.qk_norm
self.attn_sink = model.model_config.attn_sink
self.group_size = max(1, model.model_config.group_size)
self.attn_output_gate = model.model_config.attn_output_gate
rope_param = model.attention_config.rope_param
self.rope_dim = rope_param.dim if rope_param else self.head_dim
self.head_num = model.model_config.head_num
def _split_q_gate(self, q):
"""Split interleaved Q+gate tensor into separate Q and gate.
HF layout: [Q_head0, Gate_head0, Q_head1, Gate_head1, ...]
Returns: (q_real, gate) each with shape [..., num_heads * head_dim]
"""
output_dims = q.size(-1)
head_num = output_dims // (self.head_dim * 2)
orig_shape = list(q.shape)
if q.dim() == 1:
q = q.unsqueeze(0)
q = q.view(q.size(0), head_num, 2, self.head_dim)
q_real = q[:, :, 0, :].contiguous()
gate = q[:, :, 1, :].contiguous()
new_last_dim = head_num * self.head_dim
q_real = q_real.reshape(-1, new_last_dim)
gate = gate.reshape(-1, new_last_dim)
if len(orig_shape) == 1:
q_real = q_real.squeeze(0)
gate = gate.squeeze(0)
return q_real, gate
def _reorder_and_merge(self, qkvo, gs: int):
q, k, v, o = qkvo
gate = None
# When attn_output_gate, Q is interleaved [Q0, G0, Q1, G1, ...]
# Split into separate Q and gate before permuting
if self.attn_output_gate and q is not None:
q, gate = self._split_q_gate(q)
# reorder output dim for tm's rotary embedding layout
if self.model.permute_qk:
if gs == 1:
if self.rope_dim < self.head_dim:
q = permute_v2_partial(q, self.head_dim, self.rope_dim)
k = permute_v2_partial(k, self.head_dim, self.rope_dim)
else:
q = permute_v2(q, self.head_dim)
k = permute_v2(k, self.head_dim)
else:
assert gs % self.head_dim == 0
# Merge QKV with gate appended at end if present
if gate is not None:
qkv = merge_qkvg_v2(q, k, v, gate, self.tp)
else:
qkv = merge_qkv_v2(q, k, v, self.tp)
# zero bias for `wo` when `w_qkv` has bias but `wo` doesn't
if o is None and q.dim() == 1:
o = torch.zeros_like(q)
return qkv, o
def _repeat_kv(self, qkvo, gs: int, kind: str):
"""Replicate kv."""
q, k, v, o = qkvo
head_dim = self.model.model_config.size_per_head // gs
kv_head_num = self.model.model_config.kv_head_num // self.model.repeat_kv
hidden_dim = self.model.model_config.hidden_units
def _repeat(x):
n = self.model.repeat_kv
x = x.reshape(-1, kv_head_num, head_dim)
x = x.repeat(1, 1, n)
x = x.reshape(-1, kv_head_num * n * head_dim)
return x
k, v = map(_repeat, (k, v))
if kind == 'bias':
if o is None:
o = torch.zeros(hidden_dim, dtype=q.dtype, device=q.device)
q, k, v, o = map(torch.squeeze, (q, k, v, o))
return (q, k, v, o)
def _export(self, idx: int, qkvo, kind: str, pack_fn, apply_gs=[], **kwargs):
if all(x is None for x in qkvo):
return
is_lora_a, is_lora_b = get_lora_flags(kind)
assert not (is_lora_a or is_lora_b)
qkvo = tuple(map(transpose, qkvo))
gs = self.group_size if ('w1' in apply_gs) else 1
if self.model.repeat_kv:
qkvo = self._repeat_kv(qkvo, gs, kind)
qkv, o = self._reorder_and_merge(qkvo, gs)
self.model.save_split(pack_fn(qkv),
self._attn.format(idx, 'w_qkv', kind),
split_dim=-1,
split_num=self.tp,
copy=is_lora_a)
self.model.save_split(pack_fn(o),
self._attn.format(idx, 'wo', kind),
split_dim=0,
split_num=self.tp,
copy=is_lora_b)
def apply(self, i: int, r: BaseReader):
for e in get_params(r.attn(i, None), bias=self.attn_bias):
e(self._export, partial(r.attn, i), i)
if self.qk_norm:
q, k = r.qk_norm(i)
if q is not None and k is not None:
if self.model.permute_qk:
if self.rope_dim < self.head_dim:
q = permute_v2_partial(q, self.head_dim, self.rope_dim)
k = permute_v2_partial(k, self.head_dim, self.rope_dim)
else:
q = permute_v2(q, self.head_dim)
k = permute_v2(k, self.head_dim)
self.model.save_split(q, self._attn.format(i, 'q_norm', '')[:-1])
self.model.save_split(k, self._attn.format(i, 'k_norm', '')[:-1])
if self.attn_sink:
sinks = r.attn_sinks(i)
self.model.save_split(sinks, self._attn.format(i, 'sinks', '')[:-1], split_dim=-1, split_num=self.tp)
class MLA(Module):
"""
requires:
r.mla(i, kind)
r.mla_norm(i)
"""
_mla = 'layers.{0}.attention.{1}.{2}'
def __init__(self, model: BaseOutputModel):
self.model = model
def _export(self, idx: int, xs, kind: str, pack_fn, **kwargs):
if all(x is None for x in xs):
return
q_a, q_b, q, kv_a, kv_b, o = xs
cfg = self.model.model_config
head_num = cfg.head_num
kv_lora_rank = cfg.kv_lora_rank
qk_rope_dim = cfg.qk_rope_dim
size_per_head = cfg.size_per_head
v_head_dim = cfg.v_head_dim
# ========== MLA Weight Folding for Dimension Mismatch ==========
# When kv_lora_rank != qk_nope_dim (e.g., GLM 4.7 Flash: 512 != 512+64=576),
# fold the kc/vc compression/decompression BMMs into q_b_proj/o_proj weights
# at conversion time to avoid runtime overhead.
if kind == 'weight' and kv_lora_rank and q is None and q_b is not None and kv_b is not None and o is not None:
if not (torch.is_floating_point(q_b) and torch.is_floating_point(kv_b) and torch.is_floating_point(o)):
raise ValueError('MLA weight folding requires floating-point attention weights.')
orig_q_head_dim = q_b.size(0) // head_num
orig_qk_nope_dim = orig_q_head_dim - qk_rope_dim
orig_kv_dim_total = kv_b.size(0) // head_num
orig_v_head_dim = o.size(1) // head_num
actual_orig_qk_nope_dim = orig_kv_dim_total - orig_v_head_dim
if abs(orig_qk_nope_dim - actual_orig_qk_nope_dim) > 1:
raise ValueError(f'Dimension mismatch: inferred qk_nope from q_b ({orig_qk_nope_dim}) != '
f'inferred from kv_b ({actual_orig_qk_nope_dim})')
orig_qk_nope_dim = actual_orig_qk_nope_dim
target_nope_dim = size_per_head - qk_rope_dim
target_v_head_dim = v_head_dim
if orig_qk_nope_dim != target_nope_dim or orig_v_head_dim != target_v_head_dim:
if target_nope_dim != kv_lora_rank or target_v_head_dim != kv_lora_rank:
raise ValueError(f'MLA folding expects v_head_dim and nope_dim to equal kv_lora_rank, '
f'got nope={target_nope_dim}, v_head={target_v_head_dim}, rank={kv_lora_rank}')
if kv_b.size(1) != kv_lora_rank:
raise ValueError(f'kv_b_proj second dim must equal kv_lora_rank for MLA folding, '
f'got {kv_b.size(1)} != {kv_lora_rank}')
# Split kv_b into kc and vc
kv_b_per_head = kv_b.reshape(head_num, orig_qk_nope_dim + orig_v_head_dim, kv_lora_rank)
kc_w = kv_b_per_head[:, :orig_qk_nope_dim, :]
vc_w = kv_b_per_head[:, orig_qk_nope_dim:, :]
# Fold kc into q_b_proj
q_b_per_head = q_b.reshape(head_num, orig_q_head_dim, q_b.size(1))
q_nope_w = q_b_per_head[:, :orig_qk_nope_dim, :]
q_rope_w = q_b_per_head[:, orig_qk_nope_dim:, :]
q_nope_expanded = torch.bmm(kc_w.transpose(1, 2), q_nope_w)
q_b_folded = torch.cat([q_nope_expanded, q_rope_w], dim=1)
q_b = q_b_folded.reshape(head_num * size_per_head, q_b.size(1))
# Fold vc into o_proj
o_per_head = o.reshape(o.size(0), head_num, orig_v_head_dim)
o_folded = torch.bmm(o_per_head.permute(1, 0, 2), vc_w)
o = o_folded.permute(1, 0, 2).reshape(o.size(0), head_num * kv_lora_rank)
# Set kv_b to identity (kc/vc are now absorbed)
eye = torch.eye(kv_lora_rank, dtype=kv_b.dtype, device=kv_b.device)
kv_b = torch.cat([eye, eye], dim=0).repeat(head_num, 1)
# ========== End MLA Weight Folding ==========
# Transpose after folding
q_a, q_b, q, kv_a, kv_b, o = map(transpose, (q_a, q_b, q, kv_a, kv_b, o))
if q is not None:
q_b = q
# Pad o_proj to size_per_head if present
if o is not None:
o = o.reshape(head_num, v_head_dim, -1)
o = torch.nn.functional.pad(o, (0, 0, size_per_head - v_head_dim, 0, 0, 0))
o = o.view(head_num * size_per_head, cfg.hidden_units)
tp = self.model.attn_tp_size
# Export MLA weights (handle None for folded-away tensors)
if q_a is not None:
self.model.save_split(pack_fn(q_a), self._mla.format(idx, 'q_a_proj', kind))
q_b_name = 'q_proj' if q_a is None else 'q_b_proj'
if q_b is not None:
self.model.save_split(pack_fn(q_b), self._mla.format(idx, q_b_name, kind), split_dim=-1, split_num=tp)
if kv_a is not None:
self.model.save_split(pack_fn(kv_a), self._mla.format(idx, 'kv_a_proj', kind))
# if kv_b is not None:
# self.model.save_split(pack_fn(kv_b), self._mla.format(idx, 'kv_b_proj', kind), split_dim=-1, split_num=tp)
if o is not None:
self.model.save_split(pack_fn(o), self._mla.format(idx, 'wo', kind), split_dim=0, split_num=tp)
_layernorm = 'layers.{0}.attention.{1}_a_layernorm'
def apply(self, i: int, r: BaseReader):
for f in get_params(r.attn(i, None), bias=False):
f(self._export, partial(r.mla, i), i)
q, k = r.mla_norm(i)
if q is not None:
self.model.save_split(q, self._layernorm.format(i, 'q'))
self.model.save_split(k, self._layernorm.format(i, 'kv'))
class LinearAttn(Module):
_linear_attn = 'layers.{0}.linear_attn.{1}.{2}'
def __init__(self, model: BaseOutputModel):
self.model = model
self.tp = model.attn_tp_size
cfg = model.model_config
self.key_dim = cfg.linear_num_key_heads * cfg.linear_key_head_dim
self.value_dim = cfg.linear_num_value_heads * cfg.linear_value_head_dim
def _tp_interleave_qkv(self, tensor, dim):
"""Split a concatenated [Q, K, V] tensor into components, reshape each
for TP interleaving, and re-concatenate.
in_proj_qkv layout along ``dim``: Q(key_dim) | K(key_dim) | V(value_dim).
A naive split doesn't respect component boundaries when key_dim and
value_dim differ. This method splits Q/K/V, reshapes each to
``(tp, -1)`` along ``dim``, concatenates per-TP-shard, then flattens
so that a subsequent ``save_split(split_dim=dim)`` gives each rank the
correct portion.
"""
if dim < 0:
dim = tensor.dim() + dim
q, k, v = torch.split(tensor, [self.key_dim, self.key_dim, self.value_dim], dim=dim)
def reshape(x):
# Move TP axis to a new dimension right after ``dim``
shape = list(x.shape)
d = shape[dim]
new_shape = shape[:dim] + [self.tp, d // self.tp] + shape[dim + 1:]
return x.view(new_shape)
parts = torch.cat([reshape(q), reshape(k), reshape(v)], dim=dim + 1)
# Collapse tp and per-shard dims back
shape = list(parts.shape)
final_shape = shape[:dim] + [shape[dim] * shape[dim + 1]] + shape[dim + 2:]
return parts.reshape(final_shape)
def apply(self, i: int, r: BaseReader):
layer_types = getattr(self.model.model_config, 'layer_types', [])
if i >= len(layer_types) or layer_types[i] != 'linear_attention':
return
for kind in ['weight', 'bias']:
weights = r.linear_attn(i, kind)
if not weights:
continue
names = ['conv1d', 'in_proj_qkv', 'in_proj_z', 'in_proj_b', 'in_proj_a', 'out_proj', 'A_log', 'dt_bias']
for name, tensor in zip(names, weights):
if tensor is None:
continue
if name == 'conv1d':
# conv1d shape: (conv_dim, 1, d_conv) where
# conv_dim = key_dim*2 + value_dim. Interleave Q/K/V
# portions along dim 0 before splitting for TP.
tensor = self._tp_interleave_qkv(tensor, dim=0)
self.model.save_split(tensor,
self._linear_attn.format(i, name, kind),
split_dim=0,
split_num=self.tp)
elif name in ['A_log', 'dt_bias']:
# Split per-head params across TP ranks (use -1 to
# avoid the 1-D copy shortcut in save_split).
self.model.save_split(tensor,
self._linear_attn.format(i, name, kind),
split_dim=-1,
split_num=self.tp)
elif name == 'out_proj':
self.model.save_split(transpose(tensor),
self._linear_attn.format(i, name, kind),
split_dim=0,
split_num=self.tp)
elif name == 'in_proj_qkv':
# in_proj_qkv: (conv_dim, hidden) where conv_dim =
# key_dim*2 + value_dim. After transpose the QKV
# components are along dim -1. Interleave for TP so
# each shard gets the correct Q/K/V slice.
t = transpose(tensor)
t = self._tp_interleave_qkv(t, dim=-1)
self.model.save_split(t, self._linear_attn.format(i, name, kind), split_dim=-1, split_num=self.tp)
else:
self.model.save_split(transpose(tensor),
self._linear_attn.format(i, name, kind),
split_dim=-1,
split_num=self.tp)
norm = r.linear_norm(i, 'weight')
if norm is not None:
self.model.export_weight(norm, f'layers.{i}.linear_attn.norm.weight')
class Misc(Module):
"""
requires:
r.tok_embeddings()
r.norm_weight()
r.output_weight()
"""
def apply(self, i: int, r: BaseReader):
"""Export embedding, norm, output weight."""
emb = r.tok_embeddings()
norm_weight = r.norm_weight()
output_weight = r.output_weight()
def pad_weight(tensor: torch.Tensor, tp: int):
pad_size = None
vocab_size = self.model.model_config.vocab_size
if vocab_size % tp != 0:
pad_size = (vocab_size + tp - 1) // tp * tp - vocab_size
if pad_size is None:
return tensor
return torch.nn.functional.pad(tensor, (0, 0, 0, pad_size), 'constant', 0)
tp = self.model.attn_tp_size * self.model.attn_cp_size
if emb is not None:
emb = pad_weight(emb, tp=tp)
self.model.save_split(emb, 'tok_embeddings.weight', split_dim=1, split_num=tp)
if norm_weight is not None:
self.model.export_weight(norm_weight, 'norm.weight')
if output_weight is not None:
output_weight = pad_weight(output_weight, tp=tp)
# transpose
self.model.save_split(output_weight.t(), 'output.weight', split_dim=1, split_num=tp)
class Transformer:
def __init__(self, model: BaseOutputModel):
self.model = model
modules = [LayerNorm]
if model.model_config.kv_lora_rank:
modules.append(MLA)
else:
modules.append(Attn)
if getattr(model.model_config, 'layer_types', []):
modules.append(LinearAttn)
if model.model_config.inter_size:
modules.append(Ffn)
if model.model_config.expert_num:
modules.append(MoeFfn)
self.modules = [c(model) for c in modules]
self.misc = Misc(model)
def __call__(self, i: int, r: BaseReader):
if i >= 0:
for m in self.modules:
m(i, r)
return 1
else:
self.misc(i, r)