Spaces:
Running
Running
Update routes.py
Browse files
routes.py
CHANGED
@@ -4,86 +4,82 @@ import tempfile
|
|
4 |
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
5 |
from fastapi.encoders import jsonable_encoder
|
6 |
from bson import ObjectId
|
7 |
-
from models import
|
|
|
|
|
|
|
8 |
from trainer_manager import get_trainer
|
9 |
-
from config import CUSTOM_PROMPT
|
10 |
from prompt_templates import PromptTemplates
|
11 |
|
12 |
router = APIRouter()
|
13 |
trainer = get_trainer()
|
|
|
|
|
14 |
@router.post("/initialize_bot", response_model=InitializeBotResponse)
|
15 |
-
def initialize_bot(
|
16 |
"""
|
17 |
-
Initializes a new bot
|
18 |
-
The
|
19 |
"""
|
20 |
-
|
21 |
try:
|
22 |
bot_id = trainer.initialize_bot_id()
|
23 |
-
|
24 |
-
# Assign the prompt template based on the prompt type.
|
25 |
-
if prompt_type == "university":
|
26 |
-
prompt_template = PromptTemplates.get_university_chatbot_prompt()
|
27 |
-
elif prompt_type == "quiz_solving":
|
28 |
-
prompt_template = PromptTemplates.get_quiz_solving_prompt()
|
29 |
-
elif prompt_type == "assignment_solving":
|
30 |
-
prompt_template = PromptTemplates.get_assignment_solving_prompt()
|
31 |
-
elif prompt_type == "paper_solving":
|
32 |
-
prompt_template = PromptTemplates.get_paper_solving_prompt()
|
33 |
-
elif prompt_type == "quiz_creation":
|
34 |
-
prompt_template = PromptTemplates.get_quiz_creation_prompt()
|
35 |
-
elif prompt_type == "assignment_creation":
|
36 |
-
prompt_template = PromptTemplates.get_assignment_creation_prompt()
|
37 |
-
elif prompt_type == "paper_creation":
|
38 |
-
prompt_template = PromptTemplates.get_paper_creation_prompt()
|
39 |
-
elif prompt_type == "check_quiz":
|
40 |
-
prompt_template = PromptTemplates.get_check_quiz_prompt()
|
41 |
-
elif prompt_type == "check_assignment":
|
42 |
-
prompt_template = PromptTemplates.get_check_assignment_prompt()
|
43 |
-
elif prompt_type == "check_paper":
|
44 |
-
prompt_template = PromptTemplates.get_check_paper_prompt()
|
45 |
-
else:
|
46 |
-
# Default prompt template if prompt_type is not recognized.
|
47 |
-
prompt_template = PromptTemplates.get_quiz_solving_prompt()
|
48 |
-
|
49 |
-
# Pass the assigned prompt template to the bot.
|
50 |
-
trainer.set_custom_prompt_template(bot_id, prompt_template)
|
51 |
return InitializeBotResponse(bot_id=bot_id)
|
52 |
except Exception as e:
|
53 |
raise HTTPException(status_code=500, detail=str(e))
|
54 |
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
@router.post("/upload_document")
|
57 |
async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...)):
|
58 |
"""
|
59 |
Saves the uploaded file temporarily and adds it to the bot's knowledge base.
|
60 |
"""
|
61 |
try:
|
62 |
-
# Save the file to a temporary location
|
63 |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
|
64 |
contents = await file.read()
|
65 |
tmp.write(contents)
|
66 |
tmp_path = tmp.name
|
67 |
|
68 |
-
# Add the document using the temporary file path
|
69 |
trainer.add_document_from_path(tmp_path, bot_id)
|
70 |
-
|
71 |
-
# Remove the temporary file
|
72 |
os.remove(tmp_path)
|
|
|
73 |
return {"message": "Document uploaded and added successfully."}
|
74 |
except Exception as e:
|
75 |
raise HTTPException(status_code=500, detail=str(e))
|
76 |
|
77 |
-
@router.post("/create_bot/{bot_id}")
|
78 |
-
def create_bot(bot_id: str):
|
79 |
-
"""
|
80 |
-
Creates (builds) the bot (e.g., builds its index) for the given bot_id.
|
81 |
-
"""
|
82 |
-
try:
|
83 |
-
trainer.create_bot(bot_id)
|
84 |
-
return {"message": f"Bot {bot_id} created successfully."}
|
85 |
-
except Exception as e:
|
86 |
-
raise HTTPException(status_code=500, detail=str(e))
|
87 |
|
88 |
@router.post("/new_chat/{bot_id}", response_model=NewChatResponse)
|
89 |
def new_chat(bot_id: str):
|
@@ -96,47 +92,40 @@ def new_chat(bot_id: str):
|
|
96 |
except Exception as e:
|
97 |
raise HTTPException(status_code=500, detail=str(e))
|
98 |
|
|
|
99 |
@router.post("/query", response_model=QueryResponse)
|
100 |
def send_query(query_request: QueryRequest):
|
101 |
"""
|
102 |
Processes a query and returns the bot's response along with any web sources.
|
103 |
"""
|
104 |
-
# trainer = get_trainer()
|
105 |
-
# trainer.load_bot(query_request.bot_id)
|
106 |
-
prompt_type = query_request.prompt_type
|
107 |
try:
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
prompt_template = PromptTemplates.get_check_quiz_prompt()
|
125 |
-
elif prompt_type == "check_assignment":
|
126 |
-
prompt_template = PromptTemplates.get_check_assignment_prompt()
|
127 |
-
elif prompt_type == "check_paper":
|
128 |
-
prompt_template = PromptTemplates.get_check_paper_prompt()
|
129 |
-
else:
|
130 |
-
# Default prompt template if prompt_type is not recognized.
|
131 |
-
prompt_template = PromptTemplates.get_quiz_solving_prompt()
|
132 |
trainer.set_custom_prompt_template(query_request.bot_id, prompt_template)
|
|
|
133 |
response, web_sources = trainer.get_response(
|
134 |
query_request.query, query_request.bot_id, query_request.chat_id
|
135 |
)
|
|
|
136 |
return QueryResponse(response=response, web_sources=web_sources)
|
137 |
except Exception as e:
|
138 |
raise HTTPException(status_code=500, detail=str(e))
|
139 |
|
|
|
140 |
@router.get("/list_chats/{bot_id}")
|
141 |
def list_chats(bot_id: str):
|
142 |
"""
|
@@ -148,11 +137,11 @@ def list_chats(bot_id: str):
|
|
148 |
except Exception as e:
|
149 |
raise HTTPException(status_code=500, detail=str(e))
|
150 |
|
|
|
151 |
@router.get("/chat_history/{chat_id}")
|
152 |
def chat_history(chat_id: str, bot_id: str):
|
153 |
"""
|
154 |
Returns the chat history for a given chat session.
|
155 |
-
ObjectId instances in the history are converted to strings.
|
156 |
"""
|
157 |
try:
|
158 |
history = trainer.get_chat_by_id(chat_id=chat_id)
|
|
|
4 |
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
5 |
from fastapi.encoders import jsonable_encoder
|
6 |
from bson import ObjectId
|
7 |
+
from models import (
|
8 |
+
InitializeBotResponse, DocumentPath, NewChatResponse,
|
9 |
+
QueryRequest, QueryResponse, CreateBotRequest, CreateBotResponse
|
10 |
+
)
|
11 |
from trainer_manager import get_trainer
|
|
|
12 |
from prompt_templates import PromptTemplates
|
13 |
|
14 |
router = APIRouter()
|
15 |
trainer = get_trainer()
|
16 |
+
|
17 |
+
|
18 |
@router.post("/initialize_bot", response_model=InitializeBotResponse)
|
19 |
+
def initialize_bot():
|
20 |
"""
|
21 |
+
Initializes a new bot and returns the bot_id.
|
22 |
+
The prompt type is now passed during bot creation, not initialization.
|
23 |
"""
|
|
|
24 |
try:
|
25 |
bot_id = trainer.initialize_bot_id()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
return InitializeBotResponse(bot_id=bot_id)
|
27 |
except Exception as e:
|
28 |
raise HTTPException(status_code=500, detail=str(e))
|
29 |
|
30 |
|
31 |
+
@router.post("/create_bot", response_model=CreateBotResponse)
|
32 |
+
def create_bot(data: CreateBotRequest):
|
33 |
+
"""
|
34 |
+
Finalizes bot creation by assigning a prompt template based on the prompt type.
|
35 |
+
"""
|
36 |
+
try:
|
37 |
+
# Assign the appropriate prompt template based on prompt_type
|
38 |
+
prompt_type = data.prompt_type.lower()
|
39 |
+
prompt_template_mapping = {
|
40 |
+
"university": PromptTemplates.get_university_chatbot_prompt,
|
41 |
+
"quiz_solving": PromptTemplates.get_quiz_solving_prompt,
|
42 |
+
"assignment_solving": PromptTemplates.get_assignment_solving_prompt,
|
43 |
+
"paper_solving": PromptTemplates.get_paper_solving_prompt,
|
44 |
+
"quiz_creation": PromptTemplates.get_quiz_creation_prompt,
|
45 |
+
"assignment_creation": PromptTemplates.get_assignment_creation_prompt,
|
46 |
+
"paper_creation": PromptTemplates.get_paper_creation_prompt,
|
47 |
+
"check_quiz": PromptTemplates.get_check_quiz_prompt,
|
48 |
+
"check_assignment": PromptTemplates.get_check_assignment_prompt,
|
49 |
+
"check_paper": PromptTemplates.get_check_paper_prompt
|
50 |
+
}
|
51 |
+
|
52 |
+
prompt_template = prompt_template_mapping.get(prompt_type, PromptTemplates.get_quiz_solving_prompt)()
|
53 |
+
|
54 |
+
# Apply the selected prompt template to the bot
|
55 |
+
trainer.set_custom_prompt_template(data.bot_id, prompt_template)
|
56 |
+
|
57 |
+
# Complete bot creation
|
58 |
+
trainer.create_bot(data.bot_id)
|
59 |
+
|
60 |
+
return CreateBotResponse(bot_id=data.bot_id)
|
61 |
+
except Exception as e:
|
62 |
+
raise HTTPException(status_code=500, detail=str(e))
|
63 |
+
|
64 |
+
|
65 |
@router.post("/upload_document")
|
66 |
async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...)):
|
67 |
"""
|
68 |
Saves the uploaded file temporarily and adds it to the bot's knowledge base.
|
69 |
"""
|
70 |
try:
|
|
|
71 |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
|
72 |
contents = await file.read()
|
73 |
tmp.write(contents)
|
74 |
tmp_path = tmp.name
|
75 |
|
|
|
76 |
trainer.add_document_from_path(tmp_path, bot_id)
|
|
|
|
|
77 |
os.remove(tmp_path)
|
78 |
+
|
79 |
return {"message": "Document uploaded and added successfully."}
|
80 |
except Exception as e:
|
81 |
raise HTTPException(status_code=500, detail=str(e))
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
@router.post("/new_chat/{bot_id}", response_model=NewChatResponse)
|
85 |
def new_chat(bot_id: str):
|
|
|
92 |
except Exception as e:
|
93 |
raise HTTPException(status_code=500, detail=str(e))
|
94 |
|
95 |
+
|
96 |
@router.post("/query", response_model=QueryResponse)
|
97 |
def send_query(query_request: QueryRequest):
|
98 |
"""
|
99 |
Processes a query and returns the bot's response along with any web sources.
|
100 |
"""
|
|
|
|
|
|
|
101 |
try:
|
102 |
+
prompt_type = query_request.prompt_type.lower()
|
103 |
+
prompt_template_mapping = {
|
104 |
+
"university": PromptTemplates.get_university_chatbot_prompt,
|
105 |
+
"quiz_solving": PromptTemplates.get_quiz_solving_prompt,
|
106 |
+
"assignment_solving": PromptTemplates.get_assignment_solving_prompt,
|
107 |
+
"paper_solving": PromptTemplates.get_paper_solving_prompt,
|
108 |
+
"quiz_creation": PromptTemplates.get_quiz_creation_prompt,
|
109 |
+
"assignment_creation": PromptTemplates.get_assignment_creation_prompt,
|
110 |
+
"paper_creation": PromptTemplates.get_paper_creation_prompt,
|
111 |
+
"check_quiz": PromptTemplates.get_check_quiz_prompt,
|
112 |
+
"check_assignment": PromptTemplates.get_check_assignment_prompt,
|
113 |
+
"check_paper": PromptTemplates.get_check_paper_prompt
|
114 |
+
}
|
115 |
+
|
116 |
+
prompt_template = prompt_template_mapping.get(prompt_type, PromptTemplates.get_quiz_solving_prompt)()
|
117 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
trainer.set_custom_prompt_template(query_request.bot_id, prompt_template)
|
119 |
+
|
120 |
response, web_sources = trainer.get_response(
|
121 |
query_request.query, query_request.bot_id, query_request.chat_id
|
122 |
)
|
123 |
+
|
124 |
return QueryResponse(response=response, web_sources=web_sources)
|
125 |
except Exception as e:
|
126 |
raise HTTPException(status_code=500, detail=str(e))
|
127 |
|
128 |
+
|
129 |
@router.get("/list_chats/{bot_id}")
|
130 |
def list_chats(bot_id: str):
|
131 |
"""
|
|
|
137 |
except Exception as e:
|
138 |
raise HTTPException(status_code=500, detail=str(e))
|
139 |
|
140 |
+
|
141 |
@router.get("/chat_history/{chat_id}")
|
142 |
def chat_history(chat_id: str, bot_id: str):
|
143 |
"""
|
144 |
Returns the chat history for a given chat session.
|
|
|
145 |
"""
|
146 |
try:
|
147 |
history = trainer.get_chat_by_id(chat_id=chat_id)
|