import json from typing import Optional, List import httpx import logging from transformers import AutoTokenizer from llm.common import LlmParams, LlmApi logging.basicConfig( level=logging.DEBUG, format="%(asctime)s - %(message)s", ) class DeepInfraApi(LlmApi): """ Класс для работы с API vllm. """ def __init__(self, params: LlmParams): super().__init__() super().set_params(params) print('Tokenizer initialization.') self.tokenizer = AutoTokenizer.from_pretrained(params.tokenizer if params.tokenizer is not None else params.model) print(f"Tokenizer initialized for model {params.model}.") async def get_models(self) -> List[str]: """ Выполняет GET-запрос к API для получения списка доступных моделей. Возвращает: list[str]: Список идентификаторов моделей. Если произошла ошибка или данные недоступны, возвращается пустой список. Исключения: Все ошибки HTTP-запросов логируются в консоль, но не выбрасываются дальше. """ try: async with httpx.AsyncClient() as client: response = await client.get(f"{self.params.url}/v1/openai/models", headers=super().create_headers()) if response.status_code == 200: json_data = response.json() return [item['id'] for item in json_data.get('data', [])] except httpx.RequestError as error: print('Error fetching models:', error) return [] def create_messages(self, prompt: str, system_prompt: str = None) -> List[dict]: """ Создает сообщения для LLM на основе переданного промпта и системного промпта (если он задан). Args: prompt (str): Пользовательский промпт. Returns: list[dict]: Список сообщений с ролями и содержимым. """ actual_prompt = self.apply_llm_template_to_prompt(prompt) messages = [] if system_prompt is not None: messages.append({"role": "system", "content": system_prompt}) else: if self.params.predict_params and self.params.predict_params.system_prompt: messages.append({"role": "system", "content": self.params.predict_params.system_prompt}) messages.append({"role": "user", "content": actual_prompt}) return messages def apply_llm_template_to_prompt(self, prompt: str) -> str: """ Применяет шаблон LLM к переданному промпту, если он задан. Args: prompt (str): Пользовательский промпт. Returns: str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует). """ actual_prompt = prompt if self.params.template is not None: actual_prompt = self.params.template.replace("{{PROMPT}}", actual_prompt) return actual_prompt async def tokenize(self, prompt: str) -> Optional[dict]: """ Токенизирует входной текстовый промпт. Args: prompt (str): Текст, который нужно токенизировать. Returns: dict: Словарь с токенами и их количеством или None в случае ошибки. """ try: tokens = self.tokenizer.encode(prompt, add_special_tokens=True) return {"result": tokens, "num_tokens": len(tokens), "max_length": self.params.context_length} except Exception as e: print(f"Tokenization error: {e}") return None async def detokenize(self, tokens: List[int]) -> Optional[str]: """ Детокенизирует список токенов обратно в строку. Args: tokens (List[int]): Список токенов, который нужно преобразовать в текст. Returns: str: Восстановленный текст или None в случае ошибки. """ try: text = self.tokenizer.decode(tokens, skip_special_tokens=True) return text except Exception as e: print(f"Detokenization error: {e}") return None async def create_request(self, prompt: str, system_prompt: str = None) -> dict: """ Создает запрос для предсказания на основе параметров LLM. Args: prompt (str): Промпт для запроса. Returns: dict: Словарь с параметрами для выполнения запроса. """ request = { "stream": False, "model": self.params.model, } predict_params = self.params.predict_params if predict_params: if predict_params.stop: non_empty_stop = list(filter(lambda o: o != "", predict_params.stop)) if non_empty_stop: request["stop"] = non_empty_stop if predict_params.n_predict is not None: request["max_tokens"] = int(predict_params.n_predict or 0) request["temperature"] = float(predict_params.temperature or 0) if predict_params.top_k is not None: request["top_k"] = int(predict_params.top_k) if predict_params.top_p is not None: request["top_p"] = float(predict_params.top_p) if predict_params.min_p is not None: request["min_p"] = float(predict_params.min_p) if predict_params.seed is not None: request["seed"] = int(predict_params.seed) if predict_params.n_keep is not None: request["n_keep"] = int(predict_params.n_keep) if predict_params.cache_prompt is not None: request["cache_prompt"] = bool(predict_params.cache_prompt) if predict_params.repeat_penalty is not None: request["repetition_penalty"] = float(predict_params.repeat_penalty) if predict_params.repeat_last_n is not None: request["repeat_last_n"] = int(predict_params.repeat_last_n) if predict_params.presence_penalty is not None: request["presence_penalty"] = float(predict_params.presence_penalty) if predict_params.frequency_penalty is not None: request["frequency_penalty"] = float(predict_params.frequency_penalty) request["messages"] = self.create_messages(prompt, system_prompt) return request async def trim_sources(self, sources: str, user_request: str, system_prompt: str = None) -> dict: raise NotImplementedError("This function is not supported.") async def predict(self, prompt: str, system_prompt: str = None) -> str: """ Выполняет запрос к API и возвращает результат. Args: prompt (str): Входной текст для предсказания. Returns: str: Сгенерированный текст. """ async with httpx.AsyncClient() as client: request = await self.create_request(prompt, system_prompt) response = await client.post(f"{self.params.url}/v1/openai/chat/completions", headers=super().create_headers(), json=request, timeout=httpx.Timeout(connect=5.0, read=60.0, write=180, pool=10)) if response.status_code == 200: return response.json()["choices"][0]["message"]["content"] else: logging.info(f"Request {prompt} failed: status code {response.status_code}") logging.info(response.text) async def trim_prompt(self, prompt: str, system_prompt: str = None): result = await self.tokenize(prompt) result_system = None system_prompt_length = 0 if system_prompt is not None: result_system = await self.tokenize(system_prompt) if result_system is not None: system_prompt_length = len(result_system["result"]) # в случае ошибки при токенизации, вернем исходную строку безопасной длины if result["result"] is None or (system_prompt is not None and result_system is None): return prompt[int(self.params.context_length / 3)] #вероятно, часть уходит на форматирование чата, надо проверить max_length = result["max_length"] - len(result["result"]) - system_prompt_length - self.params.predict_params.n_predict detokenized_str = await self.detokenize(result["result"][:max_length]) # в случае ошибки при детокенизации, вернем исходную строку безопасной длины if detokenized_str is None: return prompt[self.params.context_length / 3] return detokenized_str