diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index db9738e6e..97452cabd 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -920,34 +920,60 @@ def process_weights_after_loading(self, layer): layer.w2_weight_scale = None return - # shuffle weight + E13, N13, K13 = layer.w13_weight_scale.data.shape layer.w13_weight.data = shuffle_weight( - layer.w13_weight, - is_guinterleave=self.is_guinterleave, - gate_up=True, + layer.w13_weight.data, layout=(16, 16), is_guinterleave=True, gate_up=True ) - layer.w2_weight.data = shuffle_weight( - layer.w2_weight, - is_guinterleave=self.is_guinterleave, - gate_up=False, - ) - layer.w13_weight.is_shuffled = True - layer.w2_weight.is_shuffled = True - - # shuffle scale - w13_scale_2d = layer.w13_weight_scale.reshape( - -1, layer.w13_weight_scale.shape[-1] + layer.w13_weight_scale = atom_parameter( + shuffle_scale( + layer.w13_weight_scale.data.reshape(E13 * N13, K13), + experts_cnt=E13, is_guinterleave=True, gate_up=True, + ).reshape(E13, N13, K13) ) - w2_scale_2d = layer.w2_weight_scale.reshape(-1, layer.w2_weight_scale.shape[-1]) - shuffled_w13_scale = shuffle_scale( - w13_scale_2d, self.num_experts, self.is_guinterleave, True + E2, N2, K2 = layer.w2_weight_scale.data.shape + layer.w2_weight.data = shuffle_weight( + layer.w2_weight.data, layout=(16, 16), is_guinterleave=True, gate_up=False ) - shuffled_w2_scale = shuffle_scale( - w2_scale_2d, self.num_experts, self.is_guinterleave, False + layer.w2_weight_scale = atom_parameter( + shuffle_scale( + layer.w2_weight_scale.data.reshape(E2 * N2, K2), + experts_cnt=E2, is_guinterleave=True, gate_up=False, + ).reshape(E2, N2, K2) ) - layer.w13_weight_scale = atom_parameter(shuffled_w13_scale) - layer.w2_weight_scale = atom_parameter(shuffled_w2_scale) + + layer.w13_weight.is_shuffled = True + layer.w2_weight.is_shuffled = True + layer.w13_weight.shuffle_kind = layer.w13_weight_scale.shuffle_kind = "mxfp4_moe" + layer.w2_weight.shuffle_kind = layer.w2_weight_scale.shuffle_kind = "mxfp4_moe" + + # layer.w13_weight.data = shuffle_weight( + # layer.w13_weight, + # is_guinterleave=self.is_guinterleave, + # gate_up=True, + # ) + # layer.w2_weight.data = shuffle_weight( + # layer.w2_weight, + # is_guinterleave=self.is_guinterleave, + # gate_up=False, + # ) + # layer.w13_weight.is_shuffled = True + # layer.w2_weight.is_shuffled = True + + # # shuffle scale + # w13_scale_2d = layer.w13_weight_scale.reshape( + # -1, layer.w13_weight_scale.shape[-1] + # ) + # w2_scale_2d = layer.w2_weight_scale.reshape(-1, layer.w2_weight_scale.shape[-1]) + + # shuffled_w13_scale = shuffle_scale( + # w13_scale_2d, self.num_experts, self.is_guinterleave, True + # ) + # shuffled_w2_scale = shuffle_scale( + # w2_scale_2d, self.num_experts, self.is_guinterleave, False + # ) + # layer.w13_weight_scale = atom_parameter(shuffled_w13_scale) + # layer.w2_weight_scale = atom_parameter(shuffled_w2_scale) def get_fused_moe_quant_config( self, layer: torch.nn.Module