File size: 11,275 Bytes
9e7f97d
1
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"https://huggingface.co/litert-community/Phi-4-mini-instruct/blob/main/phi4_litert.ipynb","timestamp":1741214116376},{"file_id":"https://huggingface.co/litert-community/Phi-4-mini-instruct/blob/main/pphi4_litert.ipynb","timestamp":1741123967537}],"gpuType":"T4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"code","source":[],"metadata":{"id":"zBuCEnuUTMvC"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["#Install dependencies"],"metadata":{"id":"39AMoCOa1ckc"}},{"cell_type":"code","source":["!pip install ai-edge-litert"],"metadata":{"id":"43tAeO0AZ7zp"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ai_edge_litert import interpreter as interpreter_lib\n","from transformers import AutoTokenizer\n","import numpy as np\n","from collections.abc import Sequence\n","import sys"],"metadata":{"id":"i6PMkMVBPr1p"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Download model files"],"metadata":{"id":"K5okZCTgYpUd"}},{"cell_type":"code","source":["from huggingface_hub import hf_hub_download\n","\n","model_path = hf_hub_download(repo_id=\"litert-community/Phi-4-mini-instruct\", filename=\"phi4_q8_ekv1280.tflite\")"],"metadata":{"id":"3t47HAG2tvc3"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Create LiteRT interpreter and tokenizer"],"metadata":{"id":"n5Xa4s6XhWqk"}},{"cell_type":"code","source":["interpreter = interpreter_lib.InterpreterWithCustomOps(\n","    custom_op_registerers=[\"pywrap_genai_ops.GenAIOpsRegisterer\"],\n","    model_path=model_path,\n","    num_threads=2,\n","    experimental_default_delegate_latest_features=True)\n","tokenizer = AutoTokenizer.from_pretrained(\"microsoft/Phi-4-mini-instruct\")"],"metadata":{"id":"Rvdn3EIZhaQn"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Create pipeline with LiteRT models"],"metadata":{"id":"AM6rDABTXt2F"}},{"cell_type":"code","source":["class LiteRTLlmPipeline:\n","\n","  def __init__(self, interpreter, tokenizer):\n","    \"\"\"Initializes the pipeline.\"\"\"\n","    self._interpreter = interpreter\n","    self._tokenizer = tokenizer\n","\n","    self._prefill_runner = None\n","    self._decode_runner = self._interpreter.get_signature_runner(\"decode\")\n","\n","\n","  def _init_prefill_runner(self, num_input_tokens: int):\n","    \"\"\"Initializes all the variables related to the prefill runner.\n","\n","    This method initializes the following variables:\n","      - self._prefill_runner: The prefill runner based on the input size.\n","      - self._max_seq_len: The maximum sequence length supported by the model.\n","      - self._max_kv_cache_seq_len: The maximum sequence length supported by the\n","        KV cache.\n","\n","    Args:\n","      num_input_tokens: The number of input tokens.\n","    \"\"\"\n","    if not self._interpreter:\n","      raise ValueError(\"Interpreter is not initialized.\")\n","\n","    # Prefill runner related variables will be initialized in `predict_text` and\n","    # `compute_log_likelihood`.\n","    self._prefill_runner = self._get_prefill_runner(num_input_tokens)\n","    # input_token_shape has shape (batch, max_seq_len)\n","    input_token_shape = self._prefill_runner.get_input_details()[\"tokens\"][\n","        \"shape\"\n","    ]\n","    if len(input_token_shape) == 1:\n","      self._max_seq_len = input_token_shape[0]\n","    else:\n","      self._max_seq_len = input_token_shape[1]\n","\n","    # kv cache input has shape [batch=1, seq_len, num_heads, dim].\n","    kv_cache_shape = self._prefill_runner.get_input_details()[\"kv_cache_k_0\"][\n","        \"shape\"\n","    ]\n","    self._max_kv_cache_seq_len = kv_cache_shape[1]\n","\n","  def _init_kv_cache(self) -> dict[str, np.ndarray]:\n","    if self._prefill_runner is None:\n","      raise ValueError(\"Prefill runner is not initialized.\")\n","    kv_cache = {}\n","    for input_key in self._prefill_runner.get_input_details().keys():\n","      if \"kv_cache\" in input_key:\n","        kv_cache[input_key] = np.zeros(\n","            self._prefill_runner.get_input_details()[input_key][\"shape\"],\n","            dtype=np.float32,\n","        )\n","        kv_cache[input_key] = np.zeros(\n","            self._prefill_runner.get_input_details()[input_key][\"shape\"],\n","            dtype=np.float32,\n","        )\n","    return kv_cache\n","\n","  def _get_prefill_runner(self, num_input_tokens: int) :\n","    \"\"\"Gets the prefill runner with the best suitable input size.\n","\n","    Args:\n","      num_input_tokens: The number of input tokens.\n","\n","    Returns:\n","      The prefill runner with the smallest input size.\n","    \"\"\"\n","    best_signature = None\n","    delta = sys.maxsize\n","    max_prefill_len = -1\n","    for key in self._interpreter.get_signature_list().keys():\n","      if \"prefill\" not in key:\n","        continue\n","      input_pos = self._interpreter.get_signature_runner(key).get_input_details()[\n","          \"input_pos\"\n","      ]\n","      # input_pos[\"shape\"] has shape (max_seq_len, )\n","      seq_size = input_pos[\"shape\"][0]\n","      max_prefill_len = max(max_prefill_len, seq_size)\n","      if num_input_tokens <= seq_size and seq_size - num_input_tokens < delta:\n","        delta = seq_size - num_input_tokens\n","        best_signature = key\n","    if best_signature is None:\n","      raise ValueError(\n","          \"The largest prefill length supported is %d, but we have %d number of input tokens\"\n","          %(max_prefill_len, num_input_tokens)\n","      )\n","    return self._interpreter.get_signature_runner(best_signature)\n","\n","  def _run_prefill(\n","      self, prefill_token_ids: Sequence[int],\n","  ) -> dict[str, np.ndarray]:\n","    \"\"\"Runs prefill and returns the kv cache.\n","\n","    Args:\n","      prefill_token_ids: The token ids of the prefill input.\n","\n","    Returns:\n","      The updated kv cache.\n","    \"\"\"\n","    if not self._prefill_runner:\n","      raise ValueError(\"Prefill runner is not initialized.\")\n","    prefill_token_length = len(prefill_token_ids)\n","    if prefill_token_length == 0:\n","      return self._init_kv_cache()\n","\n","    # Prepare the input to be [1, max_seq_len].\n","    input_token_ids = [0] * self._max_seq_len\n","    input_token_ids[:prefill_token_length] = prefill_token_ids\n","    input_token_ids = np.asarray(input_token_ids, dtype=np.int32)\n","    input_token_ids = np.expand_dims(input_token_ids, axis=0)\n","\n","    # Prepare the input position to be [max_seq_len].\n","    input_pos = [0] * self._max_seq_len\n","    input_pos[:prefill_token_length] = range(prefill_token_length)\n","    input_pos = np.asarray(input_pos, dtype=np.int32)\n","\n","    # Initialize kv cache.\n","    prefill_inputs = self._init_kv_cache()\n","    prefill_inputs.update({\n","        \"tokens\": input_token_ids,\n","        \"input_pos\": input_pos,\n","    })\n","    prefill_outputs = self._prefill_runner(**prefill_inputs)\n","    if \"logits\" in prefill_outputs:\n","      # Prefill outputs includes logits and kv cache. We only output kv cache.\n","      prefill_outputs.pop(\"logits\")\n","\n","    return prefill_outputs\n","\n","  def _greedy_sampler(self, logits: np.ndarray) -> int:\n","    return int(np.argmax(logits))\n","\n","\n","  def _run_decode(\n","      self,\n","      start_pos: int,\n","      start_token_id: int,\n","      kv_cache: dict[str, np.ndarray],\n","      max_decode_steps: int,\n","  ) -> str:\n","    \"\"\"Runs decode and outputs the token ids from greedy sampler.\n","\n","    Args:\n","      start_pos: The position of the first token of the decode input.\n","      start_token_id: The token id of the first token of the decode input.\n","      kv_cache: The kv cache from the prefill.\n","      max_decode_steps: The max decode steps.\n","\n","    Returns:\n","      The token ids from the greedy sampler.\n","    \"\"\"\n","    next_pos = start_pos\n","    next_token = start_token_id\n","    decode_text = []\n","    decode_inputs = kv_cache\n","\n","    for _ in range(max_decode_steps):\n","      decode_inputs.update({\n","          \"tokens\": np.array([[next_token]], dtype=np.int32),\n","          \"input_pos\": np.array([next_pos], dtype=np.int32),\n","      })\n","      decode_outputs = self._decode_runner(**decode_inputs)\n","      # Output logits has shape (batch=1, 1, vocab_size). We only take the first\n","      # element.\n","      logits = decode_outputs.pop(\"logits\")[0][0]\n","      next_token = self._greedy_sampler(logits)\n","      if next_token == self._tokenizer.eos_token_id:\n","        break\n","      decode_text.append(self._tokenizer.decode(next_token, skip_special_tokens=False))\n","      print(decode_text[-1], end='', flush=True)\n","      # Decode outputs includes logits and kv cache. We already poped out\n","      # logits, so the rest is kv cache. We pass the updated kv cache as input\n","      # to the next decode step.\n","      decode_inputs = decode_outputs\n","      next_pos += 1\n","\n","    print() # print a new line at the end.\n","    return ''.join(decode_text)\n","\n","  def generate(self, prompt: str, max_decode_steps: int | None = None) -> str:\n","    messages=[{ 'role': 'user', 'content': prompt}]\n","    token_ids = self._tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)\n","    # Initialize the prefill runner with the suitable input size.\n","    self._init_prefill_runner(len(token_ids))\n","\n","    # Run prefill.\n","    # Prefill up to the seond to the last token of the prompt, because the last\n","    # token of the prompt will be used to bootstrap decode.\n","    prefill_token_length = len(token_ids) - 1\n","\n","    print('Running prefill')\n","    kv_cache = self._run_prefill(token_ids[:prefill_token_length])\n","    # Run decode.\n","    print('Running decode')\n","    actual_max_decode_steps = self._max_kv_cache_seq_len - prefill_token_length - 1\n","    if max_decode_steps is not None:\n","      actual_max_decode_steps = min(actual_max_decode_steps, max_decode_steps)\n","    decode_text = self._run_decode(\n","        prefill_token_length,\n","        token_ids[prefill_token_length],\n","        kv_cache,\n","        actual_max_decode_steps,\n","    )\n","    return decode_text\n"],"metadata":{"id":"UBSGrHrM4ANm"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Generate text from model"],"metadata":{"id":"dASKx_JtYXwe"}},{"cell_type":"code","source":["# Disclaimer: Model performance demonstrated with the Python API in this notebook is not representative of performance on a local device.\n","pipeline = LiteRTLlmPipeline(interpreter, tokenizer)"],"metadata":{"id":"AZhlDQWg61AL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["prompt = \"What is the capital of France?\"\n","output = pipeline.generate(prompt, max_decode_steps = None)"],"metadata":{"id":"wT9BIiATkjzL"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":[],"metadata":{"id":"bc-isJWNYnir"}}]}