muryshev's picture
Добавлен токенизатор для корректной обрезки запроса.
2ccde67
raw
history blame
8.47 kB
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import json
import re
import os
from llm.common import LlmParams, LlmPredictParams
from llm.deepinfra_api import DeepInfraApi
from llm import prompts
from prompts import gettable
from dotenv import load_dotenv
import uvicorn
# Загрузка переменных окружения из файла .env
load_dotenv()
LLM_API_URL = os.getenv("LLM_API_URL", "https://api.deepinfra.com")
LLM_API_KEY = os.getenv("DEEPINFRA_API_KEY", "")
LLM_NAME = os.getenv("LLM_NAME", "meta-llama/Llama-3.3-70B-Instruct-Turbo")
TOKENIZER_NAME = os.getenv("TOKENIZER_NAME", "unsloth/Llama-3.3-70B-Instruct")
default_llm_params = LlmParams(
url=LLM_API_URL,
api_key=LLM_API_KEY,
model=LLM_NAME,
tokenizer=TOKENIZER_NAME,
context_length=130000,
predict_params=LlmPredictParams(
temperature=0.15, top_p=0.95, min_p=0.05, seed=42,
repetition_penalty=1.2, presence_penalty=1.1, n_predict=6000
)
)
llm_api = DeepInfraApi(default_llm_params)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"]
)
class TextRequest(BaseModel):
text: str
projects: list[str] = []
async def trim_prompt(prompt: str, system_prompt: str):
result = await llm_api.tokenize(prompt+system_prompt)
result_system = await llm_api.tokenize(system_prompt)
# в случае ошибки при токенизации, вернем исходную строку безопасной длины
if result["result"] is None or result_system is None:
return prompt[llm_api.params.context_length / 3]
#вероятно, часть уходит на форматирование чата, надо проверить
max_length = result["max_length"] - len(result["result"]) - len(result_system["result"]) - llm_api.params.predict_params.n_predict
detokenized_str = await llm_api.detokenize(result["result"][:max_length])
# в случае ошибки при детокенизации, вернем исходную строку безопасной длины
if detokenized_str is None:
return prompt[llm_api.params.context_length / 3]
return detokenized_str
def validate_json_format(data):
""" Проверяет, соответствует ли JSON ожидаемому формату. """
if not isinstance(data, list):
return "JSON должен быть списком объектов."
# Возможно, в дальнейшем стоит описать менее детально, пока так для отладки
for item in data:
if not isinstance(item, dict):
return "Элементы списка должны быть объектами (dict)."
if "name" not in item or "data" not in item:
return "Каждый объект должен содержать ключи 'name' и 'data'."
if not isinstance(item["name"], str):
return "'name' должен быть строкой."
if not isinstance(item["data"], dict):
return "'data' должен быть объектом (dict)."
if "columns" not in item["data"] or "rows" not in item["data"]:
return "'data' должен содержать 'columns' и 'rows'."
if not isinstance(item["data"]["columns"], list) or not all(isinstance(col, str) for col in item["data"]["columns"]):
return "'columns' должен быть списком строк."
if not isinstance(item["data"]["rows"], list) or not all(isinstance(row, list) for row in item["data"]["rows"]):
return "'rows' должен быть списком списков."
return None # Ошибок нет
@app.post("/extracttable")
async def extracttable_route(request: TextRequest):
prompt = gettable.USER_PROMPT.format(query=request.text)
system_prompt=gettable.SYSTEM_PROMPT
prompt = await trim_prompt(prompt, system_prompt)
response = await llm_api.predict(prompt, system_prompt=system_prompt)
result = {"response": None, "error": None, "raw": response}
if "JSON" not in response:
result["error"] = "Строка не содержит 'JSON'"
return result
prefix, json_str = response.split("JSON", 1)
json_str = json_str.strip()
if not json_str:
result["error"] = "После 'JSON' отсутствует JSON"
return result
try:
result["response"] = json.loads(json_str)
result["raw"] = prefix.strip()
validation_error = validate_json_format(result["response"])
if validation_error:
result["error"] = validation_error
else:
result["response"] = result["response"]
result["raw"] = prefix.strip()
except json.JSONDecodeError as e:
result["error"] = f"Ошибка декодирования JSON: {e}"
return result
@app.get("/health")
def health():
return {"status": "ok"}
async def generate_response(prompt):
prompt = await trim_prompt(prompt)
return await llm_api.predict(prompt)
@app.post("/getsummary")
async def getsummary_route(request: TextRequest):
return {"result": await generate_response(prompts.GET_SUMMARY.format(text=request.text))}
@app.post("/cleantext")
async def cleantext_route(request: TextRequest):
return {"result": await generate_response(prompts.CLEAN_TEXT.format(text=request.text))}
@app.post("/getfollowup")
async def getfollowup_route(request: TextRequest):
return {"result": await generate_response(prompts.GET_FOLLOWUP.format(text=request.text))}
@app.post("/getagenda")
async def getagenda_route(request: TextRequest):
return {"result": await generate_response(prompts.GET_AGENDA.format(text=request.text))}
@app.post("/gethighlights")
async def gethighlights_route(request: TextRequest):
return {"result": await generate_response(prompts.GET_HIGHLIGHTS.format(text=request.text))}
@app.post("/getprojectinfo")
async def getprojectinfo_route(request: TextRequest):
prompts_list = [
prompts.GET_PROJECT_INFO_NAMES.format(text=request.text),
prompts.GET_PROJECT_INFO_AGENDA.format(text=request.text)
]
main_info = "\n\n".join([await generate_response(p) for p in prompts_list])
return {"result": main_info.strip().replace("Конец ответа", "").replace('</s>', '')}
@app.post("/getprojectlist")
async def getprojectlist_route(request: TextRequest):
list_of_projects = await generate_response(prompts.GET_PROJECT_LIST.format(text=request.text))
projects = [f"Проект {x}" for x in list_of_projects.split("Проект ") if x][1:]
projects = [p.replace("проект ", "").strip() for p in projects]
real_projects = await generate_response(prompts.GET_PROJECT_LIST_CHECK_PROJECT.format(text=request.text, projects=projects))
real_projects_list = re.findall(r'Да:\s*(.*?)\s*(?:\n\n|$)', real_projects)
return {"result": real_projects_list}
@app.post("/getprojectdetails")
async def getprojectdetails_route(request: TextRequest):
if not request.projects:
raise HTTPException(status_code=400, detail="Проекты не выбраны")
final = {}
for project in request.projects:
prompts_list = [
prompts.GET_PROJECT_DETAILS_AIM.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_VALUE.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_BUDGET.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_ECO_EFFECT.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_DEADLINE.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_NEW_PLAN.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_CONCLUSION.format(text=request.text, project=project),
]
final[project] = "\n\n".join([await generate_response(p) for p in prompts_list])
final[project] = final[project].replace("Конец ответа", "").replace('</s>', '').strip()
return {"result": final}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)