mominah commited on
Commit
57c8644
·
verified ·
1 Parent(s): 4d15b46

Update routes.py

Browse files
Files changed (1) hide show
  1. routes.py +65 -76
routes.py CHANGED
@@ -4,86 +4,82 @@ import tempfile
4
  from fastapi import APIRouter, HTTPException, UploadFile, File, Form
5
  from fastapi.encoders import jsonable_encoder
6
  from bson import ObjectId
7
- from models import InitializeBotResponse, DocumentPath, NewChatResponse, QueryRequest, QueryResponse
 
 
 
8
  from trainer_manager import get_trainer
9
- from config import CUSTOM_PROMPT
10
  from prompt_templates import PromptTemplates
11
 
12
  router = APIRouter()
13
  trainer = get_trainer()
 
 
14
  @router.post("/initialize_bot", response_model=InitializeBotResponse)
15
- def initialize_bot(prompt_type: str):
16
  """
17
- Initializes a new bot with a custom prompt template.
18
- The custom_prompt parameter should be provided by the frontend (for example, via navigation).
19
  """
20
-
21
  try:
22
  bot_id = trainer.initialize_bot_id()
23
-
24
- # Assign the prompt template based on the prompt type.
25
- if prompt_type == "university":
26
- prompt_template = PromptTemplates.get_university_chatbot_prompt()
27
- elif prompt_type == "quiz_solving":
28
- prompt_template = PromptTemplates.get_quiz_solving_prompt()
29
- elif prompt_type == "assignment_solving":
30
- prompt_template = PromptTemplates.get_assignment_solving_prompt()
31
- elif prompt_type == "paper_solving":
32
- prompt_template = PromptTemplates.get_paper_solving_prompt()
33
- elif prompt_type == "quiz_creation":
34
- prompt_template = PromptTemplates.get_quiz_creation_prompt()
35
- elif prompt_type == "assignment_creation":
36
- prompt_template = PromptTemplates.get_assignment_creation_prompt()
37
- elif prompt_type == "paper_creation":
38
- prompt_template = PromptTemplates.get_paper_creation_prompt()
39
- elif prompt_type == "check_quiz":
40
- prompt_template = PromptTemplates.get_check_quiz_prompt()
41
- elif prompt_type == "check_assignment":
42
- prompt_template = PromptTemplates.get_check_assignment_prompt()
43
- elif prompt_type == "check_paper":
44
- prompt_template = PromptTemplates.get_check_paper_prompt()
45
- else:
46
- # Default prompt template if prompt_type is not recognized.
47
- prompt_template = PromptTemplates.get_quiz_solving_prompt()
48
-
49
- # Pass the assigned prompt template to the bot.
50
- trainer.set_custom_prompt_template(bot_id, prompt_template)
51
  return InitializeBotResponse(bot_id=bot_id)
52
  except Exception as e:
53
  raise HTTPException(status_code=500, detail=str(e))
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  @router.post("/upload_document")
57
  async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...)):
58
  """
59
  Saves the uploaded file temporarily and adds it to the bot's knowledge base.
60
  """
61
  try:
62
- # Save the file to a temporary location
63
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
64
  contents = await file.read()
65
  tmp.write(contents)
66
  tmp_path = tmp.name
67
 
68
- # Add the document using the temporary file path
69
  trainer.add_document_from_path(tmp_path, bot_id)
70
-
71
- # Remove the temporary file
72
  os.remove(tmp_path)
 
73
  return {"message": "Document uploaded and added successfully."}
74
  except Exception as e:
75
  raise HTTPException(status_code=500, detail=str(e))
76
 
77
- @router.post("/create_bot/{bot_id}")
78
- def create_bot(bot_id: str):
79
- """
80
- Creates (builds) the bot (e.g., builds its index) for the given bot_id.
81
- """
82
- try:
83
- trainer.create_bot(bot_id)
84
- return {"message": f"Bot {bot_id} created successfully."}
85
- except Exception as e:
86
- raise HTTPException(status_code=500, detail=str(e))
87
 
88
  @router.post("/new_chat/{bot_id}", response_model=NewChatResponse)
89
  def new_chat(bot_id: str):
@@ -96,47 +92,40 @@ def new_chat(bot_id: str):
96
  except Exception as e:
97
  raise HTTPException(status_code=500, detail=str(e))
98
 
 
99
  @router.post("/query", response_model=QueryResponse)
100
  def send_query(query_request: QueryRequest):
101
  """
102
  Processes a query and returns the bot's response along with any web sources.
103
  """
104
- # trainer = get_trainer()
105
- # trainer.load_bot(query_request.bot_id)
106
- prompt_type = query_request.prompt_type
107
  try:
108
- # Assign the prompt template based on the prompt type.
109
- if prompt_type == "university":
110
- prompt_template = PromptTemplates.get_university_chatbot_prompt()
111
- elif prompt_type == "quiz_solving":
112
- prompt_template = PromptTemplates.get_quiz_solving_prompt()
113
- elif prompt_type == "assignment_solving":
114
- prompt_template = PromptTemplates.get_assignment_solving_prompt()
115
- elif prompt_type == "paper_solving":
116
- prompt_template = PromptTemplates.get_paper_solving_prompt()
117
- elif prompt_type == "quiz_creation":
118
- prompt_template = PromptTemplates.get_quiz_creation_prompt()
119
- elif prompt_type == "assignment_creation":
120
- prompt_template = PromptTemplates.get_assignment_creation_prompt()
121
- elif prompt_type == "paper_creation":
122
- prompt_template = PromptTemplates.get_paper_creation_prompt()
123
- elif prompt_type == "check_quiz":
124
- prompt_template = PromptTemplates.get_check_quiz_prompt()
125
- elif prompt_type == "check_assignment":
126
- prompt_template = PromptTemplates.get_check_assignment_prompt()
127
- elif prompt_type == "check_paper":
128
- prompt_template = PromptTemplates.get_check_paper_prompt()
129
- else:
130
- # Default prompt template if prompt_type is not recognized.
131
- prompt_template = PromptTemplates.get_quiz_solving_prompt()
132
  trainer.set_custom_prompt_template(query_request.bot_id, prompt_template)
 
133
  response, web_sources = trainer.get_response(
134
  query_request.query, query_request.bot_id, query_request.chat_id
135
  )
 
136
  return QueryResponse(response=response, web_sources=web_sources)
137
  except Exception as e:
138
  raise HTTPException(status_code=500, detail=str(e))
139
 
 
140
  @router.get("/list_chats/{bot_id}")
141
  def list_chats(bot_id: str):
142
  """
@@ -148,11 +137,11 @@ def list_chats(bot_id: str):
148
  except Exception as e:
149
  raise HTTPException(status_code=500, detail=str(e))
150
 
 
151
  @router.get("/chat_history/{chat_id}")
152
  def chat_history(chat_id: str, bot_id: str):
153
  """
154
  Returns the chat history for a given chat session.
155
- ObjectId instances in the history are converted to strings.
156
  """
157
  try:
158
  history = trainer.get_chat_by_id(chat_id=chat_id)
 
4
  from fastapi import APIRouter, HTTPException, UploadFile, File, Form
5
  from fastapi.encoders import jsonable_encoder
6
  from bson import ObjectId
7
+ from models import (
8
+ InitializeBotResponse, DocumentPath, NewChatResponse,
9
+ QueryRequest, QueryResponse, CreateBotRequest, CreateBotResponse
10
+ )
11
  from trainer_manager import get_trainer
 
12
  from prompt_templates import PromptTemplates
13
 
14
  router = APIRouter()
15
  trainer = get_trainer()
16
+
17
+
18
  @router.post("/initialize_bot", response_model=InitializeBotResponse)
19
+ def initialize_bot():
20
  """
21
+ Initializes a new bot and returns the bot_id.
22
+ The prompt type is now passed during bot creation, not initialization.
23
  """
 
24
  try:
25
  bot_id = trainer.initialize_bot_id()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  return InitializeBotResponse(bot_id=bot_id)
27
  except Exception as e:
28
  raise HTTPException(status_code=500, detail=str(e))
29
 
30
 
31
+ @router.post("/create_bot", response_model=CreateBotResponse)
32
+ def create_bot(data: CreateBotRequest):
33
+ """
34
+ Finalizes bot creation by assigning a prompt template based on the prompt type.
35
+ """
36
+ try:
37
+ # Assign the appropriate prompt template based on prompt_type
38
+ prompt_type = data.prompt_type.lower()
39
+ prompt_template_mapping = {
40
+ "university": PromptTemplates.get_university_chatbot_prompt,
41
+ "quiz_solving": PromptTemplates.get_quiz_solving_prompt,
42
+ "assignment_solving": PromptTemplates.get_assignment_solving_prompt,
43
+ "paper_solving": PromptTemplates.get_paper_solving_prompt,
44
+ "quiz_creation": PromptTemplates.get_quiz_creation_prompt,
45
+ "assignment_creation": PromptTemplates.get_assignment_creation_prompt,
46
+ "paper_creation": PromptTemplates.get_paper_creation_prompt,
47
+ "check_quiz": PromptTemplates.get_check_quiz_prompt,
48
+ "check_assignment": PromptTemplates.get_check_assignment_prompt,
49
+ "check_paper": PromptTemplates.get_check_paper_prompt
50
+ }
51
+
52
+ prompt_template = prompt_template_mapping.get(prompt_type, PromptTemplates.get_quiz_solving_prompt)()
53
+
54
+ # Apply the selected prompt template to the bot
55
+ trainer.set_custom_prompt_template(data.bot_id, prompt_template)
56
+
57
+ # Complete bot creation
58
+ trainer.create_bot(data.bot_id)
59
+
60
+ return CreateBotResponse(bot_id=data.bot_id)
61
+ except Exception as e:
62
+ raise HTTPException(status_code=500, detail=str(e))
63
+
64
+
65
  @router.post("/upload_document")
66
  async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...)):
67
  """
68
  Saves the uploaded file temporarily and adds it to the bot's knowledge base.
69
  """
70
  try:
 
71
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
72
  contents = await file.read()
73
  tmp.write(contents)
74
  tmp_path = tmp.name
75
 
 
76
  trainer.add_document_from_path(tmp_path, bot_id)
 
 
77
  os.remove(tmp_path)
78
+
79
  return {"message": "Document uploaded and added successfully."}
80
  except Exception as e:
81
  raise HTTPException(status_code=500, detail=str(e))
82
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  @router.post("/new_chat/{bot_id}", response_model=NewChatResponse)
85
  def new_chat(bot_id: str):
 
92
  except Exception as e:
93
  raise HTTPException(status_code=500, detail=str(e))
94
 
95
+
96
  @router.post("/query", response_model=QueryResponse)
97
  def send_query(query_request: QueryRequest):
98
  """
99
  Processes a query and returns the bot's response along with any web sources.
100
  """
 
 
 
101
  try:
102
+ prompt_type = query_request.prompt_type.lower()
103
+ prompt_template_mapping = {
104
+ "university": PromptTemplates.get_university_chatbot_prompt,
105
+ "quiz_solving": PromptTemplates.get_quiz_solving_prompt,
106
+ "assignment_solving": PromptTemplates.get_assignment_solving_prompt,
107
+ "paper_solving": PromptTemplates.get_paper_solving_prompt,
108
+ "quiz_creation": PromptTemplates.get_quiz_creation_prompt,
109
+ "assignment_creation": PromptTemplates.get_assignment_creation_prompt,
110
+ "paper_creation": PromptTemplates.get_paper_creation_prompt,
111
+ "check_quiz": PromptTemplates.get_check_quiz_prompt,
112
+ "check_assignment": PromptTemplates.get_check_assignment_prompt,
113
+ "check_paper": PromptTemplates.get_check_paper_prompt
114
+ }
115
+
116
+ prompt_template = prompt_template_mapping.get(prompt_type, PromptTemplates.get_quiz_solving_prompt)()
117
+
 
 
 
 
 
 
 
 
118
  trainer.set_custom_prompt_template(query_request.bot_id, prompt_template)
119
+
120
  response, web_sources = trainer.get_response(
121
  query_request.query, query_request.bot_id, query_request.chat_id
122
  )
123
+
124
  return QueryResponse(response=response, web_sources=web_sources)
125
  except Exception as e:
126
  raise HTTPException(status_code=500, detail=str(e))
127
 
128
+
129
  @router.get("/list_chats/{bot_id}")
130
  def list_chats(bot_id: str):
131
  """
 
137
  except Exception as e:
138
  raise HTTPException(status_code=500, detail=str(e))
139
 
140
+
141
  @router.get("/chat_history/{chat_id}")
142
  def chat_history(chat_id: str, bot_id: str):
143
  """
144
  Returns the chat history for a given chat session.
 
145
  """
146
  try:
147
  history = trainer.get_chat_by_id(chat_id=chat_id)