Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 48 additions & 22 deletions atom/model_ops/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,34 +920,60 @@ def process_weights_after_loading(self, layer):
layer.w2_weight_scale = None
return

# shuffle weight
E13, N13, K13 = layer.w13_weight_scale.data.shape
layer.w13_weight.data = shuffle_weight(
layer.w13_weight,
is_guinterleave=self.is_guinterleave,
gate_up=True,
layer.w13_weight.data, layout=(16, 16), is_guinterleave=True, gate_up=True
)
Comment on lines +923 to 926
layer.w2_weight.data = shuffle_weight(
layer.w2_weight,
is_guinterleave=self.is_guinterleave,
gate_up=False,
)
layer.w13_weight.is_shuffled = True
layer.w2_weight.is_shuffled = True

# shuffle scale
w13_scale_2d = layer.w13_weight_scale.reshape(
-1, layer.w13_weight_scale.shape[-1]
layer.w13_weight_scale = atom_parameter(
shuffle_scale(
layer.w13_weight_scale.data.reshape(E13 * N13, K13),
experts_cnt=E13, is_guinterleave=True, gate_up=True,
).reshape(E13, N13, K13)
)
w2_scale_2d = layer.w2_weight_scale.reshape(-1, layer.w2_weight_scale.shape[-1])

shuffled_w13_scale = shuffle_scale(
w13_scale_2d, self.num_experts, self.is_guinterleave, True
E2, N2, K2 = layer.w2_weight_scale.data.shape
layer.w2_weight.data = shuffle_weight(
layer.w2_weight.data, layout=(16, 16), is_guinterleave=True, gate_up=False
)
Comment on lines +934 to 937
shuffled_w2_scale = shuffle_scale(
w2_scale_2d, self.num_experts, self.is_guinterleave, False
layer.w2_weight_scale = atom_parameter(
shuffle_scale(
layer.w2_weight_scale.data.reshape(E2 * N2, K2),
experts_cnt=E2, is_guinterleave=True, gate_up=False,
).reshape(E2, N2, K2)
)
Comment on lines +923 to 943
layer.w13_weight_scale = atom_parameter(shuffled_w13_scale)
layer.w2_weight_scale = atom_parameter(shuffled_w2_scale)

layer.w13_weight.is_shuffled = True
layer.w2_weight.is_shuffled = True
layer.w13_weight.shuffle_kind = layer.w13_weight_scale.shuffle_kind = "mxfp4_moe"
layer.w2_weight.shuffle_kind = layer.w2_weight_scale.shuffle_kind = "mxfp4_moe"

# layer.w13_weight.data = shuffle_weight(
# layer.w13_weight,
# is_guinterleave=self.is_guinterleave,
# gate_up=True,
# )
# layer.w2_weight.data = shuffle_weight(
# layer.w2_weight,
# is_guinterleave=self.is_guinterleave,
# gate_up=False,
# )
# layer.w13_weight.is_shuffled = True
# layer.w2_weight.is_shuffled = True

# # shuffle scale
# w13_scale_2d = layer.w13_weight_scale.reshape(
# -1, layer.w13_weight_scale.shape[-1]
# )
# w2_scale_2d = layer.w2_weight_scale.reshape(-1, layer.w2_weight_scale.shape[-1])

# shuffled_w13_scale = shuffle_scale(
# w13_scale_2d, self.num_experts, self.is_guinterleave, True
# )
# shuffled_w2_scale = shuffle_scale(
# w2_scale_2d, self.num_experts, self.is_guinterleave, False
# )
# layer.w13_weight_scale = atom_parameter(shuffled_w13_scale)
# layer.w2_weight_scale = atom_parameter(shuffled_w2_scale)

def get_fused_moe_quant_config(
self, layer: torch.nn.Module
Expand Down