Spaces:
Running
Running
Update routes.py
Browse files
routes.py
CHANGED
@@ -6,15 +6,47 @@ from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
|
6 |
from models import InitializeBotResponse, DocumentPath, NewChatResponse, QueryRequest, QueryResponse
|
7 |
from trainer_manager import get_trainer
|
8 |
from config import CUSTOM_PROMPT
|
|
|
9 |
|
10 |
router = APIRouter()
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
@router.post("/initialize_bot", response_model=InitializeBotResponse)
|
13 |
-
def initialize_bot():
|
|
|
|
|
|
|
|
|
14 |
trainer = get_trainer()
|
15 |
try:
|
16 |
bot_id = trainer.initialize_bot_id()
|
17 |
-
|
|
|
18 |
return InitializeBotResponse(bot_id=bot_id)
|
19 |
except Exception as e:
|
20 |
raise HTTPException(status_code=500, detail=str(e))
|
@@ -32,7 +64,7 @@ def add_document(document: DocumentPath):
|
|
32 |
async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...)):
|
33 |
trainer = get_trainer()
|
34 |
try:
|
35 |
-
# Save the uploaded file to a temporary
|
36 |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
|
37 |
contents = await file.read()
|
38 |
tmp.write(contents)
|
|
|
6 |
from models import InitializeBotResponse, DocumentPath, NewChatResponse, QueryRequest, QueryResponse
|
7 |
from trainer_manager import get_trainer
|
8 |
from config import CUSTOM_PROMPT
|
9 |
+
from prompt_templates import PromptTemplates
|
10 |
|
11 |
router = APIRouter()
|
12 |
|
13 |
+
def select_prompt(prompt_type: str):
|
14 |
+
"""Selects a prompt template based on the provided prompt_type."""
|
15 |
+
if prompt_type == "university":
|
16 |
+
return PromptTemplates.get_university_chatbot_prompt()
|
17 |
+
elif prompt_type == "quiz_solving":
|
18 |
+
return PromptTemplates.get_quiz_solving_prompt()
|
19 |
+
elif prompt_type == "assignment_solving":
|
20 |
+
return PromptTemplates.get_assignment_solving_prompt()
|
21 |
+
elif prompt_type == "paper_solving":
|
22 |
+
return PromptTemplates.get_paper_solving_prompt()
|
23 |
+
elif prompt_type == "quiz_creation":
|
24 |
+
return PromptTemplates.get_quiz_creation_prompt()
|
25 |
+
elif prompt_type == "assignment_creation":
|
26 |
+
return PromptTemplates.get_assignment_creation_prompt()
|
27 |
+
elif prompt_type == "paper_creation":
|
28 |
+
return PromptTemplates.get_paper_creation_prompt()
|
29 |
+
elif prompt_type == "check_quiz":
|
30 |
+
return PromptTemplates.get_check_quiz_prompt()
|
31 |
+
elif prompt_type == "check_assignment":
|
32 |
+
return PromptTemplates.get_check_assignment_prompt()
|
33 |
+
elif prompt_type == "check_paper":
|
34 |
+
return PromptTemplates.get_check_paper_prompt()
|
35 |
+
else:
|
36 |
+
# Fallback to a default prompt if no matching type is found
|
37 |
+
return PromptTemplates.get_quiz_solving_prompt()
|
38 |
+
|
39 |
@router.post("/initialize_bot", response_model=InitializeBotResponse)
|
40 |
+
def initialize_bot(prompt_type: str = "quiz_solving"):
|
41 |
+
"""
|
42 |
+
Initializes a new bot with a custom prompt template.
|
43 |
+
The prompt_type query parameter determines which prompt to use.
|
44 |
+
"""
|
45 |
trainer = get_trainer()
|
46 |
try:
|
47 |
bot_id = trainer.initialize_bot_id()
|
48 |
+
selected_prompt = select_prompt(prompt_type)
|
49 |
+
trainer.set_custom_prompt_template(bot_id, selected_prompt)
|
50 |
return InitializeBotResponse(bot_id=bot_id)
|
51 |
except Exception as e:
|
52 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
64 |
async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...)):
|
65 |
trainer = get_trainer()
|
66 |
try:
|
67 |
+
# Save the uploaded file to a temporary location
|
68 |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
|
69 |
contents = await file.read()
|
70 |
tmp.write(contents)
|