diff --git a/umbrella/attn/cache.py b/umbrella/attn/cache.py index 6358ae5..2e613e5 100644 --- a/umbrella/attn/cache.py +++ b/umbrella/attn/cache.py @@ -71,7 +71,8 @@ def compute_attention(self, layer_idx, storage_ids :torch.Tensor = None, attention_mask :torch.Tensor = None, - logits_soft_cap = 0): + logits_soft_cap = 0, + sm_scale=None): key_states, value_states = self.update_kv_cache(key_states[0], value_states[0], layer_idx, storage_ids) @@ -83,7 +84,8 @@ def compute_attention(self, kv_layout="NHD", custom_mask=attention_mask[:,:self.kv_offset], allow_fp16_qk_reduction=True, - logits_soft_cap = logits_soft_cap + logits_soft_cap = logits_soft_cap, + sm_scale=sm_scale, ) else: diff --git a/umbrella/models/auto_model.py b/umbrella/models/auto_model.py index 080e027..815da8b 100644 --- a/umbrella/models/auto_model.py +++ b/umbrella/models/auto_model.py @@ -2,6 +2,7 @@ from .qwen import Qwen, QwenOffload, QwenAwq, QwenAwqOffload, QwenCudagraph, QwenFBGEMM, QwenFBGEMMOffload from .gemma import Gemma2 from .mistral import Mistral, MistralAwqOffload, MistralOffload, MistralCudagraph, MistralAwq +from .granite import Granite class AutoModelLM: """ 自动模型加载器,根据模型类型动态加载对应的类。 @@ -140,7 +141,10 @@ class AutoModelLM: "mistralai/Mistral-Small-24B-Instruct-2501": Mistral, "stelterlab/Mistral-Small-24B-Instruct-2501-AWQ": MistralAwq, "PyrTools/Ministral-8B-Instruct-2410-AWQ": MistralAwq, - "mistralai/Ministral-8B-Instruct-2410": Mistral + "mistralai/Ministral-8B-Instruct-2410": Mistral, + "ibm-granite/granite-3.2-8b-instruct-preview": Granite, + "ibm-granite/granite-3.1-8b-instruct": Granite, + } _CUDAGRAPH_MODEL_MAPPING = { diff --git a/umbrella/models/granite.py b/umbrella/models/granite.py new file mode 100644 index 0000000..dfe8eef --- /dev/null +++ b/umbrella/models/granite.py @@ -0,0 +1,180 @@ +from transformers import GraniteForCausalLM, GraniteConfig, AutoModelForCausalLM +import torch +import torch.nn.functional as F +import gc +import flashinfer +from ..attn.cache import KV_Cache, StaticKV_Cache +from .granite_layer import GraniteLayer +from .base import LLMBase +from .model_utils import apply_rotary_pos_emb, layer_norm, capture_graph +from tqdm import tqdm + +class Granite(LLMBase): + def __init__(self, + model_name: str, + batch_size :int = 1, + max_length :int = 256, + device :str = 'cuda:0', + dtype = torch.float16) -> None: + + super().__init__() + self.batch_size = batch_size + self.device = device + self.dtype = dtype + self.config = GraniteConfig.from_pretrained(model_name) + self.model_name = model_name + self.max_length = max_length + + self.hidden_size = self.config.hidden_size + self.num_heads = self.config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = self.config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = self.config.max_position_embeddings + self.rope_theta = self.config.rope_theta + self.eos_tokens = self.config.eos_token_id if (isinstance(self.config.eos_token_id, list)) else [self.config.eos_token_id] + + # granite specific + self.residual_multiplier = self.config.residual_multiplier + self.embedding_multiplier = self.config.embedding_multiplier + self.attention_multiplier = self.config.attention_multiplier + self.logits_scaling = self.config.logits_scaling + + def alloc(self, **kwargs): + + + self.kv_cache = KV_Cache(self.config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + + hf_model = GraniteForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + + # Initialize embedding and language modeling head + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + + # Prepare rotary embeddings + position_ids = torch.arange(0, self.max_length).unsqueeze(0).to(self.device) + + # Compute cos and sin caches + rotary_emb = hf_model.model.rotary_emb + inv_freq = rotary_emb.inv_freq.detach().to(self.device) + + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cache = emb.cos()[0] + self.sin_cache = emb.sin()[0] + + # Initialize norm + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + + # Initialize layers + self.layers: list[GraniteLayer] = [] + for idx, hf_layer in enumerate(hf_model.model.layers): + layer = GraniteLayer(idx) + layer.init_parameters(hf_layer) + layer.to(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + + self.num_layers = len(self.layers) + + + @torch.inference_mode() + def layer_compute(self, + buffer: GraniteLayer, + layer_idx :int, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + residual = hidden_states + bsz, q_len, _ = hidden_states.size() + + hidden_states = layer_norm(hidden_states, + buffer.input_layernorm_variance_epsilon, + buffer.input_layernorm_weight + ) + + bsz, q_len, _ = hidden_states.size() + query_states = F.linear(hidden_states, buffer.wq) + key_states = F.linear(hidden_states, buffer.wk) + value_states = F.linear(hidden_states, buffer.wv) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + + + query_states, key_states = apply_rotary_pos_emb( + query_states, + key_states, + self.cos_cache, + self.sin_cache, + position_ids + ) + + hidden_states = self.kv_cache.compute_attention( + query_states.to(value_states.dtype), + key_states.to(value_states.dtype), + value_states, + layer_idx, + storage_ids, + attention_mask, + sm_scale=self.attention_multiplier + ) + + hidden_states = hidden_states.reshape(bsz, q_len, self.hidden_size) + hidden_states = F.linear(hidden_states, buffer.wo) + hidden_states = residual + hidden_states * self.residual_multiplier + residual = hidden_states + + hidden_states = layer_norm(hidden_states, + buffer.post_attention_layernorm_variance_epsilon, + buffer.post_attention_layernorm_weight + ) + + up = F.linear(hidden_states, buffer.up_proj) + gate = F.linear(hidden_states, buffer.gate_proj) + gate = F.silu(gate) + + hidden_states = gate * up + hidden_states = F.linear(hidden_states, buffer.down_proj) + + hidden_states = residual + hidden_states * self.residual_multiplier + + return hidden_states + + + @torch.inference_mode() + def inference(self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + storage_ids: torch.LongTensor): + + hidden_states = F.embedding(input_ids, self.embed_tokens) * self.embedding_multiplier + for idx in range(self.num_layers): + hidden_states = self.layer_compute( + self.layers[idx], + idx, hidden_states, + position_ids, + attention_mask, + storage_ids + ) + + + hidden_states = layer_norm(hidden_states, self.norm_variance_epsilon, self.norm_weight) + logits = F.linear(hidden_states, self.lm_head).float() / self.logits_scaling + return logits + + def gather_kv_incremental(self, indices: torch.LongTensor, offset:int): + + self.kv_cache.gather_kv_incremental(indices=indices, offset=offset) + + def clear(self): + + self.kv_cache.clear() \ No newline at end of file diff --git a/umbrella/models/granite_layer.py b/umbrella/models/granite_layer.py new file mode 100644 index 0000000..21b9cea --- /dev/null +++ b/umbrella/models/granite_layer.py @@ -0,0 +1,92 @@ +from __future__ import annotations +import torch +from transformers.models.granite.modeling_granite import GraniteDecoderLayer +from ..quantization.awq_utils import AwqLinear + +class GraniteLayer: + def __init__(self, layer_idx, device = "cpu") -> None: + + self.wq :torch.Tensor = None + self.wk :torch.Tensor = None + self.wv :torch.Tensor = None + self.wo :torch.Tensor = None + + self.gate_proj :torch.Tensor = None + self.up_proj :torch.Tensor = None + self.down_proj :torch.Tensor = None + + self.input_layernorm_weight: torch.Tensor = None + self.input_layernorm_variance_epsilon: float = 1e-05 + + self.post_attention_layernorm_weight: torch.Tensor = None + self.post_attention_layernorm_variance_epsilon: float = 1e-05 + + self.layer_idx = layer_idx + self.device = device + + def init_parameters(self, hf_layer: GraniteDecoderLayer): + + self.wq :torch.Tensor= hf_layer.self_attn.q_proj.weight.detach() + self.wk :torch.Tensor= hf_layer.self_attn.k_proj.weight.detach() + self.wv :torch.Tensor= hf_layer.self_attn.v_proj.weight.detach() + self.wo :torch.Tensor= hf_layer.self_attn.o_proj.weight.detach() + + self.gate_proj = hf_layer.mlp.gate_proj.weight.detach() + self.up_proj = hf_layer.mlp.up_proj.weight.detach() + self.down_proj = hf_layer.mlp.down_proj.weight.detach() + + # Layer norm weights and epsilon + self.input_layernorm_weight = hf_layer.input_layernorm.weight.detach() + self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.variance_epsilon + + self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach() + self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon + + def to(self, device:str = 'cuda:0', non_blocking = True): + + self.device = device + self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=non_blocking) + self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, non_blocking=non_blocking) + self.wq = self.wq.to(device, non_blocking=non_blocking) + self.wk = self.wk.to(device, non_blocking=non_blocking) + self.wv = self.wv.to(device, non_blocking=non_blocking) + self.wo = self.wo.to(device, non_blocking=non_blocking) + self.gate_proj = self.gate_proj.to(device, non_blocking=non_blocking) + self.up_proj = self.up_proj.to(device, non_blocking=non_blocking) + self.down_proj = self.down_proj.to(device, non_blocking=non_blocking) + + def copy(self, layer: GraniteDecoderLayer): + + self.wq.copy_(layer.wq, non_blocking=True) + self.wk.copy_(layer.wk, non_blocking=True) + self.wv.copy_(layer.wv, non_blocking=True) + self.wo.copy_(layer.wo, non_blocking=True) + + + self.gate_proj.copy_(layer.gate_proj, non_blocking=True) + self.up_proj.copy_(layer.up_proj, non_blocking=True) + self.down_proj.copy_(layer.down_proj, non_blocking=True) + + # Copy layer norm weights + self.input_layernorm_weight.copy_(layer.input_layernorm_weight, non_blocking=True) + self.post_attention_layernorm_weight.copy_(layer.post_attention_layernorm_weight, non_blocking=True) + + # Copy epsilon values + self.input_layernorm_variance_epsilon = layer.input_layernorm_variance_epsilon + self.post_attention_layernorm_variance_epsilon = layer.post_attention_layernorm_variance_epsilon + + def alloc_space(self, layer: GraniteDecoderLayer, device): + + self.device = device + self.wq = torch.zeros_like(layer.wq).to(device) + self.wk = torch.zeros_like(layer.wk).to(device) + self.wv = torch.zeros_like(layer.wv).to(device) + self.wo = torch.zeros_like(layer.wo).to(device) + + + self.gate_proj = torch.zeros_like(layer.gate_proj).to(device) + self.up_proj = torch.zeros_like(layer.up_proj).to(device) + self.down_proj = torch.zeros_like(layer.down_proj).to(device) + + self.input_layernorm_weight = torch.zeros_like(layer.input_layernorm_weight).to(device) + self.post_attention_layernorm_weight = torch.zeros_like(layer.post_attention_layernorm_weight).to(device) \ No newline at end of file diff --git a/umbrella/templates.py b/umbrella/templates.py index 4d1a0c7..c93adc8 100644 --- a/umbrella/templates.py +++ b/umbrella/templates.py @@ -23,7 +23,9 @@ 'gemma2': "{}", 'mistral': "[INST] {} [/INST]", -'qwq': "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n\n" +'qwq': "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n\n", +'ibm-granite': """\n<|start_of_role|>user<|end_of_role|>{}<|end_of_text|> +<|start_of_role|>assistant<|end_of_role|>""" } @@ -39,12 +41,17 @@ """, 'gemma2': "", 'gemma2-it': "", + 'mistral': """[SYSTEM_PROMPT]You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris. Your knowledge base was last updated on 2023-10-01. The current date is 2025-03-07. When you're not sure about some information, you say that you don't have the information and don't make up anything. If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. "What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "Where do you travel from?")[/SYSTEM_PROMPT]""", 'qwq': "", + + 'ibm-granite': """<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024. +Today's Date: March 05, 2025. +You are Granite, developed by IBM. You are a helpful AI assistant.<|end_of_text|>""" }