From cdc95cbfef885467da1627fc8ce3c38a6dbf992a Mon Sep 17 00:00:00 2001 From: SwaroopTha <80303869+SwaroopTha@users.noreply.github.com> Date: Wed, 5 Mar 2025 11:30:29 -0600 Subject: [PATCH] "Granite Functionality" --- umbrella/attn/cache.py | 6 +- umbrella/models/auto_model.py | 6 +- umbrella/models/granite.py | 180 +++++++++++++++++++++++++++++++ umbrella/models/granite_layer.py | 92 ++++++++++++++++ umbrella/templates.py | 8 +- 5 files changed, 288 insertions(+), 4 deletions(-) create mode 100644 umbrella/models/granite.py create mode 100644 umbrella/models/granite_layer.py diff --git a/umbrella/attn/cache.py b/umbrella/attn/cache.py index 5f6ccd1..605d2ce 100644 --- a/umbrella/attn/cache.py +++ b/umbrella/attn/cache.py @@ -71,7 +71,8 @@ def compute_attention(self, layer_idx, storage_ids :torch.Tensor = None, attention_mask :torch.Tensor = None, - logits_soft_cap = 0): + logits_soft_cap = 0, + sm_scale=None): key_states, value_states = self.update_kv_cache(key_states[0], value_states[0], layer_idx, storage_ids) @@ -83,7 +84,8 @@ def compute_attention(self, kv_layout="NHD", custom_mask=attention_mask[:,:self.kv_offset], allow_fp16_qk_reduction=True, - logits_soft_cap = logits_soft_cap + logits_soft_cap = logits_soft_cap, + sm_scale=sm_scale, ) else: diff --git a/umbrella/models/auto_model.py b/umbrella/models/auto_model.py index 819738d..a43fee3 100644 --- a/umbrella/models/auto_model.py +++ b/umbrella/models/auto_model.py @@ -2,6 +2,7 @@ from .qwen import Qwen, QwenOffload, QwenAwq, QwenAwqOffload, QwenCudagraph from .gemma import Gemma2 from .mistral import Mistral, MistralAwqOffload, MistralOffload, MistralCudagraph, MistralAwq +from .granite import Granite class AutoModelLM: """ 自动模型加载器,根据模型类型动态加载对应的类。 @@ -123,7 +124,10 @@ class AutoModelLM: "mistralai/Mistral-Small-24B-Instruct-2501": Mistral, "stelterlab/Mistral-Small-24B-Instruct-2501-AWQ": MistralAwq, "PyrTools/Ministral-8B-Instruct-2410-AWQ": MistralAwq, - "mistralai/Ministral-8B-Instruct-2410": Mistral + "mistralai/Ministral-8B-Instruct-2410": Mistral, + "ibm-granite/granite-3.2-8b-instruct-preview": Granite, + "ibm-granite/granite-3.1-8b-instruct": Granite, + } _CUDAGRAPH_MODEL_MAPPING = { diff --git a/umbrella/models/granite.py b/umbrella/models/granite.py new file mode 100644 index 0000000..dfe8eef --- /dev/null +++ b/umbrella/models/granite.py @@ -0,0 +1,180 @@ +from transformers import GraniteForCausalLM, GraniteConfig, AutoModelForCausalLM +import torch +import torch.nn.functional as F +import gc +import flashinfer +from ..attn.cache import KV_Cache, StaticKV_Cache +from .granite_layer import GraniteLayer +from .base import LLMBase +from .model_utils import apply_rotary_pos_emb, layer_norm, capture_graph +from tqdm import tqdm + +class Granite(LLMBase): + def __init__(self, + model_name: str, + batch_size :int = 1, + max_length :int = 256, + device :str = 'cuda:0', + dtype = torch.float16) -> None: + + super().__init__() + self.batch_size = batch_size + self.device = device + self.dtype = dtype + self.config = GraniteConfig.from_pretrained(model_name) + self.model_name = model_name + self.max_length = max_length + + self.hidden_size = self.config.hidden_size + self.num_heads = self.config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = self.config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = self.config.max_position_embeddings + self.rope_theta = self.config.rope_theta + self.eos_tokens = self.config.eos_token_id if (isinstance(self.config.eos_token_id, list)) else [self.config.eos_token_id] + + # granite specific + self.residual_multiplier = self.config.residual_multiplier + self.embedding_multiplier = self.config.embedding_multiplier + self.attention_multiplier = self.config.attention_multiplier + self.logits_scaling = self.config.logits_scaling + + def alloc(self, **kwargs): + + + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + + hf_model = GraniteForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + + # Initialize embedding and language modeling head + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + + # Prepare rotary embeddings + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + + # Compute cos and sin caches + rotary_emb = hf_model.model.rotary_emb + inv_freq = rotary_emb.inv_freq.detach().to(self.device) + + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + + # Initialize norm + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + + # Initialize layers + self.layers: list[GraniteLayer] = [] + for idx, hf_layer in enumerate(hf_model.model.layers): + layer = GraniteLayer(idx) + layer.init_parameters(hf_layer) + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + + self.num_layers = len(self.layers) + + + @torch.inference_mode() + def layer_compute(self, + buffer: GraniteLayer, + layer_idx :int, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + residual = hidden_states + bsz, q_len, _ = hidden_states.size() + + hidden_states = layer_norm(hidden_states, + buffer.input_layernorm_variance_epsilon, + buffer.input_layernorm_weight + ) + + bsz, q_len, _ = hidden_states.size() + query_states = F.linear(hidden_states, buffer.wq) + key_states = F.linear(hidden_states, buffer.wk) + value_states = F.linear(hidden_states, buffer.wv) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + + + query_states, key_states = apply_rotary_pos_emb( + query_states, + key_states, + self.cos_cache, + self.sin_cache, + position_ids + ) + + hidden_states = self.kv_cache.compute_attention( + query_states.to(value_states.dtype), + key_states.to(value_states.dtype), + value_states, + layer_idx, + storage_ids, + attention_mask, + sm_scale=self.attention_multiplier + ) + + hidden_states = hidden_states.reshape(bsz, q_len, self.hidden_size) + hidden_states = F.linear(hidden_states, buffer.wo) + hidden_states = residual + hidden_states * self.residual_multiplier + residual = hidden_states + + hidden_states = layer_norm(hidden_states, + buffer.post_attention_layernorm_variance_epsilon, + buffer.post_attention_layernorm_weight + ) + + up = F.linear(hidden_states, buffer.up_proj) + gate = F.linear(hidden_states, buffer.gate_proj) + gate = F.silu(gate) + + hidden_states = gate * up + hidden_states = F.linear(hidden_states, buffer.down_proj) + + hidden_states = residual + hidden_states * self.residual_multiplier + + return hidden_states + + + @torch.inference_mode() + def inference(self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + hidden_states = F.embedding(input_ids, self.embed_tokens) * self.embedding_multiplier + for idx in range(self.num_layers): + hidden_states = self.layer_compute( + self.layers[idx], + idx, hidden_states, + position_ids, + attention_mask, + storage_ids + ) + + + hidden_states = layer_norm(hidden_states, self.norm_variance_epsilon, self.norm_weight) + logits = F.linear(hidden_states, self.lm_head).float() / self.logits_scaling + return logits + + def gather_kv_incremental(self, indices: torch.LongTensor, offset:int): + + self.kv_cache.gather_kv_incremental(indices=indices, offset=offset) + + def clear(self): + + self.kv_cache.clear() \ No newline at end of file diff --git a/umbrella/models/granite_layer.py b/umbrella/models/granite_layer.py new file mode 100644 index 0000000..21b9cea --- /dev/null +++ b/umbrella/models/granite_layer.py @@ -0,0 +1,92 @@ +from __future__ import annotations +import torch +from transformers.models.granite.modeling_granite import GraniteDecoderLayer +from ..quantization.awq_utils import AwqLinear + +class GraniteLayer: + def __init__(self, layer_idx, device = "cpu") -> None: + + self.wq :torch.Tensor = None + self.wk :torch.Tensor = None + self.wv :torch.Tensor = None + self.wo :torch.Tensor = None + + self.gate_proj :torch.Tensor = None + self.up_proj :torch.Tensor = None + self.down_proj :torch.Tensor = None + + self.input_layernorm_weight: torch.Tensor = None + self.input_layernorm_variance_epsilon: float = 1e-05 + + self.post_attention_layernorm_weight: torch.Tensor = None + self.post_attention_layernorm_variance_epsilon: float = 1e-05 + + self.layer_idx = layer_idx + self.device = device + + def init_parameters(self, hf_layer: GraniteDecoderLayer): + + self.wq :torch.Tensor= hf_layer.self_attn.q_proj.weight.detach() + self.wk :torch.Tensor= hf_layer.self_attn.k_proj.weight.detach() + self.wv :torch.Tensor= hf_layer.self_attn.v_proj.weight.detach() + self.wo :torch.Tensor= hf_layer.self_attn.o_proj.weight.detach() + + self.gate_proj = hf_layer.mlp.gate_proj.weight.detach() + self.up_proj = hf_layer.mlp.up_proj.weight.detach() + self.down_proj = hf_layer.mlp.down_proj.weight.detach() + + # Layer norm weights and epsilon + self.input_layernorm_weight = hf_layer.input_layernorm.weight.detach() + self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.variance_epsilon + + self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach() + self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon + + def to(self, device:str = 'cuda:0', non_blocking = True): + + self.device = device + self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking) + self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, non_blocking=non_blocking) + self.wq = self.wq.to(device, non_blocking=non_blocking) + self.wk = self.wk.to(device, non_blocking=non_blocking) + self.wv = self.wv.to(device, non_blocking=non_blocking) + self.wo = self.wo.to(device, non_blocking=non_blocking) + self.gate_proj = self.gate_proj.to(device, non_blocking=non_blocking) + self.up_proj = self.up_proj.to(device, non_blocking=non_blocking) + self.down_proj = self.down_proj.to(device, non_blocking=non_blocking) + + def copy(self, layer: GraniteDecoderLayer): + + self.wq.copy_(layer.wq, non_blocking=True) + self.wk.copy_(layer.wk, non_blocking=True) + self.wv.copy_(layer.wv, non_blocking=True) + self.wo.copy_(layer.wo, non_blocking=True) + + + self.gate_proj.copy_(layer.gate_proj, non_blocking=True) + self.up_proj.copy_(layer.up_proj, non_blocking=True) + self.down_proj.copy_(layer.down_proj, non_blocking=True) + + # Copy layer norm weights + self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True) + self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True) + + # Copy epsilon values + self.input_layernorm_variance_epsilon = layer.input_layernorm_variance_epsilon + self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon + + def alloc_space(self, layer: GraniteDecoderLayer, device): + + self.device = device + self.wq = torch.zeros_like(layer.wq).to(device) + self.wk = torch.zeros_like(layer.wk).to(device) + self.wv = torch.zeros_like(layer.wv).to(device) + self.wo = torch.zeros_like(layer.wo).to(device) + + + self.gate_proj = torch.zeros_like(layer.gate_proj).to(device) + self.up_proj = torch.zeros_like(layer.up_proj).to(device) + self.down_proj = torch.zeros_like(layer.down_proj).to(device) + + self.input_layernorm_weight = torch.zeros_like(layer.input_layernorm_weight).to(device) + self.post_attention_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device) \ No newline at end of file diff --git a/umbrella/templates.py b/umbrella/templates.py index ac30a80..47b9f2d 100644 --- a/umbrella/templates.py +++ b/umbrella/templates.py @@ -22,7 +22,9 @@ """, 'gemma2': "{}", -'mistral': "[INST] {} [/INST]" +'mistral': "[INST] {} [/INST]", +'ibm-granite': """\n<|start_of_role|>user<|end_of_role|>{}<|end_of_text|> +<|start_of_role|>assistant<|end_of_role|>""" } @@ -39,6 +41,10 @@ 'gemma2': "", 'gemma2-it': "", 'mistral': "", + 'ibm-granite': """<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024. +Today's Date: March 05, 2025. +You are Granite, developed by IBM. You are a helpful AI assistant.<|end_of_text|>""" + }