Spaces:
Sleeping
Sleeping
Update routes.py
Browse files
routes.py
CHANGED
@@ -4,22 +4,18 @@ import tempfile
|
|
4 |
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
5 |
from fastapi.encoders import jsonable_encoder
|
6 |
from bson import ObjectId
|
7 |
-
from models import
|
8 |
-
InitializeBotResponse, DocumentPath, NewChatResponse,
|
9 |
-
QueryRequest, QueryResponse, CreateBotRequest, CreateBotResponse
|
10 |
-
)
|
11 |
from trainer_manager import get_trainer
|
|
|
12 |
from prompt_templates import PromptTemplates
|
13 |
|
14 |
router = APIRouter()
|
15 |
trainer = get_trainer()
|
16 |
|
17 |
-
|
18 |
@router.post("/initialize_bot", response_model=InitializeBotResponse)
|
19 |
def initialize_bot():
|
20 |
"""
|
21 |
-
Initializes a new bot and returns the bot_id.
|
22 |
-
The prompt type is now passed during bot creation, not initialization.
|
23 |
"""
|
24 |
try:
|
25 |
bot_id = trainer.initialize_bot_id()
|
@@ -28,59 +24,68 @@ def initialize_bot():
|
|
28 |
raise HTTPException(status_code=500, detail=str(e))
|
29 |
|
30 |
|
31 |
-
@router.post("/create_bot", response_model=CreateBotResponse)
|
32 |
-
def create_bot(data: CreateBotRequest):
|
33 |
-
"""
|
34 |
-
Finalizes bot creation by assigning a prompt template based on the prompt type.
|
35 |
-
"""
|
36 |
-
try:
|
37 |
-
# Assign the appropriate prompt template based on prompt_type
|
38 |
-
prompt_type = data.prompt_type.lower()
|
39 |
-
prompt_template_mapping = {
|
40 |
-
"university": PromptTemplates.get_university_chatbot_prompt,
|
41 |
-
"quiz_solving": PromptTemplates.get_quiz_solving_prompt,
|
42 |
-
"assignment_solving": PromptTemplates.get_assignment_solving_prompt,
|
43 |
-
"paper_solving": PromptTemplates.get_paper_solving_prompt,
|
44 |
-
"quiz_creation": PromptTemplates.get_quiz_creation_prompt,
|
45 |
-
"assignment_creation": PromptTemplates.get_assignment_creation_prompt,
|
46 |
-
"paper_creation": PromptTemplates.get_paper_creation_prompt,
|
47 |
-
"check_quiz": PromptTemplates.get_check_quiz_prompt,
|
48 |
-
"check_assignment": PromptTemplates.get_check_assignment_prompt,
|
49 |
-
"check_paper": PromptTemplates.get_check_paper_prompt
|
50 |
-
}
|
51 |
-
|
52 |
-
prompt_template = prompt_template_mapping.get(prompt_type, PromptTemplates.get_quiz_solving_prompt)()
|
53 |
-
|
54 |
-
# Apply the selected prompt template to the bot
|
55 |
-
trainer.set_custom_prompt_template(data.bot_id, prompt_template)
|
56 |
-
|
57 |
-
# Complete bot creation
|
58 |
-
trainer.create_bot(data.bot_id)
|
59 |
-
|
60 |
-
return CreateBotResponse(bot_id=data.bot_id)
|
61 |
-
except Exception as e:
|
62 |
-
raise HTTPException(status_code=500, detail=str(e))
|
63 |
-
|
64 |
-
|
65 |
@router.post("/upload_document")
|
66 |
async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...)):
|
67 |
"""
|
68 |
Saves the uploaded file temporarily and adds it to the bot's knowledge base.
|
69 |
"""
|
70 |
try:
|
|
|
71 |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
|
72 |
contents = await file.read()
|
73 |
tmp.write(contents)
|
74 |
tmp_path = tmp.name
|
75 |
|
|
|
76 |
trainer.add_document_from_path(tmp_path, bot_id)
|
77 |
-
os.remove(tmp_path)
|
78 |
|
|
|
|
|
79 |
return {"message": "Document uploaded and added successfully."}
|
80 |
except Exception as e:
|
81 |
raise HTTPException(status_code=500, detail=str(e))
|
82 |
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
@router.post("/new_chat/{bot_id}", response_model=NewChatResponse)
|
85 |
def new_chat(bot_id: str):
|
86 |
"""
|
@@ -98,17 +103,42 @@ def send_query(query_request: QueryRequest):
|
|
98 |
"""
|
99 |
Processes a query and returns the bot's response along with any web sources.
|
100 |
"""
|
|
|
101 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
response, web_sources = trainer.get_response(
|
103 |
query_request.query, query_request.bot_id, query_request.chat_id
|
104 |
)
|
105 |
-
|
106 |
return QueryResponse(response=response, web_sources=web_sources)
|
107 |
except Exception as e:
|
108 |
raise HTTPException(status_code=500, detail=str(e))
|
109 |
|
110 |
|
111 |
-
|
112 |
@router.get("/list_chats/{bot_id}")
|
113 |
def list_chats(bot_id: str):
|
114 |
"""
|
@@ -125,6 +155,7 @@ def list_chats(bot_id: str):
|
|
125 |
def chat_history(chat_id: str, bot_id: str):
|
126 |
"""
|
127 |
Returns the chat history for a given chat session.
|
|
|
128 |
"""
|
129 |
try:
|
130 |
history = trainer.get_chat_by_id(chat_id=chat_id)
|
|
|
4 |
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
5 |
from fastapi.encoders import jsonable_encoder
|
6 |
from bson import ObjectId
|
7 |
+
from models import InitializeBotResponse, DocumentPath, NewChatResponse, QueryRequest, QueryResponse
|
|
|
|
|
|
|
8 |
from trainer_manager import get_trainer
|
9 |
+
from config import CUSTOM_PROMPT
|
10 |
from prompt_templates import PromptTemplates
|
11 |
|
12 |
router = APIRouter()
|
13 |
trainer = get_trainer()
|
14 |
|
|
|
15 |
@router.post("/initialize_bot", response_model=InitializeBotResponse)
|
16 |
def initialize_bot():
|
17 |
"""
|
18 |
+
Initializes a new bot and returns the generated bot_id.
|
|
|
19 |
"""
|
20 |
try:
|
21 |
bot_id = trainer.initialize_bot_id()
|
|
|
24 |
raise HTTPException(status_code=500, detail=str(e))
|
25 |
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
@router.post("/upload_document")
|
28 |
async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...)):
|
29 |
"""
|
30 |
Saves the uploaded file temporarily and adds it to the bot's knowledge base.
|
31 |
"""
|
32 |
try:
|
33 |
+
# Save the file to a temporary location.
|
34 |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
|
35 |
contents = await file.read()
|
36 |
tmp.write(contents)
|
37 |
tmp_path = tmp.name
|
38 |
|
39 |
+
# Add the document using the temporary file path.
|
40 |
trainer.add_document_from_path(tmp_path, bot_id)
|
|
|
41 |
|
42 |
+
# Remove the temporary file.
|
43 |
+
os.remove(tmp_path)
|
44 |
return {"message": "Document uploaded and added successfully."}
|
45 |
except Exception as e:
|
46 |
raise HTTPException(status_code=500, detail=str(e))
|
47 |
|
48 |
|
49 |
+
@router.post("/create_bot/{bot_id}")
|
50 |
+
def create_bot(bot_id: str, prompt_type: str = Form(...)):
|
51 |
+
"""
|
52 |
+
Creates (builds) the bot for the given bot_id using the specified prompt type.
|
53 |
+
This endpoint assigns the appropriate custom prompt template before finalizing the bot creation.
|
54 |
+
"""
|
55 |
+
try:
|
56 |
+
# Determine the appropriate prompt template based on the provided prompt_type.
|
57 |
+
if prompt_type == "university":
|
58 |
+
prompt_template = PromptTemplates.get_university_chatbot_prompt()
|
59 |
+
elif prompt_type == "quiz_solving":
|
60 |
+
prompt_template = PromptTemplates.get_quiz_solving_prompt()
|
61 |
+
elif prompt_type == "assignment_solving":
|
62 |
+
prompt_template = PromptTemplates.get_assignment_solving_prompt()
|
63 |
+
elif prompt_type == "paper_solving":
|
64 |
+
prompt_template = PromptTemplates.get_paper_solving_prompt()
|
65 |
+
elif prompt_type == "quiz_creation":
|
66 |
+
prompt_template = PromptTemplates.get_quiz_creation_prompt()
|
67 |
+
elif prompt_type == "assignment_creation":
|
68 |
+
prompt_template = PromptTemplates.get_assignment_creation_prompt()
|
69 |
+
elif prompt_type == "paper_creation":
|
70 |
+
prompt_template = PromptTemplates.get_paper_creation_prompt()
|
71 |
+
elif prompt_type == "check_quiz":
|
72 |
+
prompt_template = PromptTemplates.get_check_quiz_prompt()
|
73 |
+
elif prompt_type == "check_assignment":
|
74 |
+
prompt_template = PromptTemplates.get_check_assignment_prompt()
|
75 |
+
elif prompt_type == "check_paper":
|
76 |
+
prompt_template = PromptTemplates.get_check_paper_prompt()
|
77 |
+
else:
|
78 |
+
# Default prompt template if prompt_type is not recognized.
|
79 |
+
prompt_template = PromptTemplates.get_quiz_solving_prompt()
|
80 |
+
|
81 |
+
# Set the custom prompt template and create the bot.
|
82 |
+
trainer.set_custom_prompt_template(bot_id, prompt_template)
|
83 |
+
trainer.create_bot(bot_id)
|
84 |
+
return {"message": f"Bot {bot_id} created successfully."}
|
85 |
+
except Exception as e:
|
86 |
+
raise HTTPException(status_code=500, detail=str(e))
|
87 |
+
|
88 |
+
|
89 |
@router.post("/new_chat/{bot_id}", response_model=NewChatResponse)
|
90 |
def new_chat(bot_id: str):
|
91 |
"""
|
|
|
103 |
"""
|
104 |
Processes a query and returns the bot's response along with any web sources.
|
105 |
"""
|
106 |
+
prompt_type = query_request.prompt_type
|
107 |
try:
|
108 |
+
# Assign the prompt template based on the provided prompt_type.
|
109 |
+
if prompt_type == "university":
|
110 |
+
prompt_template = PromptTemplates.get_university_chatbot_prompt()
|
111 |
+
elif prompt_type == "quiz_solving":
|
112 |
+
prompt_template = PromptTemplates.get_quiz_solving_prompt()
|
113 |
+
elif prompt_type == "assignment_solving":
|
114 |
+
prompt_template = PromptTemplates.get_assignment_solving_prompt()
|
115 |
+
elif prompt_type == "paper_solving":
|
116 |
+
prompt_template = PromptTemplates.get_paper_solving_prompt()
|
117 |
+
elif prompt_type == "quiz_creation":
|
118 |
+
prompt_template = PromptTemplates.get_quiz_creation_prompt()
|
119 |
+
elif prompt_type == "assignment_creation":
|
120 |
+
prompt_template = PromptTemplates.get_assignment_creation_prompt()
|
121 |
+
elif prompt_type == "paper_creation":
|
122 |
+
prompt_template = PromptTemplates.get_paper_creation_prompt()
|
123 |
+
elif prompt_type == "check_quiz":
|
124 |
+
prompt_template = PromptTemplates.get_check_quiz_prompt()
|
125 |
+
elif prompt_type == "check_assignment":
|
126 |
+
prompt_template = PromptTemplates.get_check_assignment_prompt()
|
127 |
+
elif prompt_type == "check_paper":
|
128 |
+
prompt_template = PromptTemplates.get_check_paper_prompt()
|
129 |
+
else:
|
130 |
+
# Default prompt template if prompt_type is not recognized.
|
131 |
+
prompt_template = PromptTemplates.get_quiz_solving_prompt()
|
132 |
+
|
133 |
+
trainer.set_custom_prompt_template(query_request.bot_id, prompt_template)
|
134 |
response, web_sources = trainer.get_response(
|
135 |
query_request.query, query_request.bot_id, query_request.chat_id
|
136 |
)
|
|
|
137 |
return QueryResponse(response=response, web_sources=web_sources)
|
138 |
except Exception as e:
|
139 |
raise HTTPException(status_code=500, detail=str(e))
|
140 |
|
141 |
|
|
|
142 |
@router.get("/list_chats/{bot_id}")
|
143 |
def list_chats(bot_id: str):
|
144 |
"""
|
|
|
155 |
def chat_history(chat_id: str, bot_id: str):
|
156 |
"""
|
157 |
Returns the chat history for a given chat session.
|
158 |
+
ObjectId instances in the history are converted to strings.
|
159 |
"""
|
160 |
try:
|
161 |
history = trainer.get_chat_by_id(chat_id=chat_id)
|