import copy import math import os import time from threading import Thread import uuid import gradio as gr import spaces import torch from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend from docling.datamodel.pipeline_options import PdfPipelineOptions from docling.document_converter import DocumentConverter, InputFormat, PdfFormatOption from langchain.schema.document import Document from langchain_chroma import Chroma from langchain_community.embeddings import HuggingFaceBgeEmbeddings from langchain_docling import DoclingLoader from langchain_docling.loader import ExportType from langchain_text_splitters import RecursiveCharacterTextSplitter from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, TextIteratorStreamer from transformers.models.llama.modeling_llama import rotate_half import threading import shutil import time from utils import ( calculate_tokens_suggest_compression_ratio, repeat_kv, update_retrieval_context, ) # Initialize the model and tokenizer. api_token = os.getenv("HF_TOKEN") model_name = "meta-llama/Llama-3.1-8B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token) model = AutoModelForCausalLM.from_pretrained(model_name, token=api_token, torch_dtype=torch.float16) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.eval() model.to(device) embedding_model = HuggingFaceBgeEmbeddings( model_name="BAAI/bge-large-en-v1.5", model_kwargs={"device": str(device)}, encode_kwargs={"normalize_embeddings": True}, query_instruction="" ) # Create a chat template and split into prefix and suffix. content_system = "" content_user = "######" user_template = [ {"role": "system", "content": content_system}, {"role": "user", "content": content_user} ] user = tokenizer.apply_chat_template(user_template, add_generation_prompt=True, tokenize=False) prefix, suffix = user.split(content_user) sink_tokens = max(4, len(tokenizer.encode(prefix))) # Default prompt content. default_task_description = ( "Answer the question based on the given passages. " "Only give me the answer and do not output any other words." ) default_few_shot = """Examples question: Which case was brought to court first Miller v. California or Gates v. Collier ? answer: Miller v. California question: The actor that plays Phileas Fogg in "Around the World in 80 Days", co-starred with Gary Cooper in a 1939 Goldwyn Productions film based on a novel by what author? answer: Charles L. Clifford question: Prior to playing for Michigan State, Keith Nichol played football for a school located in what city? answer: Norman """ CHROMA_DB_DIR = "./chroma_db" CACHE_DIR = "./cache_dir" EXPIRATION_SECONDS = 3600 def background_cleanup(): while True: current_time = int(time.time()) # Clean Chroma collections if os.path.exists(CHROMA_DB_DIR): for dirname in os.listdir(CHROMA_DB_DIR): parts = dirname.split("_") if len(parts) >= 3 and parts[1].isdigit(): timestamp = int(parts[1]) if current_time - timestamp > EXPIRATION_SECONDS: path = os.path.join(CHROMA_DB_DIR, dirname) shutil.rmtree(path, ignore_errors=True) print(f"[Cleanup] Deleted Chroma collection: {path}") # Clean cache files if os.path.exists(CACHE_DIR): for filename in os.listdir(CACHE_DIR): parts = filename.split("_") if len(parts) >= 3 and parts[1].isdigit(): timestamp = int(parts[1]) if current_time - timestamp > EXPIRATION_SECONDS: path = os.path.join(CACHE_DIR, filename) os.remove(path) print(f"[Cleanup] Deleted cache file: {path}") time.sleep(600) cleanup_thread = threading.Thread(target=background_cleanup, daemon=True) cleanup_thread.start() class FinchCache(DynamicCache): def __init__(self) -> None: super().__init__() self.key_cache = [] self.value_cache = [] @staticmethod def _rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def _apply_key_rotary_pos_emb(self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: return (key_states * cos) + (self._rotate_half(key_states) * sin) @staticmethod def _rerotate_cos_sin(x, inv_freq, important_pos_batch): B, H, L = important_pos_batch.shape device = important_pos_batch.device device_type = x.device.type dtype = x.dtype idx = torch.arange(0, L, device=device) idx = idx.unsqueeze(0) inv_freq = inv_freq[None, None, :, None].float().expand(B, H, -1, 1) # (B, H, M, 1) idx = idx[:, None, :].float().expand(B, H, L) # (B, H, L) delta_pos = idx - important_pos_batch delta_pos = delta_pos.unsqueeze(2) # (B, H, 1, L) device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = delta_pos.float() * inv_freq.float() freqs = freqs.transpose(2, 3) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos().contiguous() sin = emb.sin().contiguous() return cos.to(dtype=dtype), sin.to(dtype=dtype) @staticmethod def gather_important_tokens(states, indices): return torch.gather(states, 2, indices.unsqueeze(-1).expand(-1, -1, -1, states.size(3))).contiguous() def compress_cache(self, layer_index, important_pos, inv_freq): new_length = important_pos.size(2) new_cos, new_sin = self._rerotate_cos_sin(self.key_cache[layer_index], inv_freq, important_pos) gathered_keys = self.gather_important_tokens(self.key_cache[layer_index], important_pos).clone() self.key_cache[layer_index] = self._apply_key_rotary_pos_emb(gathered_keys, new_cos, new_sin) gathered_values = self.gather_important_tokens(self.value_cache[layer_index], important_pos).clone() self.value_cache[layer_index] = gathered_values self._seen_tokens = new_length def save(self, path: str): try: os.makedirs(os.path.dirname(path), exist_ok=True) torch.save( {"key_cache": [k.cpu() for k in self.key_cache], "value_cache": [v.cpu() for v in self.value_cache]}, path, ) except Exception as e: print(f"Error occurred while saving: {e}") @classmethod def load(cls, path: str, device: str = "cpu") -> "FinchCache": data = torch.load(path, map_location=device) cache = cls() cache.key_cache = [k.to(device) for k in data["key_cache"]] cache.value_cache = [v.to(device) for v in data["value_cache"]] cache._seen_tokens = cache.value_cache[0].size(2) if cache.value_cache else 0 return cache def convert_to_markdown(file_objs, url, do_ocr, do_table_structure): file_path = file_objs if file_objs is not None else url pipeline_options = PdfPipelineOptions() pipeline_options.do_ocr = do_ocr pipeline_options.do_table_structure = do_table_structure pdf_format_options = PdfFormatOption( pipeline_options=pipeline_options, backend=PyPdfiumDocumentBackend, ) doc_converter = DocumentConverter( allowed_formats=[InputFormat.PDF], format_options={InputFormat.PDF: pdf_format_options} ) loader = DoclingLoader( file_path=file_path, export_type=ExportType.MARKDOWN, converter=doc_converter ) try: docs = loader.load() return docs[0].page_content except Exception as e: raise RuntimeError(f"Failed to convert document to markdown: {e}") def create_rag_index(collection_name, text_no_prefix): text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( tokenizer, chunk_size=256, chunk_overlap=0, add_start_index=True, strip_whitespace=True, separators=["\n\n", "\n", ".", " ", ""], ) docs = [Document(page_content=x) for x in text_splitter.split_text(text_no_prefix)] vectorstore = Chroma.from_documents(collection_name=collection_name, persist_directory="./chroma_db", documents=docs, embedding=embedding_model) return vectorstore @spaces.GPU def auto_convert(file_objs, url, do_ocr, do_table_structure): # When a new file/URL is loaded, disable chat (compression not done) chat_status = "Document not compressed yet. Please compress the document to enable chat." if file_objs is None and (url is None or url.strip() == ""): return ( gr.update(value=""), "Number of tokens before compression: ", gr.update(), "Number of tokens after compression: ", 0, gr.update(interactive=False), False, {}, chat_status ) print("Converting to markdown") try: markdown = convert_to_markdown(file_objs, url, do_ocr, do_table_structure) except RuntimeError as e: return ( gr.update(value=f"{str(e)} Please try uploading another document format."), "Number of tokens before compression: ", gr.update(), "Number of tokens after compression: ", 0, gr.update(interactive=False), False, {}, chat_status ) print("Done") combined_text = prefix + markdown print("Suggestioning Compression ratio") token_count, suggestions, _ = calculate_tokens_suggest_compression_ratio(combined_text, tokenizer, model) print("Done") min_ratio = min(suggestions) max_ratio = max(suggestions) default_ratio = 6 retrieval_tokens = int(token_count / default_ratio) token_count_str = f"Number of tokens before compression: {token_count}" retrieval_str = f"Number of tokens after compression: {retrieval_tokens}" slider_update = gr.update(value=default_ratio, minimum=min_ratio, maximum=max_ratio, step=1) if combined_text.startswith(prefix): rag_text = combined_text[len(prefix):] else: rag_text = combined_text current_timestamp = int(time.time()) collection_name = f"default_{current_timestamp}_{uuid.uuid4().hex[:6]}" rag_index = create_rag_index(collection_name, rag_text) state = {"rag_index": collection_name} print("Done") return ( combined_text, token_count_str, slider_update, retrieval_str, token_count, gr.update(interactive=True), # Enable compress button if conversion succeeds. False, state, chat_status ) def get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_ids, context_attention_mask, question_ids, question_attention_mask): try: device = model.device dtype = model.dtype sink_tokens = sink_tokens num_chunks = step_size context_ids = context_ids.to(device) context_attention_mask = context_attention_mask.to(device) question_ids = question_ids.to(device) question_attention_mask = question_attention_mask.to(device) question_len = question_ids.size(1) total_len = context_ids.size(1) max_context_tokens_allowed = model.config.max_position_embeddings - question_len if total_len > max_context_tokens_allowed: num_chunks = max(step_size, math.ceil(total_len / max_context_tokens_allowed)) if total_len <= sink_tokens or num_chunks == 1: context_ids_list = [context_ids] context_attention_mask_list = [context_attention_mask] else: remainder_len = total_len - sink_tokens base = remainder_len // num_chunks leftover = remainder_len % num_chunks chunk_sizes = [sink_tokens + base] for _ in range(num_chunks - 2): chunk_sizes.append(base) if num_chunks > 1: chunk_sizes.append(base + leftover) context_ids_list = [] context_attention_mask_list = [] offset = 0 for size in chunk_sizes: end = offset + size context_ids_list.append(context_ids[:, offset:end]) context_attention_mask_list.append(context_attention_mask[:, offset:end]) offset = end len_rest = max(total_len - sink_tokens, 1) compression_factor = len_rest // target_token_size if compression_factor < 1: compression_factor = 1 tokenized_doc_chunks = [] for ids_chunk, mask_chunk in zip(context_ids_list, context_attention_mask_list): tokenized_doc_chunks.append({"input_ids": ids_chunk, "attention_mask": mask_chunk}) print("Number of chunks: ", len(tokenized_doc_chunks)) rotary_emb = model.model.rotary_emb.to(device) inv_freq = rotary_emb.inv_freq batch_size = question_ids.size(0) ones_mask = torch.ones(batch_size, 1, dtype=question_attention_mask.dtype, device=device) cache = FinchCache() past_cache_len = 0 past_attention_mask = torch.zeros(batch_size, 0, dtype=question_attention_mask.dtype, device=device) num_chunks = len(tokenized_doc_chunks) query_context_matrices = {} def query_hook_fn(module, input, output): layer_idx = getattr(module, "layer_idx", None) if layer_idx is not None: query_states = output.detach() bsz, seq_len, hidden_dim = query_states.size() num_query_heads = module.num_query_heads head_dim = hidden_dim // num_query_heads query_states = ( query_states.view(bsz, seq_len, num_query_heads, head_dim) .transpose(1, 2) .contiguous() ) query_context_matrices[layer_idx] = query_states[:, :, _current_chunk_offset:, :].clone() hooks = [] for i, layer in enumerate(model.model.layers): layer.self_attn.q_proj.layer_idx = i layer.self_attn.q_proj.num_query_heads = layer.self_attn.config.num_attention_heads hook = layer.self_attn.q_proj.register_forward_hook(query_hook_fn) hooks.append(hook) for j, tokenized_doc_chunk in enumerate(tokenized_doc_chunks): current_seq_length = tokenized_doc_chunk["input_ids"].size(1) _current_chunk_offset = current_seq_length query_context_matrices.clear() chunk_input_ids = tokenized_doc_chunk["input_ids"].contiguous() chunk_attention_mask = tokenized_doc_chunk["attention_mask"].contiguous() segment_attention_mask = torch.cat( [past_attention_mask, chunk_attention_mask, ones_mask], dim=-1 ).contiguous() current_input_ids = torch.cat([chunk_input_ids, question_ids], dim=-1).contiguous() current_attention_mask = torch.cat([segment_attention_mask, question_attention_mask], dim=-1).contiguous() past_seen_tokens = cache.get_seq_length() if cache is not None else 0 cache_position = torch.arange( past_seen_tokens + chunk_input_ids.shape[1], past_seen_tokens + current_input_ids.shape[1], device=device ) causal_mask = model.model._prepare_4d_causal_attention_mask_with_cache_position( current_attention_mask, sequence_length=question_ids.size(1), target_length=current_attention_mask.size(-1), dtype=dtype, device=device, cache_position=cache_position, batch_size=current_input_ids.size(0), ).contiguous() with torch.no_grad(): outputs = model.model( input_ids=current_input_ids, use_cache=True, past_key_values=cache, ) cache = outputs.past_key_values len_question = question_ids.size(1) for layer_idx in range(len(model.model.layers)): key_matrix = cache.key_cache[layer_idx] query_matrix = query_context_matrices[layer_idx] layer_cache_pos = torch.arange( past_cache_len + current_seq_length, past_cache_len + current_seq_length + len_question, device=device ) position_ids = layer_cache_pos.unsqueeze(0) cos, sin = rotary_emb(query_matrix, position_ids) cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) query_matrix = (query_matrix * cos) + (rotate_half(query_matrix) * sin) num_repeats = model.config.num_attention_heads // model.config.num_key_value_heads key_matrix = repeat_kv(key_matrix, num_repeats) scaling = math.sqrt(model.config.head_dim) attention_matrix = torch.matmul(query_matrix, key_matrix.transpose(2, 3)) / scaling causal_mask_sliced = causal_mask[:, :, :, : key_matrix.shape[-2]] attention_matrix = attention_matrix + causal_mask_sliced attention_matrix = torch.nn.functional.softmax(attention_matrix, dim=-1, dtype=torch.float32).to(query_matrix.dtype) tol = 1e-8 binary_mask = (torch.abs(causal_mask_sliced.to(torch.float32)) < tol).to(torch.float32) non_zero_counts = binary_mask.sum(dim=3, keepdim=True) non_zero_counts = torch.clamp_min(non_zero_counts, 1.0).to(attention_matrix.dtype) attention_matrix = attention_matrix / non_zero_counts if j != num_chunks - 1: attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length].clone().contiguous() else: attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length + len_question].clone().contiguous() attention_matrix = torch.sum(attention_matrix, dim=-2) attention_matrix = attention_matrix.view( attention_matrix.size(0), model.config.num_key_value_heads, num_repeats, -1 ).sum(dim=2) full_context_size = attention_matrix.size(-1) attention_matrix[..., :sink_tokens] = float("inf") if j == num_chunks - 1: attention_matrix[..., -len_question:] = float("inf") if j == 0: k = int(sink_tokens + (max(0, current_seq_length - sink_tokens) // compression_factor)) k = min(k + past_cache_len, full_context_size) elif j < num_chunks - 1: to_keep_new = int(current_seq_length // compression_factor) k = min(past_cache_len + to_keep_new, full_context_size) else: desired_final = sink_tokens + target_token_size + len_question k = desired_final if full_context_size >= desired_final else full_context_size k = max(k, sink_tokens) selected_indices = torch.topk(attention_matrix, k, dim=-1).indices selected_indices, _ = torch.sort(selected_indices, dim=-1) cache.compress_cache(layer_idx, selected_indices, inv_freq) past_cache_len = cache._seen_tokens past_attention_mask = torch.ones(1, past_cache_len, device=device) for hook in hooks: hook.remove() return cache except Exception as e: raise RuntimeError(f"Failed to compress KV cache: {e}") def run_naive_rag_query(collection_name, query, rag_token_size, prefix, task, few_shot_examples): k = max(1, rag_token_size // 256) vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embedding_model, collection_name=collection_name) retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": k}) retrieved_docs = retriever.invoke(query) for doc in retrieved_docs: print("=================") print(doc.page_content) print("=================") formatted_context = "\n\n".join([doc.page_content for doc in retrieved_docs]) rag_context = prefix + "Retrieved context: \n" + formatted_context + task + few_shot_examples return rag_context @spaces.GPU def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_local_value, task_description, few_shot, state, progress=gr.Progress()): progress(0, desc="Starting compression process") # percentage = int(global_local_value.replace('%', '')) percentage = 0 if global_local_value == "RAG" else 100 progress(0.1, desc="Tokenizing text and preparing task") question_text = task_description + "\n" + few_shot context_encoding = tokenizer(combined_text, return_tensors="pt").to(device) question_encoding = tokenizer(question_text, return_tensors="pt").to(device) context_ids = context_encoding["input_ids"] context_attention_mask = context_encoding["attention_mask"] question_ids = question_encoding["input_ids"] question_attention_mask = question_encoding["attention_mask"] retrieval_context_length = int(context_ids.size(1) / retrieval_slider_value) rag_tokens = int(retrieval_context_length * (1.0 - (percentage / 100))) kv_tokens = retrieval_context_length - rag_tokens progress(0.2, desc=f"Token breakdown computed: {kv_tokens} KV tokens, {rag_tokens} RAG tokens") if percentage > 0: target_token_size = int(retrieval_context_length * (percentage / 100)) progress(0.3, desc="Starting KV cache compression") step_size = 2 try: past_key_values = copy.deepcopy(get_compressed_kv_cache(sink_tokens, step_size, target_token_size, context_ids, context_attention_mask, question_ids, question_attention_mask)) except Exception as e: progress(1, desc="Compression failed") print("Error during KV cache compression:", e) state["error"] = "Error during KV cache compression. Please try lowering the compression ratio and try again." return state, False compressed_length = past_key_values.get_seq_length() progress(0.6, desc="KV cache compression completed") else: target_token_size = 0 past_key_values = FinchCache() compressed_length = past_key_values.get_seq_length() progress(0.3, desc="Skipping compression as percentage is 0") current_timestamp = int(time.time()) cache_name = f"cache_{current_timestamp}_{uuid.uuid4().hex[:6]}.pt" save_dir = "./cache_dir" os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, cache_name) past_key_values.save(save_path) progress(0.8, desc="Cache saved successfully") collection_name = state.get("rag_index", None) if collection_name is None: print("Collection name not found; creating a new one.") if combined_text.startswith(prefix): rag_text = combined_text[len(prefix):] else: rag_text = combined_text current_timestamp = int(time.time()) collection_name = f"default_{current_timestamp}_{uuid.uuid4().hex[:6]}" rag_index = create_rag_index(collection_name, rag_text) state.update({ "compressed_cache": save_path, "rag_index": collection_name, "global_local": percentage, "task_description": task_description, "few_shot": few_shot, "retrieval_slider": retrieval_context_length, }) progress(1, desc="Compression complete") return state, "Document compressed successfully. You can now chat.", True @spaces.GPU def chat_response_stream(message: str, history: list, state: dict, compression_done: bool): # Check if the document is compressed before allowing chat if not compression_done or "compressed_cache" not in state: yield "Document not compressed yet. Please compress the document first to enable chat." return user_message = message save_path = state["compressed_cache"] past_key_values = FinchCache.load(save_path, device=model.device) compressed_length = past_key_values.get_seq_length() collection_name = state["rag_index"] retrieval_slider_value = state["retrieval_slider"] percentage = state["global_local"] rag_retrieval_size = int(retrieval_slider_value * (1.0 - (percentage / 100))) print("RAG retrieval size: ", rag_retrieval_size) print("Compressed cache: ", compressed_length) if percentage == 0: rag_prefix = prefix rag_task = state["task_description"] rag_few_shot = state["few_shot"] else: rag_prefix = "" rag_task = "" rag_few_shot = "" print("user message: ", user_message) if rag_retrieval_size != 0: print("Running RAG query") rag_context = run_naive_rag_query(collection_name, user_message, rag_retrieval_size, rag_prefix, rag_task, rag_few_shot) new_input = rag_context + "\nquestion: " + user_message + suffix + "answer:" else: new_input = "\nquestion: " + user_message + suffix + "answer:" tokenized_new_input = tokenizer(new_input, return_tensors="pt").to(device) eos_block = torch.full((1, compressed_length), tokenizer.eos_token_id, device=device, dtype=torch.long) new_input_ids = torch.cat([eos_block, tokenized_new_input["input_ids"]], dim=-1) new_attention_mask = torch.cat([torch.ones((1, compressed_length), device=device), tokenized_new_input["attention_mask"]], dim=-1) print("New input is: ", new_input) streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=new_input_ids, attention_mask=new_attention_mask, past_key_values=past_key_values, streamer=streamer, use_cache=True, max_new_tokens=1024, num_beams=1, do_sample=False, temperature=1.0, top_p=1.0, top_k=None, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() full_output = "" for text in streamer: full_output += text time.sleep(0.05) yield full_output return full_output def update_token_breakdown(token_count, retrieval_slider, global_local_value): retrieval_context_length = int(token_count / retrieval_slider) # percentage = int(global_local_value.replace('%', '')) percentage = 0 if global_local_value == "RAG" else 100 rag_tokens = int(retrieval_context_length * (1.0 - (percentage / 100))) kv_tokens = retrieval_context_length - rag_tokens return f"Token Breakdown: {kv_tokens} tokens (KV compression), {rag_tokens} tokens (RAG retrieval)" ########################################################################## # Gradio Interface ########################################################################## CSS = """ .main-container { display: flex; align-items: stretch; } .upload-section, .chatbot-container { display: flex; flex-direction: column; height: 100%; overflow-y: auto; } .upload-section { padding: 10px; border: 2px dashed #ccc; border-radius: 10px; } .upload-button { background: #34c759 !important; color: white !important; border-radius: 25px !important; } .chatbot-container { margin-top: 0; } .status-output { margin-top: 10px; font-size: 14px; } .processing-info { margin-top: 5px; font-size: 12px; color: #666; } .info-container { margin-top: 10px; padding: 10px; border-radius: 5px; } .file-list { margin-top: 0; max-height: 200px; overflow-y: auto; padding: 5px; border: 1px solid #eee; border-radius: 5px; } .stats-box { margin-top: 10px; padding: 10px; border-radius: 5px; font-size: 12px; } .submit-btn { background: #1a73e8 !important; color: white !important; border-radius: 25px !important; margin-left: 10px; padding: 5px 10px; font-size: 16px; } .input-row { display: flex; align-items: center; } """ def reset_chat_state(): return gr.update(value="Document not compressed yet. Please compress the document to enable chat."), False with gr.Blocks(css=CSS, theme=gr.themes.Soft(font=["Arial", gr.themes.GoogleFont("Inconsolata"), "sans-serif"])) as demo: # gr.HTML("