mominah commited on
Commit
4cf11fc
·
verified ·
1 Parent(s): eef2ed0

Update routes.py

Browse files
Files changed (1) hide show
  1. routes.py +51 -11
routes.py CHANGED
@@ -2,6 +2,8 @@ import os
2
  import shutil
3
  import tempfile
4
  from fastapi import APIRouter, HTTPException, UploadFile, File, Form
 
 
5
  from models import InitializeBotResponse, DocumentPath, NewChatResponse, QueryRequest, QueryResponse
6
  from trainer_manager import get_trainer
7
  from config import CUSTOM_PROMPT
@@ -9,34 +11,60 @@ from prompt_templates import PromptTemplates
9
 
10
  router = APIRouter()
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  @router.post("/initialize_bot", response_model=InitializeBotResponse)
13
- def initialize_bot(custom_prompt: str):
14
  """
15
  Initializes a new bot with a custom prompt template.
16
- The custom_prompt parameter should be provided (e.g., via navigation) and is used as the bot's prompt template.
17
  """
18
  trainer = get_trainer()
19
  try:
20
  bot_id = trainer.initialize_bot_id()
21
- trainer.set_custom_prompt_template(bot_id, custom_prompt)
 
22
  return InitializeBotResponse(bot_id=bot_id)
23
  except Exception as e:
24
  raise HTTPException(status_code=500, detail=str(e))
25
 
26
  @router.post("/upload_document")
27
  async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...)):
 
 
 
28
  trainer = get_trainer()
29
  try:
30
- # Save the uploaded file to a temporary location
31
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
32
  contents = await file.read()
33
  tmp.write(contents)
34
  tmp_path = tmp.name
35
 
36
- # Add the document from the temporary file path
37
  trainer.add_document_from_path(tmp_path, bot_id)
38
-
39
- # Remove the temporary file
40
  os.remove(tmp_path)
41
  return {"message": "Document uploaded and added successfully."}
42
  except Exception as e:
@@ -44,6 +72,9 @@ async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...))
44
 
45
  @router.post("/create_bot/{bot_id}")
46
  def create_bot(bot_id: str):
 
 
 
47
  trainer = get_trainer()
48
  try:
49
  trainer.create_bot(bot_id)
@@ -53,6 +84,9 @@ def create_bot(bot_id: str):
53
 
54
  @router.post("/new_chat/{bot_id}", response_model=NewChatResponse)
55
  def new_chat(bot_id: str):
 
 
 
56
  trainer = get_trainer()
57
  try:
58
  chat_id = trainer.new_chat(bot_id)
@@ -62,6 +96,9 @@ def new_chat(bot_id: str):
62
 
63
  @router.post("/query", response_model=QueryResponse)
64
  def send_query(query_request: QueryRequest):
 
 
 
65
  trainer = get_trainer()
66
  try:
67
  response, web_sources = trainer.get_response(
@@ -73,6 +110,9 @@ def send_query(query_request: QueryRequest):
73
 
74
  @router.get("/list_chats/{bot_id}")
75
  def list_chats(bot_id: str):
 
 
 
76
  trainer = get_trainer()
77
  try:
78
  chats = trainer.list_chats(bot_id)
@@ -80,15 +120,15 @@ def list_chats(bot_id: str):
80
  except Exception as e:
81
  raise HTTPException(status_code=500, detail=str(e))
82
 
83
- from fastapi.encoders import jsonable_encoder
84
- from bson import ObjectId
85
-
86
  @router.get("/chat_history/{chat_id}")
87
  def chat_history(chat_id: str, bot_id: str):
 
 
 
 
88
  trainer = get_trainer()
89
  try:
90
  history = trainer.get_chat_by_id(chat_id=chat_id)
91
- # Convert ObjectId instances to strings before returning the response.
92
  return jsonable_encoder(history, custom_encoder={ObjectId: str})
93
  except Exception as e:
94
  raise HTTPException(status_code=500, detail=str(e))
 
2
  import shutil
3
  import tempfile
4
  from fastapi import APIRouter, HTTPException, UploadFile, File, Form
5
+ from fastapi.encoders import jsonable_encoder
6
+ from bson import ObjectId
7
  from models import InitializeBotResponse, DocumentPath, NewChatResponse, QueryRequest, QueryResponse
8
  from trainer_manager import get_trainer
9
  from config import CUSTOM_PROMPT
 
11
 
12
  router = APIRouter()
13
 
14
+ def select_prompt(prompt_type: str):
15
+ """Selects a prompt template based on the provided prompt_type."""
16
+ if prompt_type == "university":
17
+ return PromptTemplates.get_university_chatbot_prompt()
18
+ elif prompt_type == "quiz_solving":
19
+ return PromptTemplates.get_quiz_solving_prompt()
20
+ elif prompt_type == "assignment_solving":
21
+ return PromptTemplates.get_assignment_solving_prompt()
22
+ elif prompt_type == "paper_solving":
23
+ return PromptTemplates.get_paper_solving_prompt()
24
+ elif prompt_type == "quiz_creation":
25
+ return PromptTemplates.get_quiz_creation_prompt()
26
+ elif prompt_type == "assignment_creation":
27
+ return PromptTemplates.get_assignment_creation_prompt()
28
+ elif prompt_type == "paper_creation":
29
+ return PromptTemplates.get_paper_creation_prompt()
30
+ elif prompt_type == "check_quiz":
31
+ return PromptTemplates.get_check_quiz_prompt()
32
+ elif prompt_type == "check_assignment":
33
+ return PromptTemplates.get_check_assignment_prompt()
34
+ elif prompt_type == "check_paper":
35
+ return PromptTemplates.get_check_paper_prompt()
36
+ else:
37
+ # Fallback to a default prompt if no matching type is found
38
+ return PromptTemplates.get_quiz_solving_prompt()
39
+
40
  @router.post("/initialize_bot", response_model=InitializeBotResponse)
41
+ def initialize_bot(prompt_type: str):
42
  """
43
  Initializes a new bot with a custom prompt template.
44
+ The prompt_type query parameter determines which prompt to use.
45
  """
46
  trainer = get_trainer()
47
  try:
48
  bot_id = trainer.initialize_bot_id()
49
+ prompt = select_prompt(prompt_type)
50
+ trainer.set_custom_prompt_template(bot_id, prompt)
51
  return InitializeBotResponse(bot_id=bot_id)
52
  except Exception as e:
53
  raise HTTPException(status_code=500, detail=str(e))
54
 
55
  @router.post("/upload_document")
56
  async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...)):
57
+ """
58
+ Saves the uploaded file temporarily and adds it to the bot's knowledge base.
59
+ """
60
  trainer = get_trainer()
61
  try:
 
62
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
63
  contents = await file.read()
64
  tmp.write(contents)
65
  tmp_path = tmp.name
66
 
 
67
  trainer.add_document_from_path(tmp_path, bot_id)
 
 
68
  os.remove(tmp_path)
69
  return {"message": "Document uploaded and added successfully."}
70
  except Exception as e:
 
72
 
73
  @router.post("/create_bot/{bot_id}")
74
  def create_bot(bot_id: str):
75
+ """
76
+ Creates the bot (builds its index) for the given bot_id.
77
+ """
78
  trainer = get_trainer()
79
  try:
80
  trainer.create_bot(bot_id)
 
84
 
85
  @router.post("/new_chat/{bot_id}", response_model=NewChatResponse)
86
  def new_chat(bot_id: str):
87
+ """
88
+ Creates a new chat session for the bot.
89
+ """
90
  trainer = get_trainer()
91
  try:
92
  chat_id = trainer.new_chat(bot_id)
 
96
 
97
  @router.post("/query", response_model=QueryResponse)
98
  def send_query(query_request: QueryRequest):
99
+ """
100
+ Processes a user's query and returns the bot's response along with any web sources.
101
+ """
102
  trainer = get_trainer()
103
  try:
104
  response, web_sources = trainer.get_response(
 
110
 
111
  @router.get("/list_chats/{bot_id}")
112
  def list_chats(bot_id: str):
113
+ """
114
+ Returns a list of previous chat sessions for the specified bot.
115
+ """
116
  trainer = get_trainer()
117
  try:
118
  chats = trainer.list_chats(bot_id)
 
120
  except Exception as e:
121
  raise HTTPException(status_code=500, detail=str(e))
122
 
 
 
 
123
  @router.get("/chat_history/{chat_id}")
124
  def chat_history(chat_id: str, bot_id: str):
125
+ """
126
+ Returns the chat history for a given chat session.
127
+ ObjectId instances are converted to strings.
128
+ """
129
  trainer = get_trainer()
130
  try:
131
  history = trainer.get_chat_by_id(chat_id=chat_id)
 
132
  return jsonable_encoder(history, custom_encoder={ObjectId: str})
133
  except Exception as e:
134
  raise HTTPException(status_code=500, detail=str(e))