{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "source": [ "#Install dependencies" ], "metadata": { "id": "39AMoCOa1ckc" } }, { "cell_type": "code", "source": [ "!pip install ai-edge-litert-nightly" ], "metadata": { "id": "43tAeO0AZ7zp" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from ai_edge_litert import interpreter as interpreter_lib\n", "from transformers import AutoTokenizer\n", "import numpy as np\n", "from collections.abc import Sequence\n", "import sys" ], "metadata": { "id": "i6PMkMVBPr1p" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Download model files" ], "metadata": { "id": "K5okZCTgYpUd" } }, { "cell_type": "code", "source": [ "from huggingface_hub import hf_hub_download\n", "\n", "model_path = hf_hub_download(repo_id=\"litert-community/DeepSeek-R1-Distill-Qwen-1.5B\", filename=\"deepseek_q8_seq128_ekv1280.tflite\")" ], "metadata": { "id": "3t47HAG2tvc3" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Create LiteRT interpreter and tokenizer" ], "metadata": { "id": "n5Xa4s6XhWqk" } }, { "cell_type": "code", "source": [ "interpreter = interpreter_lib.InterpreterWithCustomOps(\n", " custom_op_registerers=[\"pywrap_genai_ops.GenAIOpsRegisterer\"],\n", " model_path=model_path,\n", " num_threads=2,\n", " experimental_default_delegate_latest_features=True)\n", "tokenizer = AutoTokenizer.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\")" ], "metadata": { "id": "Rvdn3EIZhaQn" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Create pipeline with LiteRT models" ], "metadata": { "id": "AM6rDABTXt2F" } }, { "cell_type": "code", "source": [ "\n", "class LiteRTLlmPipeline:\n", "\n", " def __init__(self, interpreter, tokenizer):\n", " \"\"\"Initializes the pipeline.\"\"\"\n", " self._interpreter = interpreter\n", " self._tokenizer = tokenizer\n", "\n", " self._prefill_runner = None\n", " self._decode_runner = self._interpreter.get_signature_runner(\"decode\")\n", "\n", "\n", " def _init_prefill_runner(self, num_input_tokens: int):\n", " \"\"\"Initializes all the variables related to the prefill runner.\n", "\n", " This method initializes the following variables:\n", " - self._prefill_runner: The prefill runner based on the input size.\n", " - self._max_seq_len: The maximum sequence length supported by the model.\n", " - self._max_kv_cache_seq_len: The maximum sequence length supported by the\n", " KV cache.\n", " - self._num_layers: The number of layers in the model.\n", "\n", " Args:\n", " num_input_tokens: The number of input tokens.\n", " \"\"\"\n", "\n", " self._prefill_runner = self._get_prefill_runner(num_input_tokens)\n", " # input_token_shape has shape (batch, max_seq_len)\n", " input_token_shape = self._prefill_runner.get_input_details()[\"tokens\"][\n", " \"shape\"\n", " ]\n", " if len(input_token_shape) == 1:\n", " self._max_seq_len = input_token_shape[0]\n", " else:\n", " self._max_seq_len = input_token_shape[1]\n", "\n", " # kv cache input has shape [batch=1, seq_len, num_heads, dim].\n", " kv_cache_shape = self._prefill_runner.get_input_details()[\"kv_cache_k_0\"][\n", " \"shape\"\n", " ]\n", " self._max_kv_cache_seq_len = kv_cache_shape[1]\n", "\n", " # The two arguments excluded are `tokens` and `input_pos`. Dividing by 2\n", " # because each layer has key and value caches.\n", " self._num_layers = (\n", " len(self._prefill_runner.get_input_details().keys()) - 2\n", " ) // 2\n", "\n", "\n", " def _init_kv_cache(self) -> dict[str, np.ndarray]:\n", " if self._prefill_runner is None:\n", " raise ValueError(\"Prefill runner is not initialized.\")\n", " kv_cache = {}\n", " for i in range(self._num_layers):\n", " kv_cache[f\"kv_cache_k_{i}\"] = np.zeros(\n", " self._prefill_runner.get_input_details()[f\"kv_cache_k_{i}\"][\"shape\"],\n", " dtype=np.float32,\n", " )\n", " kv_cache[f\"kv_cache_v_{i}\"] = np.zeros(\n", " self._prefill_runner.get_input_details()[f\"kv_cache_v_{i}\"][\"shape\"],\n", " dtype=np.float32,\n", " )\n", " return kv_cache\n", "\n", " def _get_prefill_runner(self, num_input_tokens: int) :\n", " \"\"\"Gets the prefill runner with the best suitable input size.\n", "\n", " Args:\n", " num_input_tokens: The number of input tokens.\n", "\n", " Returns:\n", " The prefill runner with the smallest input size.\n", " \"\"\"\n", " best_signature = None\n", " delta = sys.maxsize\n", " max_prefill_len = -1\n", " for key in self._interpreter.get_signature_list().keys():\n", " if \"prefill\" not in key:\n", " continue\n", " input_pos = self._interpreter.get_signature_runner(key).get_input_details()[\n", " \"input_pos\"\n", " ]\n", " # input_pos[\"shape\"] has shape (max_seq_len, )\n", " seq_size = input_pos[\"shape\"][0]\n", " max_prefill_len = max(max_prefill_len, seq_size)\n", " if num_input_tokens <= seq_size and seq_size - num_input_tokens < delta:\n", " delta = seq_size - num_input_tokens\n", " best_signature = key\n", " if best_signature is None:\n", " raise ValueError(\n", " \"The largest prefill length supported is %d, but we have %d number of input tokens\"\n", " %(max_prefill_len, num_input_tokens)\n", " )\n", " return self._interpreter.get_signature_runner(best_signature)\n", "\n", " def _greedy_sampler(self, logits: np.ndarray) -> int:\n", " return int(np.argmax(logits))\n", "\n", " def generate(self, prompt: str, max_decode_steps: int | None = None) -> str:\n", " messages=[{ 'role': 'user', 'content': prompt}]\n", " token_ids = self._tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)\n", " # Initialize the prefill runner with the suitable input size.\n", " self._init_prefill_runner(len(token_ids))\n", "\n", " actual_max_decode_steps = self._max_kv_cache_seq_len - len(token_ids)\n", " if max_decode_steps is not None:\n", " actual_max_decode_steps = min(actual_max_decode_steps, max_decode_steps)\n", "\n", " input_token_ids = [0] * self._max_seq_len\n", " input_token_ids[:len(token_ids)] = token_ids\n", " model_inputs = self._init_kv_cache()\n", " model_inputs.update({\n", " \"tokens\": np.asarray([input_token_ids], dtype=np.int32),\n", " \"input_pos\": np.arange(self._max_seq_len, dtype=np.int32),\n", " })\n", " decode_text = []\n", " decode_step = 0\n", " print('Running prefill')\n", " for step in range(actual_max_decode_steps+1):\n", " signature_runner = self._prefill_runner if step == 0 else self._decode_runner\n", " model_outputs = signature_runner(**model_inputs)\n", " # At prefill stage, output logits has shape (batch=1, seq_size, vocab_size)\n", " # At decode stage, output logits has shape (batch=1, 1, vocab_size).\n", " selected_logit = len(token_ids)-1 if step == 0 else 0\n", " logits = model_outputs.pop(\"logits\")[0][selected_logit]\n", "\n", " if step == 0:\n", " print('Running decode')\n", "\n", " # Decode text output.\n", " next_token = self._greedy_sampler(logits)\n", " if next_token == self._tokenizer.eos_token_id:\n", " break\n", " decode_text.append(self._tokenizer.decode(next_token, skip_special_tokens=False))\n", " print(decode_text[-1], end='', flush=True)\n", " # The rest of the outputs is the updated kv cache.\n", " model_inputs = model_outputs\n", " model_inputs.update({\n", " \"tokens\": np.array([[next_token]], dtype=np.int32),\n", " \"input_pos\": np.array([decode_step + len(token_ids)], dtype=np.int32),})\n", " decode_step += 1\n", "\n", "\n", "\n", " print() # print a new line at the end.\n", " return ''.join(decode_text)\n" ], "metadata": { "id": "UBSGrHrM4ANm" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Generate text from model" ], "metadata": { "id": "dASKx_JtYXwe" } }, { "cell_type": "code", "source": [ "# Disclaimer: Model performance demonstrated with the Python API in this notebook is not representative of performance on a local device.\n", "pipeline = LiteRTLlmPipeline(interpreter, tokenizer)" ], "metadata": { "id": "AZhlDQWg61AL" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "prompt = \"what is 8 mod 5\"\n", "output = pipeline.generate(prompt, max_decode_steps = None)" ], "metadata": { "id": "wT9BIiATkjzL" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [], "metadata": { "id": "GNzDBxDFEuAJ" }, "execution_count": null, "outputs": [] } ] }