mominah commited on
Commit
0240527
·
verified ·
1 Parent(s): 071ae0f

Update routes.py

Browse files
Files changed (1) hide show
  1. routes.py +35 -3
routes.py CHANGED
@@ -6,15 +6,47 @@ from fastapi import APIRouter, HTTPException, UploadFile, File, Form
6
  from models import InitializeBotResponse, DocumentPath, NewChatResponse, QueryRequest, QueryResponse
7
  from trainer_manager import get_trainer
8
  from config import CUSTOM_PROMPT
 
9
 
10
  router = APIRouter()
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  @router.post("/initialize_bot", response_model=InitializeBotResponse)
13
- def initialize_bot():
 
 
 
 
14
  trainer = get_trainer()
15
  try:
16
  bot_id = trainer.initialize_bot_id()
17
- trainer.set_custom_prompt_template(bot_id, CUSTOM_PROMPT)
 
18
  return InitializeBotResponse(bot_id=bot_id)
19
  except Exception as e:
20
  raise HTTPException(status_code=500, detail=str(e))
@@ -32,7 +64,7 @@ def add_document(document: DocumentPath):
32
  async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...)):
33
  trainer = get_trainer()
34
  try:
35
- # Save the uploaded file to a temporary directory
36
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
37
  contents = await file.read()
38
  tmp.write(contents)
 
6
  from models import InitializeBotResponse, DocumentPath, NewChatResponse, QueryRequest, QueryResponse
7
  from trainer_manager import get_trainer
8
  from config import CUSTOM_PROMPT
9
+ from prompt_templates import PromptTemplates
10
 
11
  router = APIRouter()
12
 
13
+ def select_prompt(prompt_type: str):
14
+ """Selects a prompt template based on the provided prompt_type."""
15
+ if prompt_type == "university":
16
+ return PromptTemplates.get_university_chatbot_prompt()
17
+ elif prompt_type == "quiz_solving":
18
+ return PromptTemplates.get_quiz_solving_prompt()
19
+ elif prompt_type == "assignment_solving":
20
+ return PromptTemplates.get_assignment_solving_prompt()
21
+ elif prompt_type == "paper_solving":
22
+ return PromptTemplates.get_paper_solving_prompt()
23
+ elif prompt_type == "quiz_creation":
24
+ return PromptTemplates.get_quiz_creation_prompt()
25
+ elif prompt_type == "assignment_creation":
26
+ return PromptTemplates.get_assignment_creation_prompt()
27
+ elif prompt_type == "paper_creation":
28
+ return PromptTemplates.get_paper_creation_prompt()
29
+ elif prompt_type == "check_quiz":
30
+ return PromptTemplates.get_check_quiz_prompt()
31
+ elif prompt_type == "check_assignment":
32
+ return PromptTemplates.get_check_assignment_prompt()
33
+ elif prompt_type == "check_paper":
34
+ return PromptTemplates.get_check_paper_prompt()
35
+ else:
36
+ # Fallback to a default prompt if no matching type is found
37
+ return PromptTemplates.get_quiz_solving_prompt()
38
+
39
  @router.post("/initialize_bot", response_model=InitializeBotResponse)
40
+ def initialize_bot(prompt_type: str = "quiz_solving"):
41
+ """
42
+ Initializes a new bot with a custom prompt template.
43
+ The prompt_type query parameter determines which prompt to use.
44
+ """
45
  trainer = get_trainer()
46
  try:
47
  bot_id = trainer.initialize_bot_id()
48
+ selected_prompt = select_prompt(prompt_type)
49
+ trainer.set_custom_prompt_template(bot_id, selected_prompt)
50
  return InitializeBotResponse(bot_id=bot_id)
51
  except Exception as e:
52
  raise HTTPException(status_code=500, detail=str(e))
 
64
  async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...)):
65
  trainer = get_trainer()
66
  try:
67
+ # Save the uploaded file to a temporary location
68
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
69
  contents = await file.read()
70
  tmp.write(contents)