Spaces:
Sleeping
Sleeping
Update routes.py
Browse files
routes.py
CHANGED
@@ -2,6 +2,8 @@ import os
|
|
2 |
import shutil
|
3 |
import tempfile
|
4 |
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
|
|
|
|
5 |
from models import InitializeBotResponse, DocumentPath, NewChatResponse, QueryRequest, QueryResponse
|
6 |
from trainer_manager import get_trainer
|
7 |
from config import CUSTOM_PROMPT
|
@@ -9,34 +11,60 @@ from prompt_templates import PromptTemplates
|
|
9 |
|
10 |
router = APIRouter()
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
@router.post("/initialize_bot", response_model=InitializeBotResponse)
|
13 |
-
def initialize_bot(
|
14 |
"""
|
15 |
Initializes a new bot with a custom prompt template.
|
16 |
-
The
|
17 |
"""
|
18 |
trainer = get_trainer()
|
19 |
try:
|
20 |
bot_id = trainer.initialize_bot_id()
|
21 |
-
|
|
|
22 |
return InitializeBotResponse(bot_id=bot_id)
|
23 |
except Exception as e:
|
24 |
raise HTTPException(status_code=500, detail=str(e))
|
25 |
|
26 |
@router.post("/upload_document")
|
27 |
async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...)):
|
|
|
|
|
|
|
28 |
trainer = get_trainer()
|
29 |
try:
|
30 |
-
# Save the uploaded file to a temporary location
|
31 |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
|
32 |
contents = await file.read()
|
33 |
tmp.write(contents)
|
34 |
tmp_path = tmp.name
|
35 |
|
36 |
-
# Add the document from the temporary file path
|
37 |
trainer.add_document_from_path(tmp_path, bot_id)
|
38 |
-
|
39 |
-
# Remove the temporary file
|
40 |
os.remove(tmp_path)
|
41 |
return {"message": "Document uploaded and added successfully."}
|
42 |
except Exception as e:
|
@@ -44,6 +72,9 @@ async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...))
|
|
44 |
|
45 |
@router.post("/create_bot/{bot_id}")
|
46 |
def create_bot(bot_id: str):
|
|
|
|
|
|
|
47 |
trainer = get_trainer()
|
48 |
try:
|
49 |
trainer.create_bot(bot_id)
|
@@ -53,6 +84,9 @@ def create_bot(bot_id: str):
|
|
53 |
|
54 |
@router.post("/new_chat/{bot_id}", response_model=NewChatResponse)
|
55 |
def new_chat(bot_id: str):
|
|
|
|
|
|
|
56 |
trainer = get_trainer()
|
57 |
try:
|
58 |
chat_id = trainer.new_chat(bot_id)
|
@@ -62,6 +96,9 @@ def new_chat(bot_id: str):
|
|
62 |
|
63 |
@router.post("/query", response_model=QueryResponse)
|
64 |
def send_query(query_request: QueryRequest):
|
|
|
|
|
|
|
65 |
trainer = get_trainer()
|
66 |
try:
|
67 |
response, web_sources = trainer.get_response(
|
@@ -73,6 +110,9 @@ def send_query(query_request: QueryRequest):
|
|
73 |
|
74 |
@router.get("/list_chats/{bot_id}")
|
75 |
def list_chats(bot_id: str):
|
|
|
|
|
|
|
76 |
trainer = get_trainer()
|
77 |
try:
|
78 |
chats = trainer.list_chats(bot_id)
|
@@ -80,15 +120,15 @@ def list_chats(bot_id: str):
|
|
80 |
except Exception as e:
|
81 |
raise HTTPException(status_code=500, detail=str(e))
|
82 |
|
83 |
-
from fastapi.encoders import jsonable_encoder
|
84 |
-
from bson import ObjectId
|
85 |
-
|
86 |
@router.get("/chat_history/{chat_id}")
|
87 |
def chat_history(chat_id: str, bot_id: str):
|
|
|
|
|
|
|
|
|
88 |
trainer = get_trainer()
|
89 |
try:
|
90 |
history = trainer.get_chat_by_id(chat_id=chat_id)
|
91 |
-
# Convert ObjectId instances to strings before returning the response.
|
92 |
return jsonable_encoder(history, custom_encoder={ObjectId: str})
|
93 |
except Exception as e:
|
94 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
2 |
import shutil
|
3 |
import tempfile
|
4 |
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
5 |
+
from fastapi.encoders import jsonable_encoder
|
6 |
+
from bson import ObjectId
|
7 |
from models import InitializeBotResponse, DocumentPath, NewChatResponse, QueryRequest, QueryResponse
|
8 |
from trainer_manager import get_trainer
|
9 |
from config import CUSTOM_PROMPT
|
|
|
11 |
|
12 |
router = APIRouter()
|
13 |
|
14 |
+
def select_prompt(prompt_type: str):
|
15 |
+
"""Selects a prompt template based on the provided prompt_type."""
|
16 |
+
if prompt_type == "university":
|
17 |
+
return PromptTemplates.get_university_chatbot_prompt()
|
18 |
+
elif prompt_type == "quiz_solving":
|
19 |
+
return PromptTemplates.get_quiz_solving_prompt()
|
20 |
+
elif prompt_type == "assignment_solving":
|
21 |
+
return PromptTemplates.get_assignment_solving_prompt()
|
22 |
+
elif prompt_type == "paper_solving":
|
23 |
+
return PromptTemplates.get_paper_solving_prompt()
|
24 |
+
elif prompt_type == "quiz_creation":
|
25 |
+
return PromptTemplates.get_quiz_creation_prompt()
|
26 |
+
elif prompt_type == "assignment_creation":
|
27 |
+
return PromptTemplates.get_assignment_creation_prompt()
|
28 |
+
elif prompt_type == "paper_creation":
|
29 |
+
return PromptTemplates.get_paper_creation_prompt()
|
30 |
+
elif prompt_type == "check_quiz":
|
31 |
+
return PromptTemplates.get_check_quiz_prompt()
|
32 |
+
elif prompt_type == "check_assignment":
|
33 |
+
return PromptTemplates.get_check_assignment_prompt()
|
34 |
+
elif prompt_type == "check_paper":
|
35 |
+
return PromptTemplates.get_check_paper_prompt()
|
36 |
+
else:
|
37 |
+
# Fallback to a default prompt if no matching type is found
|
38 |
+
return PromptTemplates.get_quiz_solving_prompt()
|
39 |
+
|
40 |
@router.post("/initialize_bot", response_model=InitializeBotResponse)
|
41 |
+
def initialize_bot(prompt_type: str):
|
42 |
"""
|
43 |
Initializes a new bot with a custom prompt template.
|
44 |
+
The prompt_type query parameter determines which prompt to use.
|
45 |
"""
|
46 |
trainer = get_trainer()
|
47 |
try:
|
48 |
bot_id = trainer.initialize_bot_id()
|
49 |
+
prompt = select_prompt(prompt_type)
|
50 |
+
trainer.set_custom_prompt_template(bot_id, prompt)
|
51 |
return InitializeBotResponse(bot_id=bot_id)
|
52 |
except Exception as e:
|
53 |
raise HTTPException(status_code=500, detail=str(e))
|
54 |
|
55 |
@router.post("/upload_document")
|
56 |
async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...)):
|
57 |
+
"""
|
58 |
+
Saves the uploaded file temporarily and adds it to the bot's knowledge base.
|
59 |
+
"""
|
60 |
trainer = get_trainer()
|
61 |
try:
|
|
|
62 |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
|
63 |
contents = await file.read()
|
64 |
tmp.write(contents)
|
65 |
tmp_path = tmp.name
|
66 |
|
|
|
67 |
trainer.add_document_from_path(tmp_path, bot_id)
|
|
|
|
|
68 |
os.remove(tmp_path)
|
69 |
return {"message": "Document uploaded and added successfully."}
|
70 |
except Exception as e:
|
|
|
72 |
|
73 |
@router.post("/create_bot/{bot_id}")
|
74 |
def create_bot(bot_id: str):
|
75 |
+
"""
|
76 |
+
Creates the bot (builds its index) for the given bot_id.
|
77 |
+
"""
|
78 |
trainer = get_trainer()
|
79 |
try:
|
80 |
trainer.create_bot(bot_id)
|
|
|
84 |
|
85 |
@router.post("/new_chat/{bot_id}", response_model=NewChatResponse)
|
86 |
def new_chat(bot_id: str):
|
87 |
+
"""
|
88 |
+
Creates a new chat session for the bot.
|
89 |
+
"""
|
90 |
trainer = get_trainer()
|
91 |
try:
|
92 |
chat_id = trainer.new_chat(bot_id)
|
|
|
96 |
|
97 |
@router.post("/query", response_model=QueryResponse)
|
98 |
def send_query(query_request: QueryRequest):
|
99 |
+
"""
|
100 |
+
Processes a user's query and returns the bot's response along with any web sources.
|
101 |
+
"""
|
102 |
trainer = get_trainer()
|
103 |
try:
|
104 |
response, web_sources = trainer.get_response(
|
|
|
110 |
|
111 |
@router.get("/list_chats/{bot_id}")
|
112 |
def list_chats(bot_id: str):
|
113 |
+
"""
|
114 |
+
Returns a list of previous chat sessions for the specified bot.
|
115 |
+
"""
|
116 |
trainer = get_trainer()
|
117 |
try:
|
118 |
chats = trainer.list_chats(bot_id)
|
|
|
120 |
except Exception as e:
|
121 |
raise HTTPException(status_code=500, detail=str(e))
|
122 |
|
|
|
|
|
|
|
123 |
@router.get("/chat_history/{chat_id}")
|
124 |
def chat_history(chat_id: str, bot_id: str):
|
125 |
+
"""
|
126 |
+
Returns the chat history for a given chat session.
|
127 |
+
ObjectId instances are converted to strings.
|
128 |
+
"""
|
129 |
trainer = get_trainer()
|
130 |
try:
|
131 |
history = trainer.get_chat_by_id(chat_id=chat_id)
|
|
|
132 |
return jsonable_encoder(history, custom_encoder={ObjectId: str})
|
133 |
except Exception as e:
|
134 |
raise HTTPException(status_code=500, detail=str(e))
|