muryshev's picture
update
f443975
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import json
import re
import os
from llm.common import LlmParams, LlmPredictParams
from llm.deepinfra_api import DeepInfraApi
from llm import prompts
from prompts import gettable
from dotenv import load_dotenv
import uvicorn
from models.text_request import TextRequest
from llm.llm_api import LlmApi
from routes.summary import router as summary_router
# Загрузка переменных окружения из файла .env
load_dotenv()
llm_api = LlmApi().get_api()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"]
)
app.include_router(summary_router)
async def trim_prompt(prompt: str, system_prompt: str = None):
result = await llm_api.tokenize(prompt)
result_system = None
system_prompt_length = 0
if system_prompt is not None:
result_system = await llm_api.tokenize(system_prompt)
if result_system is not None:
system_prompt_length = len(result_system["result"])
# в случае ошибки при токенизации, вернем исходную строку безопасной длины
if result["result"] is None or (system_prompt is not None and result_system is None):
return prompt[int(llm_api.params.context_length / 3)]
#вероятно, часть уходит на форматирование чата, надо проверить
max_length = result["max_length"] - len(result["result"]) - system_prompt_length - llm_api.params.predict_params.n_predict
detokenized_str = await llm_api.detokenize(result["result"][:max_length])
# в случае ошибки при детокенизации, вернем исходную строку безопасной длины
if detokenized_str is None:
return prompt[llm_api.params.context_length / 3]
return detokenized_str
def validate_json_format(data):
""" Проверяет, соответствует ли JSON ожидаемому формату. """
if not isinstance(data, list):
return "JSON должен быть списком объектов."
# Возможно, в дальнейшем стоит описать менее детально, пока так для отладки
for item in data:
if not isinstance(item, dict):
return "Элементы списка должны быть объектами (dict)."
if "name" not in item or "data" not in item:
return "Каждый объект должен содержать ключи 'name' и 'data'."
if not isinstance(item["name"], str):
return "'name' должен быть строкой."
if not isinstance(item["data"], dict):
return "'data' должен быть объектом (dict)."
if "columns" not in item["data"] or "rows" not in item["data"]:
return "'data' должен содержать 'columns' и 'rows'."
if not isinstance(item["data"]["columns"], list) or not all(isinstance(col, str) for col in item["data"]["columns"]):
return "'columns' должен быть списком строк."
if not isinstance(item["data"]["rows"], list) or not all(isinstance(row, list) for row in item["data"]["rows"]):
return "'rows' должен быть списком списков."
return None # Ошибок нет
@app.post("/extracttable")
async def extracttable_route(request: TextRequest):
prompt = gettable.USER_PROMPT.format(query=request.text)
system_prompt=gettable.SYSTEM_PROMPT
prompt = await trim_prompt(prompt, system_prompt)
response = await llm_api.predict(prompt, system_prompt=system_prompt)
result = {"response": None, "error": None, "raw": response}
if "JSON" not in response:
result["error"] = "Строка не содержит 'JSON'"
return result
prefix, json_str = response.split("JSON", 1)
json_str = json_str.strip()
if not json_str:
result["error"] = "После 'JSON' отсутствует JSON"
return result
try:
result["response"] = json.loads(json_str)
result["raw"] = prefix.strip()
validation_error = validate_json_format(result["response"])
if validation_error:
result["error"] = validation_error
else:
result["response"] = result["response"]
result["raw"] = prefix.strip()
except json.JSONDecodeError as e:
result["error"] = f"Ошибка декодирования JSON: {e}"
return result
@app.get("/health")
def health():
return {"status": "ok"}
async def generate_response(prompt):
prompt = await trim_prompt(prompt)
return await llm_api.predict(prompt)
@app.post("/cleantext")
async def cleantext_route(request: TextRequest):
return {"result": await generate_response(prompts.CLEAN_TEXT.format(text=request.text))}
@app.post("/getfollowup")
async def getfollowup_route(request: TextRequest):
return {"result": await generate_response(prompts.GET_FOLLOWUP.format(text=request.text))}
@app.post("/getagenda")
async def getagenda_route(request: TextRequest):
return {"result": await generate_response(prompts.GET_AGENDA.format(text=request.text))}
@app.post("/gethighlights")
async def gethighlights_route(request: TextRequest):
return {"result": await generate_response(prompts.GET_HIGHLIGHTS.format(text=request.text))}
@app.post("/getprojectinfo")
async def getprojectinfo_route(request: TextRequest):
prompts_list = [
prompts.GET_PROJECT_INFO_NAMES.format(text=request.text),
prompts.GET_PROJECT_INFO_AGENDA.format(text=request.text)
]
main_info = "\n\n".join([await generate_response(p) for p in prompts_list])
return {"result": main_info.strip().replace("Конец ответа", "").replace('</s>', '')}
@app.post("/getprojectlist")
async def getprojectlist_route(request: TextRequest):
list_of_projects = await generate_response(prompts.GET_PROJECT_LIST.format(text=request.text))
projects = [f"Проект {x}" for x in list_of_projects.split("Проект ") if x][1:]
projects = [p.replace("проект ", "").strip() for p in projects]
real_projects = await generate_response(prompts.GET_PROJECT_LIST_CHECK_PROJECT.format(text=request.text, projects=projects))
real_projects_list = re.findall(r'Да:\s*(.*?)\s*(?:\n\n|$)', real_projects)
return {"result": real_projects_list}
@app.post("/getprojectdetails")
async def getprojectdetails_route(request: TextRequest):
if not request.projects:
raise HTTPException(status_code=400, detail="Проекты не выбраны")
final = {}
for project in request.projects:
prompts_list = [
prompts.GET_PROJECT_DETAILS_AIM.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_VALUE.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_BUDGET.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_ECO_EFFECT.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_DEADLINE.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_NEW_PLAN.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_CONCLUSION.format(text=request.text, project=project),
]
final[project] = "\n\n".join([await generate_response(p) for p in prompts_list])
final[project] = final[project].replace("Конец ответа", "").replace('</s>', '').strip()
return {"result": final}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)