Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
import json | |
import re | |
import os | |
from llm.common import LlmParams, LlmPredictParams | |
from llm.deepinfra_api import DeepInfraApi | |
from llm import prompts | |
from prompts import gettable | |
from dotenv import load_dotenv | |
import uvicorn | |
from models.text_request import TextRequest | |
from llm.llm_api import LlmApi | |
from routes.summary import router as summary_router | |
# Загрузка переменных окружения из файла .env | |
load_dotenv() | |
llm_api = LlmApi().get_api() | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"] | |
) | |
app.include_router(summary_router) | |
async def trim_prompt(prompt: str, system_prompt: str = None): | |
result = await llm_api.tokenize(prompt) | |
result_system = None | |
system_prompt_length = 0 | |
if system_prompt is not None: | |
result_system = await llm_api.tokenize(system_prompt) | |
if result_system is not None: | |
system_prompt_length = len(result_system["result"]) | |
# в случае ошибки при токенизации, вернем исходную строку безопасной длины | |
if result["result"] is None or (system_prompt is not None and result_system is None): | |
return prompt[int(llm_api.params.context_length / 3)] | |
#вероятно, часть уходит на форматирование чата, надо проверить | |
max_length = result["max_length"] - len(result["result"]) - system_prompt_length - llm_api.params.predict_params.n_predict | |
detokenized_str = await llm_api.detokenize(result["result"][:max_length]) | |
# в случае ошибки при детокенизации, вернем исходную строку безопасной длины | |
if detokenized_str is None: | |
return prompt[llm_api.params.context_length / 3] | |
return detokenized_str | |
def validate_json_format(data): | |
""" Проверяет, соответствует ли JSON ожидаемому формату. """ | |
if not isinstance(data, list): | |
return "JSON должен быть списком объектов." | |
# Возможно, в дальнейшем стоит описать менее детально, пока так для отладки | |
for item in data: | |
if not isinstance(item, dict): | |
return "Элементы списка должны быть объектами (dict)." | |
if "name" not in item or "data" not in item: | |
return "Каждый объект должен содержать ключи 'name' и 'data'." | |
if not isinstance(item["name"], str): | |
return "'name' должен быть строкой." | |
if not isinstance(item["data"], dict): | |
return "'data' должен быть объектом (dict)." | |
if "columns" not in item["data"] or "rows" not in item["data"]: | |
return "'data' должен содержать 'columns' и 'rows'." | |
if not isinstance(item["data"]["columns"], list) or not all(isinstance(col, str) for col in item["data"]["columns"]): | |
return "'columns' должен быть списком строк." | |
if not isinstance(item["data"]["rows"], list) or not all(isinstance(row, list) for row in item["data"]["rows"]): | |
return "'rows' должен быть списком списков." | |
return None # Ошибок нет | |
async def extracttable_route(request: TextRequest): | |
prompt = gettable.USER_PROMPT.format(query=request.text) | |
system_prompt=gettable.SYSTEM_PROMPT | |
prompt = await trim_prompt(prompt, system_prompt) | |
response = await llm_api.predict(prompt, system_prompt=system_prompt) | |
result = {"response": None, "error": None, "raw": response} | |
if "JSON" not in response: | |
result["error"] = "Строка не содержит 'JSON'" | |
return result | |
prefix, json_str = response.split("JSON", 1) | |
json_str = json_str.strip() | |
if not json_str: | |
result["error"] = "После 'JSON' отсутствует JSON" | |
return result | |
try: | |
result["response"] = json.loads(json_str) | |
result["raw"] = prefix.strip() | |
validation_error = validate_json_format(result["response"]) | |
if validation_error: | |
result["error"] = validation_error | |
else: | |
result["response"] = result["response"] | |
result["raw"] = prefix.strip() | |
except json.JSONDecodeError as e: | |
result["error"] = f"Ошибка декодирования JSON: {e}" | |
return result | |
def health(): | |
return {"status": "ok"} | |
async def generate_response(prompt): | |
prompt = await trim_prompt(prompt) | |
return await llm_api.predict(prompt) | |
async def cleantext_route(request: TextRequest): | |
return {"result": await generate_response(prompts.CLEAN_TEXT.format(text=request.text))} | |
async def getfollowup_route(request: TextRequest): | |
return {"result": await generate_response(prompts.GET_FOLLOWUP.format(text=request.text))} | |
async def getagenda_route(request: TextRequest): | |
return {"result": await generate_response(prompts.GET_AGENDA.format(text=request.text))} | |
async def gethighlights_route(request: TextRequest): | |
return {"result": await generate_response(prompts.GET_HIGHLIGHTS.format(text=request.text))} | |
async def getprojectinfo_route(request: TextRequest): | |
prompts_list = [ | |
prompts.GET_PROJECT_INFO_NAMES.format(text=request.text), | |
prompts.GET_PROJECT_INFO_AGENDA.format(text=request.text) | |
] | |
main_info = "\n\n".join([await generate_response(p) for p in prompts_list]) | |
return {"result": main_info.strip().replace("Конец ответа", "").replace('</s>', '')} | |
async def getprojectlist_route(request: TextRequest): | |
list_of_projects = await generate_response(prompts.GET_PROJECT_LIST.format(text=request.text)) | |
projects = [f"Проект {x}" for x in list_of_projects.split("Проект ") if x][1:] | |
projects = [p.replace("проект ", "").strip() for p in projects] | |
real_projects = await generate_response(prompts.GET_PROJECT_LIST_CHECK_PROJECT.format(text=request.text, projects=projects)) | |
real_projects_list = re.findall(r'Да:\s*(.*?)\s*(?:\n\n|$)', real_projects) | |
return {"result": real_projects_list} | |
async def getprojectdetails_route(request: TextRequest): | |
if not request.projects: | |
raise HTTPException(status_code=400, detail="Проекты не выбраны") | |
final = {} | |
for project in request.projects: | |
prompts_list = [ | |
prompts.GET_PROJECT_DETAILS_AIM.format(text=request.text, project=project), | |
prompts.GET_PROJECT_DETAILS_VALUE.format(text=request.text, project=project), | |
prompts.GET_PROJECT_DETAILS_BUDGET.format(text=request.text, project=project), | |
prompts.GET_PROJECT_DETAILS_ECO_EFFECT.format(text=request.text, project=project), | |
prompts.GET_PROJECT_DETAILS_DEADLINE.format(text=request.text, project=project), | |
prompts.GET_PROJECT_DETAILS_NEW_PLAN.format(text=request.text, project=project), | |
prompts.GET_PROJECT_DETAILS_CONCLUSION.format(text=request.text, project=project), | |
] | |
final[project] = "\n\n".join([await generate_response(p) for p in prompts_list]) | |
final[project] = final[project].replace("Конец ответа", "").replace('</s>', '').strip() | |
return {"result": final} | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |