Spaces:
Sleeping
Sleeping
File size: 7,980 Bytes
9dd4b6c a9b39d2 cb60f6b ecf21a8 9dd4b6c ecf21a8 9dd4b6c 9347ad1 a9b39d2 ecf21a8 a9b39d2 9347ad1 a9b39d2 9dd4b6c 267da20 9dd4b6c a9b39d2 9347ad1 a9b39d2 f443975 2ccde67 f443975 9347ad1 2ccde67 f443975 2ccde67 9347ad1 2ccde67 9dd4b6c 2ccde67 ecf21a8 9dd4b6c ecf21a8 be938bc ecf21a8 9dd4b6c be938bc ecf21a8 9dd4b6c ecf21a8 be938bc ecf21a8 9dd4b6c ecf21a8 9dd4b6c 2ccde67 ecf21a8 9dd4b6c ecf21a8 9dd4b6c a9b39d2 9dd4b6c 2ccde67 9dd4b6c cad0763 9dd4b6c cad0763 9dd4b6c cad0763 1af689b 9dd4b6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import json
import re
import os
from llm.common import LlmParams, LlmPredictParams
from llm.deepinfra_api import DeepInfraApi
from llm import prompts
from prompts import gettable
from dotenv import load_dotenv
import uvicorn
from models.text_request import TextRequest
from llm.llm_api import LlmApi
from routes.summary import router as summary_router
# Загрузка переменных окружения из файла .env
load_dotenv()
llm_api = LlmApi().get_api()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"]
)
app.include_router(summary_router)
async def trim_prompt(prompt: str, system_prompt: str = None):
result = await llm_api.tokenize(prompt)
result_system = None
system_prompt_length = 0
if system_prompt is not None:
result_system = await llm_api.tokenize(system_prompt)
if result_system is not None:
system_prompt_length = len(result_system["result"])
# в случае ошибки при токенизации, вернем исходную строку безопасной длины
if result["result"] is None or (system_prompt is not None and result_system is None):
return prompt[int(llm_api.params.context_length / 3)]
#вероятно, часть уходит на форматирование чата, надо проверить
max_length = result["max_length"] - len(result["result"]) - system_prompt_length - llm_api.params.predict_params.n_predict
detokenized_str = await llm_api.detokenize(result["result"][:max_length])
# в случае ошибки при детокенизации, вернем исходную строку безопасной длины
if detokenized_str is None:
return prompt[llm_api.params.context_length / 3]
return detokenized_str
def validate_json_format(data):
""" Проверяет, соответствует ли JSON ожидаемому формату. """
if not isinstance(data, list):
return "JSON должен быть списком объектов."
# Возможно, в дальнейшем стоит описать менее детально, пока так для отладки
for item in data:
if not isinstance(item, dict):
return "Элементы списка должны быть объектами (dict)."
if "name" not in item or "data" not in item:
return "Каждый объект должен содержать ключи 'name' и 'data'."
if not isinstance(item["name"], str):
return "'name' должен быть строкой."
if not isinstance(item["data"], dict):
return "'data' должен быть объектом (dict)."
if "columns" not in item["data"] or "rows" not in item["data"]:
return "'data' должен содержать 'columns' и 'rows'."
if not isinstance(item["data"]["columns"], list) or not all(isinstance(col, str) for col in item["data"]["columns"]):
return "'columns' должен быть списком строк."
if not isinstance(item["data"]["rows"], list) or not all(isinstance(row, list) for row in item["data"]["rows"]):
return "'rows' должен быть списком списков."
return None # Ошибок нет
@app.post("/extracttable")
async def extracttable_route(request: TextRequest):
prompt = gettable.USER_PROMPT.format(query=request.text)
system_prompt=gettable.SYSTEM_PROMPT
prompt = await trim_prompt(prompt, system_prompt)
response = await llm_api.predict(prompt, system_prompt=system_prompt)
result = {"response": None, "error": None, "raw": response}
if "JSON" not in response:
result["error"] = "Строка не содержит 'JSON'"
return result
prefix, json_str = response.split("JSON", 1)
json_str = json_str.strip()
if not json_str:
result["error"] = "После 'JSON' отсутствует JSON"
return result
try:
result["response"] = json.loads(json_str)
result["raw"] = prefix.strip()
validation_error = validate_json_format(result["response"])
if validation_error:
result["error"] = validation_error
else:
result["response"] = result["response"]
result["raw"] = prefix.strip()
except json.JSONDecodeError as e:
result["error"] = f"Ошибка декодирования JSON: {e}"
return result
@app.get("/health")
def health():
return {"status": "ok"}
async def generate_response(prompt):
prompt = await trim_prompt(prompt)
return await llm_api.predict(prompt)
@app.post("/cleantext")
async def cleantext_route(request: TextRequest):
return {"result": await generate_response(prompts.CLEAN_TEXT.format(text=request.text))}
@app.post("/getfollowup")
async def getfollowup_route(request: TextRequest):
return {"result": await generate_response(prompts.GET_FOLLOWUP.format(text=request.text))}
@app.post("/getagenda")
async def getagenda_route(request: TextRequest):
return {"result": await generate_response(prompts.GET_AGENDA.format(text=request.text))}
@app.post("/gethighlights")
async def gethighlights_route(request: TextRequest):
return {"result": await generate_response(prompts.GET_HIGHLIGHTS.format(text=request.text))}
@app.post("/getprojectinfo")
async def getprojectinfo_route(request: TextRequest):
prompts_list = [
prompts.GET_PROJECT_INFO_NAMES.format(text=request.text),
prompts.GET_PROJECT_INFO_AGENDA.format(text=request.text)
]
main_info = "\n\n".join([await generate_response(p) for p in prompts_list])
return {"result": main_info.strip().replace("Конец ответа", "").replace('</s>', '')}
@app.post("/getprojectlist")
async def getprojectlist_route(request: TextRequest):
list_of_projects = await generate_response(prompts.GET_PROJECT_LIST.format(text=request.text))
projects = [f"Проект {x}" for x in list_of_projects.split("Проект ") if x][1:]
projects = [p.replace("проект ", "").strip() for p in projects]
real_projects = await generate_response(prompts.GET_PROJECT_LIST_CHECK_PROJECT.format(text=request.text, projects=projects))
real_projects_list = re.findall(r'Да:\s*(.*?)\s*(?:\n\n|$)', real_projects)
return {"result": real_projects_list}
@app.post("/getprojectdetails")
async def getprojectdetails_route(request: TextRequest):
if not request.projects:
raise HTTPException(status_code=400, detail="Проекты не выбраны")
final = {}
for project in request.projects:
prompts_list = [
prompts.GET_PROJECT_DETAILS_AIM.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_VALUE.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_BUDGET.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_ECO_EFFECT.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_DEADLINE.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_NEW_PLAN.format(text=request.text, project=project),
prompts.GET_PROJECT_DETAILS_CONCLUSION.format(text=request.text, project=project),
]
final[project] = "\n\n".join([await generate_response(p) for p in prompts_list])
final[project] = final[project].replace("Конец ответа", "").replace('</s>', '').strip()
return {"result": final}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
|