From a792907b157f4084a55a36aafdde2c27b2544a4a Mon Sep 17 00:00:00 2001 From: Haizhong Zheng Date: Wed, 5 Feb 2025 01:21:39 -0500 Subject: [PATCH] init shadowKV code --- umbrella/attn/cache_shadowKV.py | 266 ++++++++++++++ umbrella/models/llama_shadowKV.py | 570 ++++++++++++++++++++++++++++++ umbrella/models/shadow_kv_ops.py | 296 ++++++++++++++++ 3 files changed, 1132 insertions(+) create mode 100644 umbrella/attn/cache_shadowKV.py create mode 100644 umbrella/models/llama_shadowKV.py create mode 100644 umbrella/models/shadow_kv_ops.py diff --git a/umbrella/attn/cache_shadowKV.py b/umbrella/attn/cache_shadowKV.py new file mode 100644 index 0000000..a011e34 --- /dev/null +++ b/umbrella/attn/cache_shadowKV.py @@ -0,0 +1,266 @@ +import torch.nn as nn +import torch +import math +import gc +from torch import nn + +class ShadowKVCache: + """ShadowKV, only for accuracy measurement and understanding, not for efficiency, please refer to ShadowKV_CPU for the efficient implementation""" + def __init__(self, + config :object, + batch_size :int = 1, + max_length :int = 32*1024, + device :str = 'cuda:0', + dtype = torch.bfloat16, + sparse_budget: int = 2048, + chunk_size=8, + rank=160, + ) -> None: + + self.config = config + self.batch_size = batch_size + self.max_length = max_length + self.device = device + self.dtype = dtype + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + + self.sparse_budget = int(sparse_budget) + self.chunk_size = chunk_size + self.rank = rank + self.local_chunk = 4 + self.outlier_chunk = 48 + + assert self.batch_size == 1, "ShadowKV class only supports batch_size=1, please use ShadowKV_CPU class for batch_size > 1" + + self.selected_chunk_idx = torch.zeros( + config.num_hidden_layers, + batch_size, + config.num_key_value_heads, + self.sparse_budget // self.chunk_size, + device=self.device, + dtype=torch.long + ) + + self.v_cache_cpu = torch.zeros( + config.num_hidden_layers, + batch_size, + config.num_key_value_heads, + self.max_length, + self.config.hidden_size // self.config.num_attention_heads, + device=self.device, + dtype=self.dtype + ) + + self.k_cache_buffer = torch.zeros( + config.num_hidden_layers, + batch_size, + config.num_key_value_heads, + self.sparse_budget + 4096, + self.config.hidden_size // self.config.num_attention_heads, + device=self.device, + dtype=self.dtype + ) + + self.v_cache_buffer = torch.zeros( + config.num_hidden_layers, + batch_size, + config.num_key_value_heads, + self.sparse_budget + 4096, + self.config.hidden_size // self.config.num_attention_heads, + device=self.device, + dtype=self.dtype + ) + + + self.num_layers = config.num_hidden_layers + self.kv_offset = 0 + self.prefill = 0 + self.gen_offset = 0 + + self.k_landmark = None + self.k_landmark_idx = None + self.U = None + self.SV = None + + self.copy_stream = torch.cuda.Stream() + + def print_stats(self): + print(f"ShadowKV | sparse budget {self.sparse_budget} | chunk size {self.chunk_size} |rank {self.rank} | cached {self.kv_offset} | local_chunk {self.local_chunk} | outlier_chunk {self.outlier_chunk}") + + def get_svd(self, new_k_cache, layer_idx): + # [bsz, 8, prefill, 128] OR [bsz, prefill, 1024] + if new_k_cache.shape[1] <= 32: + # [bsz, 8, prefill, 128] --> [bsz, prefill, 1024] + k_cache = new_k_cache.transpose(1, 2).reshape(self.batch_size, -1, self.num_key_value_heads*self.head_dim) + else: + # [bsz, prefill, 1024] + k_cache = new_k_cache + + if layer_idx == 0: + # init U, SV + self.U = torch.zeros(self.num_layers, self.batch_size, k_cache.shape[1], self.rank, device=self.device, dtype=self.dtype) + self.SV = torch.zeros(self.num_layers, self.batch_size, self.num_key_value_heads, self.rank, self.head_dim, device=self.device, dtype=self.dtype) + + u, s, v = torch.svd(k_cache.float()) + v = v.transpose(1,2) + # [bsz, 128k, 1024] --> [bsz, 128k, 160] [bsz, 160, 1024] (bsz, 8, 160, 128) + self.U[layer_idx].copy_(u[:, :, :self.rank].to(self.dtype)) # [bsz, 128k, 160] + self.SV[layer_idx].copy_(torch.matmul(torch.diag_embed(s[:, :self.rank]), v[:, :self.rank]).to(self.dtype).view(self.batch_size, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)) # [bsz, 8, 160, 128] + + def register_k_landmark(self, k_landmark, k_landmark_idx, layer_idx): + num_landmarks = k_landmark.shape[-2] + if layer_idx == 0: + # init k_landmark, k_landmark_idx + self.k_landmark = torch.zeros(self.num_layers, self.batch_size, self.num_key_value_heads, num_landmarks, self.head_dim, device=self.device, dtype=self.dtype) + self.k_landmark_idx = torch.zeros(self.num_layers, self.batch_size, self.num_key_value_heads, num_landmarks, device=self.device, dtype=torch.long) + + self.k_landmark[layer_idx].copy_(k_landmark.contiguous()) + self.k_landmark_idx[layer_idx].copy_(k_landmark_idx.contiguous()) + + def prefill_kv_cache(self, + new_v_cache :torch.Tensor, + layer_idx :int, + key_states_roped: torch.Tensor, + query: torch.Tensor=None + ): + + incoming = new_v_cache.shape[-2] # [bsz, num_kv_heads, incoming, head_dim] + self.prefill = incoming + self.v_cache_cpu[layer_idx][:, :, :incoming] = new_v_cache.clone() + + # [x0, x1, ...., self.chunks*chunk_size, local_chunk, rest] + self.chunks = incoming // self.chunk_size - self.local_chunk + self.select_sets = self.sparse_budget // self.chunk_size + + assert self.select_sets * self.chunk_size == self.sparse_budget, f"({self.select_sets}) * {self.chunk_size} != {self.sparse_budget}" + + # store Post-RoPE k cache to the cache + self.prefill_local = incoming - self.chunks * self.chunk_size # local chunks + align to chunk_size + self.k_cache_buffer[layer_idx][:, :, :self.prefill_local].copy_(key_states_roped[:, :, -self.prefill_local:]) + self.v_cache_buffer[layer_idx][:, :, :self.prefill_local].copy_(new_v_cache[:, :, -self.prefill_local:]) + + key_states_roped_ctx = key_states_roped[:,:,:self.chunks*self.chunk_size].view(self.batch_size, self.num_key_value_heads, self.chunks, self.chunk_size, self.head_dim) + landmark_candidates = key_states_roped_ctx.mean(dim=-2) # [bsz, kv_heads, chunks, head_dim] + + # compute the cos similarity between it and the original key cache + cos_sim = torch.nn.functional.cosine_similarity(landmark_candidates.unsqueeze(3).expand(-1, -1, -1, self.chunk_size, -1), key_states_roped_ctx, dim=-1) # [bsz, kv_heads, chunks, chunk_size] + + # get the outlier_chunk idx for each head # [bsz, kv_heads, outlier_chunk] + outlier_chunk_idx = cos_sim.min(dim=-1).values.topk(self.outlier_chunk, largest=False).indices + + # [bsz, kv_heads, chunks, chunk_size, head_dim] --gather[bsz, kv_heads, outlier_chunk]-->[bsz, kv_heads, outlier_chunk, chunk_size, head_dim] + outlier_chunk_k_cache = key_states_roped_ctx.gather(dim=2, index=outlier_chunk_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, self.chunk_size, self.head_dim)).view(self.batch_size, self.num_key_value_heads, self.outlier_chunk*self.chunk_size, self.head_dim) + + outlier_chunk_v_cache = new_v_cache[:,:,:self.chunks*self.chunk_size].view(self.batch_size, self.num_key_value_heads, self.chunks, self.chunk_size, self.head_dim).gather(dim=2, index=outlier_chunk_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, self.chunk_size, self.head_dim)).view(self.batch_size, self.num_key_value_heads, self.outlier_chunk*self.chunk_size, self.head_dim) + + self.sparse_start = self.prefill_local + self.outlier_chunk*self.chunk_size + self.sparse_end = self.prefill_local + self.outlier_chunk*self.chunk_size + self.sparse_budget + + # store outlier_chunk to the cache + self.k_cache_buffer[layer_idx][:, :, self.prefill_local:self.sparse_start].copy_(outlier_chunk_k_cache) + self.v_cache_buffer[layer_idx][:, :, self.prefill_local:self.sparse_start].copy_(outlier_chunk_v_cache) + + # filter landmark_candidates using outlier_chunk and register the rest to k_landmark + # [bsz, kv_heads, chunks, head_dim] --> [bsz, kv_heads, chunks - outlier_chunk, head_dim] + # get rest_idx: [bsz, kv_heads, chunks] --filter--> [bsz, kv_heads, chunks - outlier_chunk] + all_idx = torch.arange(self.chunks, device=key_states_roped.device).unsqueeze(0).unsqueeze(0).expand(self.batch_size, self.num_key_value_heads, -1) # [bsz, kv_heads, chunks] + mask = torch.ones_like(all_idx, dtype=torch.bool) + mask.scatter_(dim=-1, index=outlier_chunk_idx, value=False) + rest_idx = all_idx.masked_select(mask).view(self.batch_size, self.num_key_value_heads, -1) + + # register rest_idxed landmarks to k_landmark + self.register_k_landmark(landmark_candidates.gather(dim=2, index=rest_idx.unsqueeze(-1).expand(-1, -1, -1, self.head_dim)).view(self.batch_size, self.num_key_value_heads, -1, self.head_dim), rest_idx, layer_idx) + + if layer_idx == self.num_layers - 1: + assert self.sparse_budget < incoming + self.kv_offset += incoming + + def get_retrieval_position_ids(self, layer_idx, query_states): + # self.k_landmark[layer_idx][:, :, :self.chunks] is [bsz, 8, chunks, head_dim] + # chunk_attn: [bsz, 32, window_size, chunks] + self.incoming_q_len = query_states.shape[-2] # 1 + # print(query_states.view(-1, self.num_key_value_heads, self.num_key_value_groups, self.incoming_q_len, self.head_dim).shape, self.k_landmark[layer_idx].transpose(2, 3).shape) + # [bsz, 8, 4, q_len, 128] * [bsz, 8, 128, chunks] --> [bsz, 8, 4, q_len, chunks] + chunk_attn = torch.einsum('bhgqd,bhdc->bhgqc', query_states.view(-1, self.num_key_value_heads, self.num_key_value_groups, self.incoming_q_len, self.head_dim), self.k_landmark[layer_idx].transpose(2, 3)).squeeze(2) / math.sqrt(128) + chunk_attn = nn.functional.softmax(chunk_attn, dim=-1, dtype=torch.float32).to(self.dtype) # [bsz, 8, 4, q_len, chunks] + chunk_attn = chunk_attn.sum(dim = -2) # [bsz, 8, 4, chunks] + if self.num_key_value_groups > 1: + chunk_attn, _ = torch.max(chunk_attn, dim=-2) # [bsz, 8, chunks] + merged_results = torch.topk(chunk_attn, k=self.select_sets, dim=-1).indices # [bsz, 8, select_sets(256)] + + # use merged_results to gather the position_ids: [bsz, 8, select_sets] --> [bsz, 8, select_sets] + selected_chunks = self.k_landmark_idx[layer_idx].gather(dim=-1, index=merged_results) # [bsz, 8, select_sets] + + # this is chunk idx, which can be used to offload value cache and decide if the cache hits + self.selected_chunk_idx[layer_idx].copy_(selected_chunks, non_blocking=True) + + position_ids = (selected_chunks.unsqueeze(-1) * self.chunk_size + torch.arange(self.chunk_size, device=chunk_attn.device).unsqueeze(0).unsqueeze(0).unsqueeze(0)).view(self.batch_size, self.num_key_value_heads, -1) # [bsz, 8, select_sets * chunk_size] + + return position_ids + + def get_value_cache(self, layer_idx, position_ids): + # gather value cache + value_ = self.v_cache_cpu[layer_idx].gather(dim=-2, index=position_ids.unsqueeze(-1).expand(-1, -1, -1, self.head_dim)) + self.v_cache_buffer[layer_idx][:, :, self.sparse_start:self.sparse_end].copy_(value_, non_blocking=True) + gen_offset = self.gen_offset if layer_idx == self.num_layers - 1 else self.gen_offset + self.incoming_q_len + + return self.v_cache_buffer[layer_idx][:, :, :self.sparse_end + gen_offset] + + def get_key_cache(self, layer_idx, position_ids, rope_func, cos_sin_cache): + # gather key cache and rope them + u = self.U[layer_idx] # [bsz, 128k, rank] + sv = self.SV[layer_idx] # [bsz, 8, rank, 128] + + # indexing, [bsz, 8, sparse_budget, rank] + index_expanded = position_ids.unsqueeze(-1).expand(-1, -1, -1, u.size(-1)) # [bsz, 8, sparse_budget, rank] + u_expand = u.unsqueeze(1).expand(-1, self.num_key_value_heads, -1, -1) # [bsz, 8, 128k, rank] + U_head = torch.gather(u_expand, 2, index_expanded) + + # [bsz, 8, sparse_budget, rank] -matmul- [8, rank, 128] --> [bsz, 8, sparse_budget, 128] + result = torch.einsum('bhrk,bhkd->bhrd', U_head, sv) + + # rope the key cache + result = rope_func(result, position_ids) + + # send to buffer + self.k_cache_buffer[layer_idx][:, :, self.sparse_start:self.sparse_end].copy_(result, non_blocking=True) + gen_offset = self.gen_offset if layer_idx == self.num_layers - 1 else self.gen_offset + self.incoming_q_len + + return self.k_cache_buffer[layer_idx][:, :, :self.sparse_end + gen_offset] + + def update_kv_cache(self, + new_k_cache :torch.Tensor, + new_v_cache :torch.Tensor, + layer_idx :int, + ): + + incoming = new_k_cache.shape[-2] + self.v_cache_buffer[layer_idx][:, :, self.sparse_end+self.gen_offset:self.sparse_end+self.gen_offset+incoming].copy_(new_v_cache, non_blocking=True) + self.k_cache_buffer[layer_idx][:, :, self.sparse_end+self.gen_offset:self.sparse_end+self.gen_offset+incoming].copy_(new_k_cache, non_blocking=True) + + if layer_idx == self.num_layers - 1: + self.kv_offset += incoming + self.gen_offset += incoming + + def clear(self): + self.k_cache_buffer.zero_() + self.v_cache_buffer.zero_() + self.selected_chunk_idx.zero_() + self.k_landmark = None + self.k_landmark_idx = None + self.U = None + self.SV = None + + self.kv_offset = 0 + self.prefill = 0 + self.gen_offset = 0 + self.prefill_local = 0 + + def H2D(self): + pass + + def get_kv_len(self): + return self.kv_offset \ No newline at end of file diff --git a/umbrella/models/llama_shadowKV.py b/umbrella/models/llama_shadowKV.py new file mode 100644 index 0000000..cc2b31f --- /dev/null +++ b/umbrella/models/llama_shadowKV.py @@ -0,0 +1,570 @@ +################################################################################ +# +# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +# Base LLM class + +import torch +import torch.nn.functional as F +import time +import gc +from tqdm import tqdm + +from flash_attn import flash_attn_with_kvcache +from transformers import LlamaForCausalLM, LlamaConfig, AutoTokenizer +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +import vllm + +from .shadow_kv_ops import sample_token, layer_norm, minference_prefill_kernel, apply_rotary_pos_emb_cuda +# from .kv_cache import KV_Cache, ShadowKVCache, ShadowKVCache_CPU +from ..attn.cache import ShadowKVCache +from .base import LLMBase + +Templates = { + 'base': "{ctx}", + 'llama-3': "<|start_header_id|>system<|end_header_id|>You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>{ctx}<|eot_id|><|start_header_id|>assistant<|end_header_id|>", + 'yi': "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n{ctx}<|im_end|>\n<|im_start|>assistant\n", + 'glm': "<|system|>\nYou are a helpful assistant\n<|user|> \n{ctx}<|assistant|>\n", + 'lwm': "You are a helpful assistant.\nUSER: {ctx}\nASSISTANT: Answer: ", + 'qwen': "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n{ctx}<|im_end|>\n<|im_start|>assistant\n", + 'phi': "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\n{ctx}<|end|>\n<|assistant|>\n", +} + +Chat_Templates = { + 'base': "{msg}", + 'llama-3': "<|start_header_id|>user<|end_header_id|>{msg}<|eot_id|><|start_header_id|>assistant<|end_header_id|>", + 'yi': "<|im_start|>user\n{msg}<|im_end|>\n<|im_start|>assistant\n", + 'glm': "<|user|>\n{msg}<|assistant|>\n", + 'lwm': "\nUSER: {msg}\nASSISTANT: ", + 'qwen': "<|im_start|>user\n{msg}<|im_end|>\n<|im_start|>assistant\n", + 'phi': "<|user|>\n{msg}<|end|>\n<|assistant|>\n", +} + +Prefix_Templates = { + 'base': "{ctx}", + 'llama-3': "<|start_header_id|>system<|end_header_id|>You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>{ctx}<|eot_id|><|start_header_id|>assistant<|end_header_id|>OK! I will help you with that. Please ask me anything.<|eot_id|>", + 'yi': "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n{ctx}<|im_end|>\n<|im_start|>assistant\nOK! I will help you with that. Please ask me anything.\n", + 'glm': "<|system|>\nYou are a helpful assistant\n<|user|> \n{ctx}<|assistant|>\nOK! I will help you with that. Please ask me anything.\n", +} + +class LLMBase_ShadowKV(LLMBase): + def __str__(self) -> str: + gpu_mem = f"{round(torch.cuda.memory_allocated(self.device) / 1024**3, 2)} GB / {round(torch.cuda.get_device_properties(self.device).total_memory / 1024**3, 2)} GB" + return f"LLM: {self.model_name}, attn_mode: {self.attn_mode}, max_length: {self.max_length}, batch_size: {self.batch_size}, device: {self.device}, dtype: {self.dtype}, GPU mem: {gpu_mem}" + + def init_kv_cache(self, sparse_budget: int, rank: int, chunk_size: int, config): + # if self.attn_mode == 'full': + # self.kv_cache = KV_Cache(config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size) + # elif self.attn_mode.lower() == 'shadowkv': + # self.kv_cache = ShadowKVCache(config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size, sparse_budget=sparse_budget, rank=rank, chunk_size=chunk_size) + # elif self.attn_mode.lower() == 'shadowkv_cpu': + # self.kv_cache = ShadowKVCache_CPU(config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size, sparse_budget=sparse_budget, rank=rank, chunk_size=chunk_size) + if self.attn_mode.lower() == 'shadowkv': + self.kv_cache = ShadowKVCache(config, max_length=self.max_length, device=self.device, dtype=self.dtype, batch_size=self.batch_size, sparse_budget=sparse_budget, rank=rank, chunk_size=chunk_size) + else: + raise ValueError(f"Invalid attention mode {self.attn_mode}") + + def alloc(self, **kwargs): + raise NotImplementedError("Subclasses must implement the `alloc` method.") + + def print_kv_stats(self): + self.kv_cache.print_stats() + + def get_ctx(self, input_ids: torch.LongTensor): + input_len = input_ids.size(1) + past_len = self.kv_cache.get_kv_len() + position_ids = torch.arange(past_len, past_len + input_len, device=self.device, dtype=torch.long).unsqueeze(0).repeat(input_ids.size(0), 1) + return position_ids + + @torch.inference_mode() + def inference(self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor = None, + storage_ids: torch.LongTensor = None): + + hidden_states = F.embedding(input_ids, self.embed_tokens) + + for idx in range(self.num_layers): + hidden_states = self.layer_compute(self.layers[idx], idx, hidden_states, position_ids) + + hidden_states = layer_norm(hidden_states, w=self.norm_weight, eps=self.norm_variance_epsilon) + + if hidden_states.shape[1] > 16: # prefill + hidden_states = hidden_states[:, -1:, :] + logits = F.linear(hidden_states, self.lm_head).float() + + return logits + + @torch.inference_mode() + def prefill(self, input_ids: torch.LongTensor): + self.kv_cache.clear() + logits = self.inference(input_ids=input_ids, position_ids=self.get_ctx(input_ids)) + + assert self.kv_cache.get_kv_len() == input_ids.shape[-1], f"KV length mismatch, got {self.kv_cache.get_kv_len()}, expected {input_ids.shape[-1]}" + return logits + + @torch.inference_mode() + def prefill_cont(self, input_ids: torch.LongTensor): + logits = self.inference(input_ids=input_ids, position_ids=self.get_ctx(input_ids)) + return logits + + def encode(self, text: str, template=None, truncation=False): + if template == 'chat': + text = self.chat_template.format(msg=text) + input_ids = self.tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.to(self.device) + if self.tokenizer.bos_token_id is not None: + assert self.tokenizer.bos_token_id not in input_ids, f"bos_token_id found in input_ids" + return input_ids + if template == 'ctx': + text = self.ctx_template.format(ctx=text) + if template == 'prefix': + text = self.prefix_template.format(ctx=text) + input_ids = self.tokenizer(text, return_tensors="pt", truncation=truncation).input_ids.to(self.device) + return input_ids + + @torch.inference_mode() + def layer_compute(self, + buffer, + layer_idx :int, + hidden_states: torch.FloatTensor, + position_ids: torch.LongTensor, + attention_mask: torch.FloatTensor=None, + storage_ids: torch.LongTensor=None): + + residual = hidden_states + bsz, q_len, _ = hidden_states.size() + query_states, key_states, value_states = self.pre_attention_compute( + hidden_states, + buffer, + self.num_heads, + self.num_key_value_heads, + self.head_dim + ) + + if isinstance(self.kv_cache, KV_Cache): + query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, position_ids) + key_states, value_states = self.kv_cache.update_kv_cache(key_states, value_states, layer_idx) + + if self.minference == True and q_len > 1: + hidden_states = minference_prefill_kernel(query_states=query_states, key_states=key_states, value_states=value_states, minference_parttern=self.minference_parttern[layer_idx]) + else: + hidden_states = flash_attn_with_kvcache(q=query_states.transpose(1, 2), k_cache=key_states.transpose(1, 2), v_cache=value_states.transpose(1, 2), causal=True) + + elif isinstance(self.kv_cache, ShadowKVCache) or isinstance(self.kv_cache, ShadowKVCache_CPU): + + if q_len > 4*1024: # prefill + # svd unrope key and save + self.kv_cache.get_svd(key_states, layer_idx=layer_idx) + query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, position_ids) + self.kv_cache.prefill_kv_cache(value_states, layer_idx, key_states, query_states[:, :, -1:]) + + if self.minference == True: + hidden_states = minference_prefill_kernel(query_states=query_states, key_states=key_states, value_states=value_states, minference_parttern=self.minference_parttern[layer_idx]) + else: + hidden_states = flash_attn_with_kvcache(q=query_states.transpose(1, 2), k_cache=key_states.transpose(1, 2), v_cache=value_states.transpose(1, 2), causal=True) + + else: # decode + # rope query and key + query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, position_ids) + + # update kv cache to buffer + self.kv_cache.update_kv_cache(key_states, value_states, layer_idx) + + # get retrieval idx + position_ids = self.kv_cache.get_retrieval_position_ids(layer_idx=layer_idx, query_states=query_states) + + # multi-stream + curr_stream = torch.cuda.current_stream() + get_value_stream = self.kv_cache.copy_stream + + with torch.cuda.stream(get_value_stream): + get_value_stream.wait_stream(curr_stream) + value_states = self.kv_cache.get_value_cache(layer_idx, position_ids) + + # gather key cache from GPU and RoPE it (should be hide by CPU offloading time) + key_states = self.kv_cache.get_key_cache(layer_idx=layer_idx, position_ids=position_ids, rope_func=self.apply_rotary_pos_emb_single, cos_sin_cache=self.cos_sin_cache) + + curr_stream.wait_stream(get_value_stream) + + # flash attention + hidden_states = flash_attn_with_kvcache(q=query_states.transpose(1, 2), k_cache=key_states.transpose(1, 2), v_cache=value_states.transpose(1, 2), causal=True) + + else: + raise ValueError(f"Invalid attention mode {self.attn_mode}") + + hidden_states = hidden_states.reshape(bsz, q_len, self.hidden_size) + + if bsz*q_len > 64*1024: # [bsz, seq, 128] + output = torch.empty_like(hidden_states) + prop_iter = bsz * q_len // (8*1024) + prefill_chunk_size = bsz * q_len // prop_iter + prefill_iter = (q_len + prefill_chunk_size - 1) // prefill_chunk_size + for i in range(prefill_iter): + start = i*prefill_chunk_size + end = (i+1)*prefill_chunk_size + output[:, start:end] = self.post_attention_compute(hidden_states[:, start:end], residual[:, start:end], buffer) + + hidden_states = output + + else: + hidden_states = self.post_attention_compute(hidden_states, residual, buffer) + + return hidden_states + + def decode(self, input_ids: torch.Tensor, skip_special_tokens: bool = False): + return self.tokenizer.batch_decode(input_ids, skip_special_tokens=skip_special_tokens) + + @torch.inference_mode() + def generate(self, input_ids: torch.Tensor, gen_len: int = 256, temperature: float = 0.0, top_p: float = 0.9, top_k :int = 50, verbose: bool = False, benchmark: bool = False, cont: bool = False): + """accuracy eval usage, not for throughput eval""" + assert type(input_ids) == torch.Tensor, f"input_ids must be a torch.Tensor, got {type(input_ids)}" + + # prefill + if cont == False: + if input_ids.size(1) > self.max_length: + raise ValueError(f"Input length must be less than {self.max_length}, but got {input_ids.size(1)}") + logits = self.prefill(input_ids) + else: + if input_ids.size(1) + self.kv_cache.get_kv_len() >= self.max_length: + raise ValueError(f"Input length must be less than {self.max_length}, but got {input_ids.size(1)}") + logits = self.prefill_cont(input_ids) + next_token = sample_token(logits[:, -1, :], temperature=temperature, top_p=top_p, top_k=top_k) + + n = 0 + pos = 0 + generated_ids = [] + generated_ids.extend(next_token[0].tolist()) + + self.kv_cache.H2D() + + if benchmark == True: + start = time.time() + + while n < gen_len: + logits = self.inference(input_ids=next_token, position_ids=self.get_ctx(next_token)) + next_token = sample_token(logits[:, -1, :], temperature=temperature, top_p=top_p, top_k=top_k) + + n += 1 + generated_ids.extend(next_token[0].tolist()) + if verbose == True: + generated_text = ( + self.tokenizer.decode( + generated_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=True, + spaces_between_special_tokens=False, + ).strip().split(" ") + ) + now = len(generated_text) - 1 + if now > pos: + print(" ".join(generated_text[pos:now]), end=" ", flush=True) + pos = now + + if next_token[0] == self.tokenizer.eos_token_id: + break + if self.tokenizer.decode(next_token[0]) == "<|eot_id|>": # llama-3 + break + if self.tokenizer.decode(next_token[0]) == "<|im_end|>": # yi + break + if next_token[0] in [151329, 151336, 151338]: # glm + break + if self.tokenizer.decode(next_token[0]) == "<|endoftext|>": # glm + break + if self.tokenizer.decode(next_token[0]) == "<|end|>": # phi + break + + if verbose == True and n!=0: + print(" ".join(generated_text[pos:]), end=" ", flush=True) + if benchmark == True: + end = time.time() + print(f"\nPrefill {input_ids.size(1)} tokens | Generate {n} tokens in {round(end - start, 2)}s, {round(n / (end - start), 2)} tokens/s | cached {self.kv_cache.get_kv_len()}\n") + + # feed new token to the model + self.inference(input_ids=next_token, position_ids=self.get_ctx(next_token)) + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + return [self.tokenizer.decode(generated_ids, skip_special_tokens=True)] + + @torch.inference_mode() + def batch_prefill(self, input_ids: torch.Tensor, benchmark: bool = False): + self.kv_cache.clear() + batch_size = input_ids.size(0) + + assert batch_size == self.batch_size, f"batch_size mismatch, got {batch_size}, expected {self.batch_size}" + + if input_ids.size(1) > self.max_length: + raise ValueError(f"Input length must be less than {self.max_length}, but got {input_ids.size(1)}") + + logits = torch.zeros(batch_size, 1, self.vocab_size, device=self.device, dtype=torch.float32) + + if input_ids.shape[-1] > 120*1024 and input_ids.shape[-1] < 200*1024: + T = 8 + else: + T = 4 + # for bsz in range(0, batch_size, T): + for bsz in tqdm(range(0, batch_size, T), desc=f"Prefilling (batch size={batch_size})"): + req_input_ids = input_ids[bsz:bsz+T] + logits[bsz:bsz+T].copy_(self.inference(input_ids=req_input_ids, position_ids=self.get_ctx(req_input_ids))) + assert self.kv_cache.get_kv_len() == input_ids.shape[-1], f"KV length mismatch, got {self.kv_cache.get_kv_len()}, expected {input_ids.shape[-1]}" + + return logits + + + @torch.inference_mode() + def warmup(self): + + a = torch.randn(self.batch_size, 1024, 1024).to(self.dtype).to(self.device) + b = torch.randn(self.batch_size, 1024, 1024).to(self.dtype).to(self.device) + for _ in range(100): + torch.bmm(a, b) + del a, b + + print("Warmup done") + + @torch.inference_mode() + def batch_generate(self, input_ids: torch.Tensor, gen_len: int = 256, temperature: float = 0.0, top_p: float = -1, top_k :int = 50, verbose: bool = False, benchmark: bool = False, cont: bool = False): + """throughput eval usage""" + assert type(input_ids) == torch.Tensor, f"input_ids must be a torch.Tensor, got {type(input_ids)}" + + # prefill + if cont == False: + if input_ids.size(1) > self.max_length: + raise ValueError(f"Input length must be less than {self.max_length}, but got {input_ids.size(1)}") + logits = self.batch_prefill(input_ids) + else: + logits = self.prefill_cont(input_ids) + + next_token = sample_token(logits[:, -1, :], temperature=temperature, top_p=top_p, top_k=top_k) + + n = 0 + generated_ids = [] + generated_ids.append(next_token[:, -1].tolist()) + + self.kv_cache.H2D() + self.warmup() + + if benchmark == True: + start = time.time() + + while n < gen_len: + logits = self.inference(input_ids=next_token, position_ids=self.get_ctx(next_token)) + next_token = sample_token(logits[:, -1, :], temperature=temperature, top_p=top_p, top_k=top_k) + + n += 1 + generated_ids.append(next_token[:, -1].tolist()) + + if benchmark == True: + end = time.time() + print(f"\nPrefill {input_ids.size(1)} tokens | Generate {n} tokens in {round(end - start, 2)}s | Throughput: {round(self.batch_size * n / (end - start), 2)} tokens/s, Latency: {round((end - start)*1000 / n, 2)} ms/step | cached {self.kv_cache.get_kv_len()}\n") + + # feed new token to the model + self.inference(input_ids=next_token, position_ids=self.get_ctx(next_token)) + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + generated_ids = torch.LongTensor(generated_ids).t().tolist() + + if benchmark == True: + return self.decode(generated_ids, skip_special_tokens=True), self.batch_size * n / (end - start) + + return self.decode(generated_ids, skip_special_tokens=True) + +class LlamaLayer: + def __init__(self, layer_idx) -> None: + + self.wqkv :torch.Tensor = None + self.wo :torch.Tensor = None + + self.gate_up_proj :torch.Tensor = None + self.down_proj :torch.Tensor = None + + self.input_layernorm_weight :torch.Tensor = None + self.input_layernorm_variance_epsilon :float = 0.0 + + self.post_attention_layernorm_weight :torch.Tensor = None + self.post_attention_layernorm_variance_epsilon :float = 0.0 + + self.layer_idx = layer_idx + + def init_parameters(self, hf_layer: LlamaDecoderLayer): + + self.wqkv :torch.Tensor= torch.cat((hf_layer.self_attn.q_proj.weight.detach(), hf_layer.self_attn.k_proj.weight.detach(), hf_layer.self_attn.v_proj.weight.detach()), dim=0) + self.wo :torch.Tensor= hf_layer.self_attn.o_proj.weight.detach() + self.q_size = hf_layer.self_attn.q_proj.weight.shape[0] + self.kv_size = hf_layer.self_attn.k_proj.weight.shape[0] + + self.gate_up_proj = torch.cat((hf_layer.mlp.gate_proj.weight.detach(), hf_layer.mlp.up_proj.weight.detach()), dim=0) + self.down_proj = hf_layer.mlp.down_proj.weight.detach() + + self.input_layernorm_weight = hf_layer.input_layernorm.weight + self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.variance_epsilon + + self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight + self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon + + def init_gpu(self, device:str = 'cuda:0'): + + self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=True) + self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, non_blocking=True) + self.wqkv = self.wqkv.to(device, non_blocking=True) + self.wo = self.wo.to(device, non_blocking=True) + self.gate_up_proj = self.gate_up_proj.to(device, non_blocking=True) + self.down_proj = self.down_proj.to(device, non_blocking=True) + +class Llama(LLMBase_ShadowKV): + def __init__(self, + model_name: str = "gradientai/Llama-3-8B-Instruct-Gradient-1048k", + batch_size :int = 1, + max_length :int = 64*1024, + device :str = 'cuda:0', + dtype = torch.bfloat16, + attn_mode: str = 'full', + sparse_budget: int = 2048, + rank=160, + chunk_size=8, + minference=False) -> None: + + # assert batch_size == 1, "Batch size must be 1" + self.batch_size = batch_size + self.device = device + self.dtype = dtype + self.config = LlamaConfig.from_pretrained(model_name) + self.model_name = model_name + self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, legacy=False) + self.max_length = max_length + self.hidden_size = self.config.hidden_size + self.num_heads = self.config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = self.config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = self.config.max_position_embeddings + self.rope_theta = self.config.rope_theta + self.vocab_size = self.config.vocab_size + + self.init_parameters() + self.attn_mode = attn_mode + self.minference = minference + + if 'llama-3' in model_name.lower(): + self.ctx_template = Templates['llama-3'] + self.chat_template = Chat_Templates['llama-3'] + self.prefix_template = Prefix_Templates['llama-3'] + elif 'yi' in model_name.lower(): + self.ctx_template = Templates['yi'] + self.chat_template = Chat_Templates['yi'] + self.prefix_template = Prefix_Templates['yi'] + else: + raise ValueError(f"Invalid model name {model_name}") + + self.init_kv_cache(sparse_budget, rank, chunk_size, self.config) + + if self.minference: + import json + self.minference_parttern = [] + for layer_idx in range(self.num_layers): + self.minference_parttern.append({int(ii): jj for ii, jj in json.load(open(MODEL2PATH[self.model_name]))[layer_idx].items()}) + + def alloc(self, **kwargs): + pass + + def _set_cos_sin_cache(self, inv_freq: torch.Tensor): + t = torch.arange(self.max_length + 1024, device=self.device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos().to(self.dtype), emb.sin().to(self.dtype) + + @torch.inference_mode() + def apply_rotary_pos_emb_single(self, x: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor: + return apply_rotary_pos_emb_cuda(x, self.cos_sin_cache, position_ids) + + @torch.inference_mode() + def apply_rotary_pos_emb(self, q: torch.Tensor, k: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor: + vllm._custom_ops.rotary_embedding(position_ids, q, k, 128, self.cos_sin_cache, True) + bsz = q.shape[0] + q = q.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + return q, k + + def init_parameters(self): + hf_model = LlamaForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) + self.embed_tokens = hf_model.model.embed_tokens.weight.detach().to(self.device) + self.lm_head = hf_model.lm_head.weight.detach().to(self.device) + self.norm_weight = hf_model.model.norm.weight.detach().to(self.device) + self.norm_variance_epsilon = hf_model.model.norm.variance_epsilon + try: + cos_cache = hf_model.model.layers[0].self_attn.rotary_emb.cos_cached[:self.max_length+1024].to(self.device).to(self.dtype) + sin_cache = hf_model.model.layers[0].self_attn.rotary_emb.sin_cached[:self.max_length+1024].to(self.device).to(self.dtype) + except: + cos_cache, sin_cache = self._set_cos_sin_cache(hf_model.model.layers[0].self_attn.rotary_emb.inv_freq.to(self.device)) + self.cos_sin_cache = torch.cat((cos_cache[:, :64], sin_cache[:, :64]), dim=-1) + + del cos_cache, sin_cache + + self.layers :list[LlamaLayer] = [] + + for idx, hf_layer in enumerate(hf_model.model.layers): + layer = LlamaLayer(idx) + layer.init_parameters(hf_layer=hf_layer) + layer.init_gpu(self.device) + self.layers.append(layer) + hf_model.model.layers[idx] = None + gc.collect() + + self.num_layers = len(self.layers) + + def pre_attention_compute( + self, + hidden_states: torch.Tensor, + buffer: LlamaLayer, + num_heads:int, + num_key_value_heads:int, + head_dim:int + ): + hidden_states = layer_norm(hidden_states, buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) + qkv = F.linear(hidden_states, buffer.wqkv) + query_states, key_states, value_states = qkv.split([buffer.q_size, buffer.kv_size, buffer.kv_size], dim=-1) + + return query_states, key_states, value_states.view(value_states.shape[0], -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + def clear(self): + self.kv_cache.clear() + + def post_attention_compute( + self, + attn_output: torch.Tensor, + residual: torch.Tensor, + buffer: LlamaLayer + ): + hidden_states = F.linear(attn_output, buffer.wo) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = layer_norm(hidden_states, buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight) + + hidden_states = F.linear(hidden_states, buffer.gate_up_proj) + d = hidden_states.shape[-1] // 2 + output_shape = (hidden_states.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device) + vllm._custom_ops.silu_and_mul(out, hidden_states) + + hidden_states = F.linear(out, buffer.down_proj) + hidden_states = residual + hidden_states + return hidden_states diff --git a/umbrella/models/shadow_kv_ops.py b/umbrella/models/shadow_kv_ops.py new file mode 100644 index 0000000..434b3aa --- /dev/null +++ b/umbrella/models/shadow_kv_ops.py @@ -0,0 +1,296 @@ +################################################################################ +# +# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +# Some code comes from MInference +# Original license: +# Copyright (c) Microsoft Corporation. and affiliates All rights reserved. +# +# See LICENSE.txt for license information +################################################################################ + +import torch +from torch.nn import functional as F + +from flashinfer.norm import rmsnorm +from minference import vertical_slash_sparse_attention, block_sparse_attention, streaming_forward + +#fix the kernel part later +from kernels import shadowkv + +def layer_norm( + hidden_states: torch.Tensor, + eps: float, + w: torch.Tensor, +): + return rmsnorm(hidden_states.view(-1, hidden_states.size(-1)), w, eps).view_as(hidden_states) + +# def layer_norm( +# hidden_states: torch.Tensor, +# eps: float, +# w: torch.Tensor, +# ): +# input_dtype = hidden_states.dtype +# hidden_states = hidden_states.to(torch.float32) +# variance = hidden_states.pow(2).mean(-1, keepdim=True) +# hidden_states = hidden_states * torch.rsqrt(variance + eps) +# hidden_states = w * hidden_states.to(input_dtype) +# return hidden_states + + +# copy from https://github.com/microsoft/MInference/blob/main/minference/modules/minference_forward.py + +last_q = 64 +arange = torch.arange(last_q, device="cuda") +LAST_Q_MASK = arange[None, None, :, None] >= arange[None, None, None, :] + +def sum_all_diagonal_matrix(mat: torch.tensor): + b, h, n, m = mat.shape + zero_mat = torch.zeros((b, h, n, n)).to(mat.device) # Zero matrix used for padding + mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) # pads the matrix on left and right + mat_strided = mat_padded.as_strided((1, 1, n, n + m), (1, n * (2 * n + m), 2 * n + m + 1, 1)) # Change the strides + sum_diags = torch.sum(mat_strided, 2) # Sums the resulting matrix's columns + return sum_diags[:,:,1:] + +def minference_prefill_kernel(query_states, key_states, value_states, minference_parttern): + q_len = query_states.shape[2] # [bsz, heads, q_len, head_dim] + bsz = query_states.shape[0] + gqa_groups = query_states.shape[1] // key_states.shape[1] + + assert q_len > 1 + output = torch.empty_like(query_states) + for head in range(query_states.size(1)): + q = query_states[:, head, :, :].unsqueeze(1) + k = key_states[:, head // gqa_groups, :, :].unsqueeze(1) + v = value_states[:, head // gqa_groups, :, :].unsqueeze(1) + attn_output = gather_last_q_vertical_slash_topk_v4(q, k, v, head, minference_parttern) + output[:, head:head + 1] = attn_output + + return output.transpose(1, 2).contiguous().reshape(bsz, q_len, -1) + + +def gather_last_q_vertical_slash_topk_v4(q, k, v, head_id, minference_parttern): + def vertical_and_slash_kernel(q, k, v, vertical_size, slash_size): + vertical_size, slash_size = min(q_len, max(vertical_size, 30)), min(q_len, max(slash_size, 50)) + last_q = min(64, q_len) + qk = torch.einsum(f'bhmk, bhnk -> bhmn', q[:,:,-last_q:,:], k) + qk[:, :, :, -last_q:] = torch.where(LAST_Q_MASK[...,-last_q:,-last_q:].to(q.device), qk[:, :, :, -last_q:], -torch.inf) + qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32) + vertical = qk.sum(-2, keepdim=True) + vertical[...,:30] = torch.inf + vertical_topk = torch.topk(vertical, vertical_size, -1).indices + + slash = sum_all_diagonal_matrix(qk)[...,:-last_q + 1] + slash[...,-100:] = torch.inf + slash_topk = slash + slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices + + return vertical_slash_sparse_attention(q, k, v, vertical_topk, slash) + + def block_sparse_kernel(q, k, v, vertical_size=None, slash_size=None): + topk = 100 + return block_sparse_attention(q, k, v, topk) + + q_len = q.shape[2] + bsz = q.shape[0] + + ty, vertical_size, slash_size, _ = minference_parttern.get(head_id) + + fc = { + "stream_llm": streaming_forward, + "vertical_and_slash": vertical_and_slash_kernel, + "block_sparse": block_sparse_kernel, + }[ty] + return fc(q, k, v, vertical_size, slash_size) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def apply_rotary_pos_emb_single(q, cos, sin, position_ids, unsqueeze_dim=1): + # if position_ids shape is (batch_size, num_heads, seq_len), then reshape it to (batch_size*num_heads, seq_len) + if len(position_ids.shape) == 3: + position_ids = position_ids.view(-1, position_ids.size(-1)) + cos = cos[position_ids] + sin = sin[position_ids] + q_embed = (q * cos) + (rotate_half(q) * sin) + + else: + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + return q_embed + + +def apply_rotary_pos_emb_cuda(x, cos_sin, position_ids): + batch_size, heads, seq_len, embed_dim = x.shape + half_dim = embed_dim // 2 + + output = torch.empty_like(x) + + shadowkv.apply_rotary_pos_emb_new( + x, cos_sin, position_ids, output, + int(batch_size), int(heads), int(seq_len), int(embed_dim), + int(x.stride(0)), int(x.stride(1)), int(x.stride(2)), int(x.stride(3)), + int(cos_sin.stride(0)), + int(position_ids.stride(0)), int(position_ids.stride(1)), int(position_ids.stride(2)), + int(half_dim) + ) + + return output + +def apply_rotary_pos_emb_cuda_push_cache(x, cos_sin, position_ids, chunk_size, cache, sparse_start, sparse_end, cnts): + batch_size, heads, seq_len, embed_dim = x.shape + half_dim = embed_dim // 2 + if cos_sin.shape[-1] == 128: + shadowkv.apply_rotary_pos_emb_push_cache_opt( + x, cos_sin, position_ids, cache, cnts, + int(batch_size), int(heads), int(seq_len), int(embed_dim), + int(x.stride(0)), int(x.stride(1)), int(x.stride(2)), int(x.stride(3)), + int(cos_sin.stride(0)), + int(position_ids.stride(0)), int(position_ids.stride(1)), int(position_ids.stride(2)), + int(cache.stride(0)), int(cache.stride(1)), int(cache.stride(2)), + int(sparse_start), int(sparse_end), + int(half_dim), int(chunk_size) + ) + elif cos_sin.shape[-1] == 64: + shadowkv.apply_rotary_pos_emb_push_cache_opt_glm( + x, cos_sin, position_ids, cache, cnts, + int(batch_size), int(heads), int(seq_len), int(embed_dim), + int(x.stride(0)), int(x.stride(1)), int(x.stride(2)), int(x.stride(3)), + int(cos_sin.stride(0)), + int(position_ids.stride(0)), int(position_ids.stride(1)), int(position_ids.stride(2)), + int(cache.stride(0)), int(cache.stride(1)), int(cache.stride(2)), + int(sparse_start), int(sparse_end), + int(half_dim), int(chunk_size) + ) + else: + raise ValueError(f"Invalid cos_sin shape {cos_sin.shape}") + + return cache + +def batch_gather_gemm_rotary_pos_emb_cuda( + a: torch.Tensor, + b: torch.Tensor, + cos_sin: torch.Tensor, + position_ids: torch.Tensor, + output: torch.Tensor, + chunk_size: int, + cache: torch.Tensor, + sparse_start: int, + sparse_end: int, + cnts: torch.Tensor +): + batch_size, seq_len, rank = a.shape + _, heads, head_dim, _ = b.shape + max_seq_len, _ = cos_sin.shape + _, _, num_chunks = position_ids.shape + sparse_budget = num_chunks * chunk_size + position_ids = position_ids.to(torch.int32).contiguous() + + shadowkv.batch_gather_gemm( + a.contiguous(), + b.contiguous(), + cos_sin.contiguous(), + cos_sin.contiguous(), + position_ids, + output, + batch_size, + heads, + seq_len, + head_dim, + rank, + sparse_budget, + max_seq_len, + chunk_size, + cnts, + ) + + return apply_rotary_pos_emb_cuda_push_cache(output, cos_sin, position_ids, chunk_size, cache, sparse_start, sparse_end, cnts) + + +# copy from https://github.com/LeeSinLiang/microGPT/blob/ed40cf9780dbeb180adfe94c227d4aa97e69250e/gpt.py +def top_k_top_p_filter(logits: torch.Tensor, top_k: int = 0, top_p: float = 0.0): + """ + + Args: + logits (torch.Tensorpe_): 2D tensor with shape (batch, vocab) + top_k (int, optional): top_k. Defaults to 0. + top_p (float, optional): top_p. Defaults to 0.0. + + Returns: + torch.Tensor: a renormalized logits + """ + if top_k > 0: + filter = torch.topk(logits, min(top_k, logits.size(-1)))[0] + logits[logits < filter[:, [-1]]] = float('-inf') + if top_p > 0.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + filter = cumulative_probs > top_p + filter[..., 1:] = filter[..., :-1].clone() + filter[..., 0] = 0 + indices_to_remove = filter.scatter(1, sorted_indices, filter) + logits[indices_to_remove] = float('-inf') + return logits + +def norm_logits(logits : torch.Tensor, temperature=0.6, top_k=-1, top_p=0.9) -> torch.Tensor: + """ + + Args: + logits (torch.Tensor): shape (1, vocab) + temperature (float): temperature + top_k (float): top_k + top_p (float): top_p + + Returns: + torch.Tensor: next token with shape as (batch, 1) + """ + assert logits.dim() == 2 + if temperature != 1.0: + logits = logits / temperature + logits = top_k_top_p_filter(logits, top_k=top_k, top_p=top_p) + + probs = F.softmax(logits, dim=-1) + return probs + + +def sample(probs : torch.Tensor, num_samples=1): + idx_next = torch.multinomial(probs, num_samples=num_samples, replacement=True) + return idx_next + +def sample_token(logits: torch.Tensor, temperature=0, top_k=50, top_p=0.9): + if temperature == 0.0: + token = logits.argmax(dim=-1, keepdim=True) + else: + token = sample(norm_logits(logits, temperature=temperature, top_p=top_p, top_k=top_k)) + + return token \ No newline at end of file