File size: 7,980 Bytes
9dd4b6c
 
a9b39d2
cb60f6b
ecf21a8
 
 
 
9dd4b6c
ecf21a8
9dd4b6c
9347ad1
 
 
 
a9b39d2
ecf21a8
 
a9b39d2
9347ad1
a9b39d2
9dd4b6c
267da20
9dd4b6c
 
 
 
 
 
a9b39d2
9347ad1
 
a9b39d2
f443975
2ccde67
f443975
9347ad1
 
 
 
 
 
 
 
 
2ccde67
 
f443975
 
2ccde67
 
9347ad1
2ccde67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9dd4b6c
 
 
2ccde67
 
 
ecf21a8
9dd4b6c
ecf21a8
be938bc
 
ecf21a8
9dd4b6c
be938bc
ecf21a8
9dd4b6c
ecf21a8
be938bc
ecf21a8
9dd4b6c
ecf21a8
 
9dd4b6c
2ccde67
 
 
 
 
 
ecf21a8
 
 
9dd4b6c
ecf21a8
9dd4b6c
a9b39d2
9dd4b6c
 
 
2ccde67
 
9dd4b6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cad0763
9dd4b6c
cad0763
9dd4b6c
 
 
 
cad0763
1af689b
9dd4b6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import json
import re
import os
from llm.common import LlmParams, LlmPredictParams
from llm.deepinfra_api import DeepInfraApi
from llm import prompts
from prompts import gettable
from dotenv import load_dotenv
import uvicorn
from models.text_request import TextRequest
from llm.llm_api import LlmApi
from routes.summary import router as summary_router


# Загрузка переменных окружения из файла .env
load_dotenv()

llm_api = LlmApi().get_api()

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"]
)

app.include_router(summary_router)


async def trim_prompt(prompt: str, system_prompt: str = None):
    
    result = await llm_api.tokenize(prompt)
    
    result_system = None
    system_prompt_length = 0
    if system_prompt is not None:
        result_system = await llm_api.tokenize(system_prompt)
        
        if result_system is not None:
            system_prompt_length = len(result_system["result"])
        
    
    # в случае ошибки при токенизации, вернем исходную строку безопасной длины
    if result["result"] is None or (system_prompt is not None and result_system is None):
        return prompt[int(llm_api.params.context_length / 3)]
    
    #вероятно, часть уходит на форматирование чата, надо проверить
    max_length = result["max_length"] - len(result["result"]) - system_prompt_length - llm_api.params.predict_params.n_predict

    detokenized_str = await llm_api.detokenize(result["result"][:max_length])    
    
    # в случае ошибки при детокенизации, вернем исходную строку безопасной длины
    if detokenized_str is None:
        return prompt[llm_api.params.context_length / 3]

    return detokenized_str
    

def validate_json_format(data):
    """ Проверяет, соответствует ли JSON ожидаемому формату. """
    if not isinstance(data, list):
        return "JSON должен быть списком объектов."

    # Возможно, в дальнейшем стоит описать менее детально, пока так для отладки
    for item in data:
        if not isinstance(item, dict):
            return "Элементы списка должны быть объектами (dict)."
        if "name" not in item or "data" not in item:
            return "Каждый объект должен содержать ключи 'name' и 'data'."
        if not isinstance(item["name"], str):
            return "'name' должен быть строкой."
        if not isinstance(item["data"], dict):
            return "'data' должен быть объектом (dict)."
        if "columns" not in item["data"] or "rows" not in item["data"]:
            return "'data' должен содержать 'columns' и 'rows'."
        if not isinstance(item["data"]["columns"], list) or not all(isinstance(col, str) for col in item["data"]["columns"]):
            return "'columns' должен быть списком строк."
        if not isinstance(item["data"]["rows"], list) or not all(isinstance(row, list) for row in item["data"]["rows"]):
            return "'rows' должен быть списком списков."

    return None  # Ошибок нет        
        
    
@app.post("/extracttable")
async def extracttable_route(request: TextRequest):
    prompt = gettable.USER_PROMPT.format(query=request.text)
    system_prompt=gettable.SYSTEM_PROMPT
    prompt = await trim_prompt(prompt, system_prompt)
    response = await llm_api.predict(prompt, system_prompt=system_prompt)
    
    result = {"response": None, "error": None, "raw": response}
    
    if "JSON" not in response:
        result["error"] = "Строка не содержит 'JSON'"
        return result
    
    prefix, json_str = response.split("JSON", 1)
    json_str = json_str.strip()
    
    if not json_str:
        result["error"] = "После 'JSON' отсутствует JSON"
        return result
    
    try:
        result["response"] = json.loads(json_str)
        result["raw"] = prefix.strip()
        validation_error = validate_json_format(result["response"])
        if validation_error:
            result["error"] = validation_error
        else:
            result["response"] = result["response"]
            result["raw"] = prefix.strip()
    except json.JSONDecodeError as e:
        result["error"] = f"Ошибка декодирования JSON: {e}"
        
    return result

@app.get("/health")
def health():
    return {"status": "ok"}

async def generate_response(prompt):
    prompt = await trim_prompt(prompt)
    return await llm_api.predict(prompt)

@app.post("/cleantext")
async def cleantext_route(request: TextRequest):
    return {"result": await generate_response(prompts.CLEAN_TEXT.format(text=request.text))}

@app.post("/getfollowup")
async def getfollowup_route(request: TextRequest):
    return {"result": await generate_response(prompts.GET_FOLLOWUP.format(text=request.text))}

@app.post("/getagenda")
async def getagenda_route(request: TextRequest):
    return {"result": await generate_response(prompts.GET_AGENDA.format(text=request.text))}

@app.post("/gethighlights")
async def gethighlights_route(request: TextRequest):
    return {"result": await generate_response(prompts.GET_HIGHLIGHTS.format(text=request.text))}

@app.post("/getprojectinfo")
async def getprojectinfo_route(request: TextRequest):
    prompts_list = [
        prompts.GET_PROJECT_INFO_NAMES.format(text=request.text),
        prompts.GET_PROJECT_INFO_AGENDA.format(text=request.text)
    ]
    main_info = "\n\n".join([await generate_response(p) for p in prompts_list])
    return {"result": main_info.strip().replace("Конец ответа", "").replace('</s>', '')}

@app.post("/getprojectlist")
async def getprojectlist_route(request: TextRequest):
    list_of_projects = await generate_response(prompts.GET_PROJECT_LIST.format(text=request.text))
    projects = [f"Проект {x}" for x in list_of_projects.split("Проект ") if x][1:]
    projects = [p.replace("проект ", "").strip() for p in projects]
    real_projects = await generate_response(prompts.GET_PROJECT_LIST_CHECK_PROJECT.format(text=request.text, projects=projects))
    real_projects_list = re.findall(r'Да:\s*(.*?)\s*(?:\n\n|$)', real_projects)
    return {"result": real_projects_list}

@app.post("/getprojectdetails")
async def getprojectdetails_route(request: TextRequest):
    if not request.projects:
        raise HTTPException(status_code=400, detail="Проекты не выбраны")
    
    final = {}
    for project in request.projects:
        prompts_list = [
            prompts.GET_PROJECT_DETAILS_AIM.format(text=request.text, project=project),
            prompts.GET_PROJECT_DETAILS_VALUE.format(text=request.text, project=project),
            prompts.GET_PROJECT_DETAILS_BUDGET.format(text=request.text, project=project),
            prompts.GET_PROJECT_DETAILS_ECO_EFFECT.format(text=request.text, project=project),
            prompts.GET_PROJECT_DETAILS_DEADLINE.format(text=request.text, project=project),
            prompts.GET_PROJECT_DETAILS_NEW_PLAN.format(text=request.text, project=project),
            prompts.GET_PROJECT_DETAILS_CONCLUSION.format(text=request.text, project=project),
        ]
        final[project] = "\n\n".join([await generate_response(p) for p in prompts_list])
        final[project] = final[project].replace("Конец ответа", "").replace('</s>', '').strip()
    
    return {"result": final}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)