muryshev's picture
update
9347ad1
raw
history blame
7.95 kB
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import json
import re
import os
from llm.common import LlmParams, LlmPredictParams
from llm.deepinfra_api import DeepInfraApi
from llm import prompts
from prompts import gettable
from dotenv import load_dotenv
import uvicorn
from models.text_request import TextRequest
from llm.llm_api import LlmApi
from routes.summary import router as summary_router
# Загрузка переменных окружения из файла .env
load_dotenv()
llm_api = LlmApi().get_api()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"]
)
app.include_router(summary_router)
async def trim_prompt(prompt: str, system_prompt: str):
result = await llm_api.tokenize(prompt+system_prompt)
result_system = None
system_prompt_length = 0
if system_prompt is not None:
result_system = await llm_api.tokenize(system_prompt)
if result_system is not None:
system_prompt_length = len(result_system["result"])
# в случае ошибки при токенизации, вернем исходную строку безопасной длины
if result["result"] is None or result_system is None:
return prompt[llm_api.params.context_length / 3]
#вероятно, часть уходит на форматирование чата, надо проверить
max_length = result["max_length"] - len(result["result"]) - system_prompt_length - llm_api.params.predict_params.n_predict
detokenized_str = await llm_api.detokenize(result["result"][:max_length])
# в случае ошибки при детокенизации, вернем исходную строку безопасной длины
if detokenized_str is None:
return prompt[llm_api.params.context_length / 3]
return detokenized_str
def validate_json_format(data):
""" Проверяет, соответствует ли JSON ожидаемому формату. """
if not isinstance(data, list):
return "JSON должен быть списком объектов."
# Возможно, в дальнейшем стоит описать менее детально, пока так для отладки
for item in data:
if not isinstance(item, dict):
return "Элементы списка должны быть объектами (dict)."
if "name" not in item or "data" not in item:
return "Каждый объект должен содержать ключи 'name' и 'data'."
if not isinstance(item["name"], str):
return "'name' должен быть строкой."
if not isinstance(item["data"], dict):
return "'data' должен быть объектом (dict)."
if "columns" not in item["data"] or "rows" not in item["data"]:
return "'data' должен содержать 'columns' и 'rows'."
if not isinstance(item["data"]["columns"], list) or not all(isinstance(col, str) for col in item["data"]["columns"]):
return "'columns' должен быть списком строк."
if not isinstance(item["data"]["rows"], list) or not all(isinstance(row, list) for row in item["data"]["rows"]):
return "'rows' должен быть списком списков."
return None # Ошибок нет
@app.post("/extracttable")
async def extracttable_route(request: TextRequest):
prompt = gettable.USER_PROMPT.format(query=request.text)
system_prompt=gettable.SYSTEM_PROMPT
prompt = await trim_prompt(prompt, system_prompt)
response = await llm_api.predict(prompt, system_prompt=system_prompt)
result = {"response": None, "error": None, "raw": response}
if "JSON" not in response:
result["error"] = "Строка не содержит 'JSON'"
return result
prefix, json_str = response.split("JSON", 1)
json_str = json_str.strip()
if not json_str:
result["error"] = "После 'JSON' отсутствует JSON"
return result
try:
result["response"] = json.loads(json_str)
result["raw"] = prefix.strip()
validation_error = validate_json_format(result["response"])
if validation_error:
result["error"] = validation_error
else:
result["response"] = result["response"]
result["raw"] = prefix.strip()
except json.JSONDecodeError as e:
result["error"] = f"Ошибка декодирования JSON: {e}"
return result
@app.get("/health")
def health():
return {"status": "ok"}
async def generate_response(prompt):
prompt = await trim_prompt(prompt)
return await llm_api.predict(prompt)
@app.post("/cleantext")
async def cleantext_route(request: TextRequest):
return {"result": await generate_response(prompts.CLEAN_TEXT.format(text=request.text))}
@app.post("/getfollowup")
async def getfollowup_route(request: TextRequest):
return {"result": await generate_response(prompts.GET_FOLLOWUP.format(text=request.text))}
@app.post("/getagenda")
async def getagenda_route(request: TextRequest):
return {"result": await generate_response(prompts.GET_AGENDA.format(text=request.text))}
@app.post("/gethighlights")
async def gethighlights_route(request: TextRequest):
return {"result": await generate_response(prompts.GET_HIGHLIGHTS.format(text=request.text))}
@app.post("/getprojectinfo")
async def getprojectinfo_route(request: TextRequest):
prompts_list = [
prompts.GET_PROJECT_INFO_NAMES.format(text=request.text),
prompts.GET_PROJECT_INFO_AGENDA.format(text=request.text)
]
main_info = "\n\n".join([await generate_response(p) for p in prompts_list])
return {"result": main_info.strip().replace("Конец ответа", "").replace('</s>', '')}
@app.post("/getprojectlist")
async def getprojectlist_route(request: TextRequest):
list_of_projects = await generate_response(prompts.GET_PROJECT_LIST.format(text=request.text))
projects = [f"Проект {x}" for x in list_of_projects.split("Проект ") if x][1:]
projects = [p.replace("проект ", "").strip() for p in projects]
real_projects = await generate_response(prompts.GET_PROJECT_LIST_CHECK_PROJECT.format(text=request.text, projects=projects))
real_projects_list = re.findall(r'Да:\s*(.*?)\s*(?:\n\n|$)', real_projects)
return {"result": real_projects_list}
@app.post("/getprojectdetails")
async def getprojectdetails_route(request: TextRequest):
if not request.projects:
raise HTTPException(status_code=400, detail="Проекты не выбраны")
final = {}
for project in request.projects:
prompts_list = [
prompts.GET_PROJECT_DETAILS_AIM.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_VALUE.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_BUDGET.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_ECO_EFFECT.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_DEADLINE.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_NEW_PLAN.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_CONCLUSION.format(text=request.text, project=project),
]
final[project] = "\n\n".join([await generate_response(p) for p in prompts_list])
final[project] = final[project].replace("Конец ответа", "").replace('</s>', '').strip()
return {"result": final}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)