File size: 11,275 Bytes
9e7f97d |
1 |
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"https://huggingface.co/litert-community/Phi-4-mini-instruct/blob/main/phi4_litert.ipynb","timestamp":1741214116376},{"file_id":"https://huggingface.co/litert-community/Phi-4-mini-instruct/blob/main/pphi4_litert.ipynb","timestamp":1741123967537}],"gpuType":"T4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"code","source":[],"metadata":{"id":"zBuCEnuUTMvC"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["#Install dependencies"],"metadata":{"id":"39AMoCOa1ckc"}},{"cell_type":"code","source":["!pip install ai-edge-litert"],"metadata":{"id":"43tAeO0AZ7zp"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from ai_edge_litert import interpreter as interpreter_lib\n","from transformers import AutoTokenizer\n","import numpy as np\n","from collections.abc import Sequence\n","import sys"],"metadata":{"id":"i6PMkMVBPr1p"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Download model files"],"metadata":{"id":"K5okZCTgYpUd"}},{"cell_type":"code","source":["from huggingface_hub import hf_hub_download\n","\n","model_path = hf_hub_download(repo_id=\"litert-community/Phi-4-mini-instruct\", filename=\"phi4_q8_ekv1280.tflite\")"],"metadata":{"id":"3t47HAG2tvc3"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Create LiteRT interpreter and tokenizer"],"metadata":{"id":"n5Xa4s6XhWqk"}},{"cell_type":"code","source":["interpreter = interpreter_lib.InterpreterWithCustomOps(\n"," custom_op_registerers=[\"pywrap_genai_ops.GenAIOpsRegisterer\"],\n"," model_path=model_path,\n"," num_threads=2,\n"," experimental_default_delegate_latest_features=True)\n","tokenizer = AutoTokenizer.from_pretrained(\"microsoft/Phi-4-mini-instruct\")"],"metadata":{"id":"Rvdn3EIZhaQn"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Create pipeline with LiteRT models"],"metadata":{"id":"AM6rDABTXt2F"}},{"cell_type":"code","source":["class LiteRTLlmPipeline:\n","\n"," def __init__(self, interpreter, tokenizer):\n"," \"\"\"Initializes the pipeline.\"\"\"\n"," self._interpreter = interpreter\n"," self._tokenizer = tokenizer\n","\n"," self._prefill_runner = None\n"," self._decode_runner = self._interpreter.get_signature_runner(\"decode\")\n","\n","\n"," def _init_prefill_runner(self, num_input_tokens: int):\n"," \"\"\"Initializes all the variables related to the prefill runner.\n","\n"," This method initializes the following variables:\n"," - self._prefill_runner: The prefill runner based on the input size.\n"," - self._max_seq_len: The maximum sequence length supported by the model.\n"," - self._max_kv_cache_seq_len: The maximum sequence length supported by the\n"," KV cache.\n","\n"," Args:\n"," num_input_tokens: The number of input tokens.\n"," \"\"\"\n"," if not self._interpreter:\n"," raise ValueError(\"Interpreter is not initialized.\")\n","\n"," # Prefill runner related variables will be initialized in `predict_text` and\n"," # `compute_log_likelihood`.\n"," self._prefill_runner = self._get_prefill_runner(num_input_tokens)\n"," # input_token_shape has shape (batch, max_seq_len)\n"," input_token_shape = self._prefill_runner.get_input_details()[\"tokens\"][\n"," \"shape\"\n"," ]\n"," if len(input_token_shape) == 1:\n"," self._max_seq_len = input_token_shape[0]\n"," else:\n"," self._max_seq_len = input_token_shape[1]\n","\n"," # kv cache input has shape [batch=1, seq_len, num_heads, dim].\n"," kv_cache_shape = self._prefill_runner.get_input_details()[\"kv_cache_k_0\"][\n"," \"shape\"\n"," ]\n"," self._max_kv_cache_seq_len = kv_cache_shape[1]\n","\n"," def _init_kv_cache(self) -> dict[str, np.ndarray]:\n"," if self._prefill_runner is None:\n"," raise ValueError(\"Prefill runner is not initialized.\")\n"," kv_cache = {}\n"," for input_key in self._prefill_runner.get_input_details().keys():\n"," if \"kv_cache\" in input_key:\n"," kv_cache[input_key] = np.zeros(\n"," self._prefill_runner.get_input_details()[input_key][\"shape\"],\n"," dtype=np.float32,\n"," )\n"," kv_cache[input_key] = np.zeros(\n"," self._prefill_runner.get_input_details()[input_key][\"shape\"],\n"," dtype=np.float32,\n"," )\n"," return kv_cache\n","\n"," def _get_prefill_runner(self, num_input_tokens: int) :\n"," \"\"\"Gets the prefill runner with the best suitable input size.\n","\n"," Args:\n"," num_input_tokens: The number of input tokens.\n","\n"," Returns:\n"," The prefill runner with the smallest input size.\n"," \"\"\"\n"," best_signature = None\n"," delta = sys.maxsize\n"," max_prefill_len = -1\n"," for key in self._interpreter.get_signature_list().keys():\n"," if \"prefill\" not in key:\n"," continue\n"," input_pos = self._interpreter.get_signature_runner(key).get_input_details()[\n"," \"input_pos\"\n"," ]\n"," # input_pos[\"shape\"] has shape (max_seq_len, )\n"," seq_size = input_pos[\"shape\"][0]\n"," max_prefill_len = max(max_prefill_len, seq_size)\n"," if num_input_tokens <= seq_size and seq_size - num_input_tokens < delta:\n"," delta = seq_size - num_input_tokens\n"," best_signature = key\n"," if best_signature is None:\n"," raise ValueError(\n"," \"The largest prefill length supported is %d, but we have %d number of input tokens\"\n"," %(max_prefill_len, num_input_tokens)\n"," )\n"," return self._interpreter.get_signature_runner(best_signature)\n","\n"," def _run_prefill(\n"," self, prefill_token_ids: Sequence[int],\n"," ) -> dict[str, np.ndarray]:\n"," \"\"\"Runs prefill and returns the kv cache.\n","\n"," Args:\n"," prefill_token_ids: The token ids of the prefill input.\n","\n"," Returns:\n"," The updated kv cache.\n"," \"\"\"\n"," if not self._prefill_runner:\n"," raise ValueError(\"Prefill runner is not initialized.\")\n"," prefill_token_length = len(prefill_token_ids)\n"," if prefill_token_length == 0:\n"," return self._init_kv_cache()\n","\n"," # Prepare the input to be [1, max_seq_len].\n"," input_token_ids = [0] * self._max_seq_len\n"," input_token_ids[:prefill_token_length] = prefill_token_ids\n"," input_token_ids = np.asarray(input_token_ids, dtype=np.int32)\n"," input_token_ids = np.expand_dims(input_token_ids, axis=0)\n","\n"," # Prepare the input position to be [max_seq_len].\n"," input_pos = [0] * self._max_seq_len\n"," input_pos[:prefill_token_length] = range(prefill_token_length)\n"," input_pos = np.asarray(input_pos, dtype=np.int32)\n","\n"," # Initialize kv cache.\n"," prefill_inputs = self._init_kv_cache()\n"," prefill_inputs.update({\n"," \"tokens\": input_token_ids,\n"," \"input_pos\": input_pos,\n"," })\n"," prefill_outputs = self._prefill_runner(**prefill_inputs)\n"," if \"logits\" in prefill_outputs:\n"," # Prefill outputs includes logits and kv cache. We only output kv cache.\n"," prefill_outputs.pop(\"logits\")\n","\n"," return prefill_outputs\n","\n"," def _greedy_sampler(self, logits: np.ndarray) -> int:\n"," return int(np.argmax(logits))\n","\n","\n"," def _run_decode(\n"," self,\n"," start_pos: int,\n"," start_token_id: int,\n"," kv_cache: dict[str, np.ndarray],\n"," max_decode_steps: int,\n"," ) -> str:\n"," \"\"\"Runs decode and outputs the token ids from greedy sampler.\n","\n"," Args:\n"," start_pos: The position of the first token of the decode input.\n"," start_token_id: The token id of the first token of the decode input.\n"," kv_cache: The kv cache from the prefill.\n"," max_decode_steps: The max decode steps.\n","\n"," Returns:\n"," The token ids from the greedy sampler.\n"," \"\"\"\n"," next_pos = start_pos\n"," next_token = start_token_id\n"," decode_text = []\n"," decode_inputs = kv_cache\n","\n"," for _ in range(max_decode_steps):\n"," decode_inputs.update({\n"," \"tokens\": np.array([[next_token]], dtype=np.int32),\n"," \"input_pos\": np.array([next_pos], dtype=np.int32),\n"," })\n"," decode_outputs = self._decode_runner(**decode_inputs)\n"," # Output logits has shape (batch=1, 1, vocab_size). We only take the first\n"," # element.\n"," logits = decode_outputs.pop(\"logits\")[0][0]\n"," next_token = self._greedy_sampler(logits)\n"," if next_token == self._tokenizer.eos_token_id:\n"," break\n"," decode_text.append(self._tokenizer.decode(next_token, skip_special_tokens=False))\n"," print(decode_text[-1], end='', flush=True)\n"," # Decode outputs includes logits and kv cache. We already poped out\n"," # logits, so the rest is kv cache. We pass the updated kv cache as input\n"," # to the next decode step.\n"," decode_inputs = decode_outputs\n"," next_pos += 1\n","\n"," print() # print a new line at the end.\n"," return ''.join(decode_text)\n","\n"," def generate(self, prompt: str, max_decode_steps: int | None = None) -> str:\n"," messages=[{ 'role': 'user', 'content': prompt}]\n"," token_ids = self._tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)\n"," # Initialize the prefill runner with the suitable input size.\n"," self._init_prefill_runner(len(token_ids))\n","\n"," # Run prefill.\n"," # Prefill up to the seond to the last token of the prompt, because the last\n"," # token of the prompt will be used to bootstrap decode.\n"," prefill_token_length = len(token_ids) - 1\n","\n"," print('Running prefill')\n"," kv_cache = self._run_prefill(token_ids[:prefill_token_length])\n"," # Run decode.\n"," print('Running decode')\n"," actual_max_decode_steps = self._max_kv_cache_seq_len - prefill_token_length - 1\n"," if max_decode_steps is not None:\n"," actual_max_decode_steps = min(actual_max_decode_steps, max_decode_steps)\n"," decode_text = self._run_decode(\n"," prefill_token_length,\n"," token_ids[prefill_token_length],\n"," kv_cache,\n"," actual_max_decode_steps,\n"," )\n"," return decode_text\n"],"metadata":{"id":"UBSGrHrM4ANm"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Generate text from model"],"metadata":{"id":"dASKx_JtYXwe"}},{"cell_type":"code","source":["# Disclaimer: Model performance demonstrated with the Python API in this notebook is not representative of performance on a local device.\n","pipeline = LiteRTLlmPipeline(interpreter, tokenizer)"],"metadata":{"id":"AZhlDQWg61AL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["prompt = \"What is the capital of France?\"\n","output = pipeline.generate(prompt, max_decode_steps = None)"],"metadata":{"id":"wT9BIiATkjzL"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":[],"metadata":{"id":"bc-isJWNYnir"}}]} |