mominah commited on
Commit
4876c75
·
verified ·
1 Parent(s): 9caf3b3

Update routes.py

Browse files
Files changed (1) hide show
  1. routes.py +75 -44
routes.py CHANGED
@@ -4,22 +4,18 @@ import tempfile
4
  from fastapi import APIRouter, HTTPException, UploadFile, File, Form
5
  from fastapi.encoders import jsonable_encoder
6
  from bson import ObjectId
7
- from models import (
8
- InitializeBotResponse, DocumentPath, NewChatResponse,
9
- QueryRequest, QueryResponse, CreateBotRequest, CreateBotResponse
10
- )
11
  from trainer_manager import get_trainer
 
12
  from prompt_templates import PromptTemplates
13
 
14
  router = APIRouter()
15
  trainer = get_trainer()
16
 
17
-
18
  @router.post("/initialize_bot", response_model=InitializeBotResponse)
19
  def initialize_bot():
20
  """
21
- Initializes a new bot and returns the bot_id.
22
- The prompt type is now passed during bot creation, not initialization.
23
  """
24
  try:
25
  bot_id = trainer.initialize_bot_id()
@@ -28,59 +24,68 @@ def initialize_bot():
28
  raise HTTPException(status_code=500, detail=str(e))
29
 
30
 
31
- @router.post("/create_bot", response_model=CreateBotResponse)
32
- def create_bot(data: CreateBotRequest):
33
- """
34
- Finalizes bot creation by assigning a prompt template based on the prompt type.
35
- """
36
- try:
37
- # Assign the appropriate prompt template based on prompt_type
38
- prompt_type = data.prompt_type.lower()
39
- prompt_template_mapping = {
40
- "university": PromptTemplates.get_university_chatbot_prompt,
41
- "quiz_solving": PromptTemplates.get_quiz_solving_prompt,
42
- "assignment_solving": PromptTemplates.get_assignment_solving_prompt,
43
- "paper_solving": PromptTemplates.get_paper_solving_prompt,
44
- "quiz_creation": PromptTemplates.get_quiz_creation_prompt,
45
- "assignment_creation": PromptTemplates.get_assignment_creation_prompt,
46
- "paper_creation": PromptTemplates.get_paper_creation_prompt,
47
- "check_quiz": PromptTemplates.get_check_quiz_prompt,
48
- "check_assignment": PromptTemplates.get_check_assignment_prompt,
49
- "check_paper": PromptTemplates.get_check_paper_prompt
50
- }
51
-
52
- prompt_template = prompt_template_mapping.get(prompt_type, PromptTemplates.get_quiz_solving_prompt)()
53
-
54
- # Apply the selected prompt template to the bot
55
- trainer.set_custom_prompt_template(data.bot_id, prompt_template)
56
-
57
- # Complete bot creation
58
- trainer.create_bot(data.bot_id)
59
-
60
- return CreateBotResponse(bot_id=data.bot_id)
61
- except Exception as e:
62
- raise HTTPException(status_code=500, detail=str(e))
63
-
64
-
65
  @router.post("/upload_document")
66
  async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...)):
67
  """
68
  Saves the uploaded file temporarily and adds it to the bot's knowledge base.
69
  """
70
  try:
 
71
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
72
  contents = await file.read()
73
  tmp.write(contents)
74
  tmp_path = tmp.name
75
 
 
76
  trainer.add_document_from_path(tmp_path, bot_id)
77
- os.remove(tmp_path)
78
 
 
 
79
  return {"message": "Document uploaded and added successfully."}
80
  except Exception as e:
81
  raise HTTPException(status_code=500, detail=str(e))
82
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  @router.post("/new_chat/{bot_id}", response_model=NewChatResponse)
85
  def new_chat(bot_id: str):
86
  """
@@ -98,17 +103,42 @@ def send_query(query_request: QueryRequest):
98
  """
99
  Processes a query and returns the bot's response along with any web sources.
100
  """
 
101
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  response, web_sources = trainer.get_response(
103
  query_request.query, query_request.bot_id, query_request.chat_id
104
  )
105
-
106
  return QueryResponse(response=response, web_sources=web_sources)
107
  except Exception as e:
108
  raise HTTPException(status_code=500, detail=str(e))
109
 
110
 
111
-
112
  @router.get("/list_chats/{bot_id}")
113
  def list_chats(bot_id: str):
114
  """
@@ -125,6 +155,7 @@ def list_chats(bot_id: str):
125
  def chat_history(chat_id: str, bot_id: str):
126
  """
127
  Returns the chat history for a given chat session.
 
128
  """
129
  try:
130
  history = trainer.get_chat_by_id(chat_id=chat_id)
 
4
  from fastapi import APIRouter, HTTPException, UploadFile, File, Form
5
  from fastapi.encoders import jsonable_encoder
6
  from bson import ObjectId
7
+ from models import InitializeBotResponse, DocumentPath, NewChatResponse, QueryRequest, QueryResponse
 
 
 
8
  from trainer_manager import get_trainer
9
+ from config import CUSTOM_PROMPT
10
  from prompt_templates import PromptTemplates
11
 
12
  router = APIRouter()
13
  trainer = get_trainer()
14
 
 
15
  @router.post("/initialize_bot", response_model=InitializeBotResponse)
16
  def initialize_bot():
17
  """
18
+ Initializes a new bot and returns the generated bot_id.
 
19
  """
20
  try:
21
  bot_id = trainer.initialize_bot_id()
 
24
  raise HTTPException(status_code=500, detail=str(e))
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  @router.post("/upload_document")
28
  async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...)):
29
  """
30
  Saves the uploaded file temporarily and adds it to the bot's knowledge base.
31
  """
32
  try:
33
+ # Save the file to a temporary location.
34
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
35
  contents = await file.read()
36
  tmp.write(contents)
37
  tmp_path = tmp.name
38
 
39
+ # Add the document using the temporary file path.
40
  trainer.add_document_from_path(tmp_path, bot_id)
 
41
 
42
+ # Remove the temporary file.
43
+ os.remove(tmp_path)
44
  return {"message": "Document uploaded and added successfully."}
45
  except Exception as e:
46
  raise HTTPException(status_code=500, detail=str(e))
47
 
48
 
49
+ @router.post("/create_bot/{bot_id}")
50
+ def create_bot(bot_id: str, prompt_type: str = Form(...)):
51
+ """
52
+ Creates (builds) the bot for the given bot_id using the specified prompt type.
53
+ This endpoint assigns the appropriate custom prompt template before finalizing the bot creation.
54
+ """
55
+ try:
56
+ # Determine the appropriate prompt template based on the provided prompt_type.
57
+ if prompt_type == "university":
58
+ prompt_template = PromptTemplates.get_university_chatbot_prompt()
59
+ elif prompt_type == "quiz_solving":
60
+ prompt_template = PromptTemplates.get_quiz_solving_prompt()
61
+ elif prompt_type == "assignment_solving":
62
+ prompt_template = PromptTemplates.get_assignment_solving_prompt()
63
+ elif prompt_type == "paper_solving":
64
+ prompt_template = PromptTemplates.get_paper_solving_prompt()
65
+ elif prompt_type == "quiz_creation":
66
+ prompt_template = PromptTemplates.get_quiz_creation_prompt()
67
+ elif prompt_type == "assignment_creation":
68
+ prompt_template = PromptTemplates.get_assignment_creation_prompt()
69
+ elif prompt_type == "paper_creation":
70
+ prompt_template = PromptTemplates.get_paper_creation_prompt()
71
+ elif prompt_type == "check_quiz":
72
+ prompt_template = PromptTemplates.get_check_quiz_prompt()
73
+ elif prompt_type == "check_assignment":
74
+ prompt_template = PromptTemplates.get_check_assignment_prompt()
75
+ elif prompt_type == "check_paper":
76
+ prompt_template = PromptTemplates.get_check_paper_prompt()
77
+ else:
78
+ # Default prompt template if prompt_type is not recognized.
79
+ prompt_template = PromptTemplates.get_quiz_solving_prompt()
80
+
81
+ # Set the custom prompt template and create the bot.
82
+ trainer.set_custom_prompt_template(bot_id, prompt_template)
83
+ trainer.create_bot(bot_id)
84
+ return {"message": f"Bot {bot_id} created successfully."}
85
+ except Exception as e:
86
+ raise HTTPException(status_code=500, detail=str(e))
87
+
88
+
89
  @router.post("/new_chat/{bot_id}", response_model=NewChatResponse)
90
  def new_chat(bot_id: str):
91
  """
 
103
  """
104
  Processes a query and returns the bot's response along with any web sources.
105
  """
106
+ prompt_type = query_request.prompt_type
107
  try:
108
+ # Assign the prompt template based on the provided prompt_type.
109
+ if prompt_type == "university":
110
+ prompt_template = PromptTemplates.get_university_chatbot_prompt()
111
+ elif prompt_type == "quiz_solving":
112
+ prompt_template = PromptTemplates.get_quiz_solving_prompt()
113
+ elif prompt_type == "assignment_solving":
114
+ prompt_template = PromptTemplates.get_assignment_solving_prompt()
115
+ elif prompt_type == "paper_solving":
116
+ prompt_template = PromptTemplates.get_paper_solving_prompt()
117
+ elif prompt_type == "quiz_creation":
118
+ prompt_template = PromptTemplates.get_quiz_creation_prompt()
119
+ elif prompt_type == "assignment_creation":
120
+ prompt_template = PromptTemplates.get_assignment_creation_prompt()
121
+ elif prompt_type == "paper_creation":
122
+ prompt_template = PromptTemplates.get_paper_creation_prompt()
123
+ elif prompt_type == "check_quiz":
124
+ prompt_template = PromptTemplates.get_check_quiz_prompt()
125
+ elif prompt_type == "check_assignment":
126
+ prompt_template = PromptTemplates.get_check_assignment_prompt()
127
+ elif prompt_type == "check_paper":
128
+ prompt_template = PromptTemplates.get_check_paper_prompt()
129
+ else:
130
+ # Default prompt template if prompt_type is not recognized.
131
+ prompt_template = PromptTemplates.get_quiz_solving_prompt()
132
+
133
+ trainer.set_custom_prompt_template(query_request.bot_id, prompt_template)
134
  response, web_sources = trainer.get_response(
135
  query_request.query, query_request.bot_id, query_request.chat_id
136
  )
 
137
  return QueryResponse(response=response, web_sources=web_sources)
138
  except Exception as e:
139
  raise HTTPException(status_code=500, detail=str(e))
140
 
141
 
 
142
  @router.get("/list_chats/{bot_id}")
143
  def list_chats(bot_id: str):
144
  """
 
155
  def chat_history(chat_id: str, bot_id: str):
156
  """
157
  Returns the chat history for a given chat session.
158
+ ObjectId instances in the history are converted to strings.
159
  """
160
  try:
161
  history = trainer.get_chat_by_id(chat_id=chat_id)