File size: 4,565 Bytes
982bc60
 
 
 
 
 
 
 
0240527
982bc60
 
 
0240527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
982bc60
0240527
 
 
 
 
982bc60
 
 
0240527
 
982bc60
 
 
 
 
 
 
 
0240527
982bc60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# routes.py
import os
import shutil
import tempfile
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from models import InitializeBotResponse, DocumentPath, NewChatResponse, QueryRequest, QueryResponse
from trainer_manager import get_trainer
from config import CUSTOM_PROMPT
from prompt_templates import PromptTemplates

router = APIRouter()

def select_prompt(prompt_type: str):
    """Selects a prompt template based on the provided prompt_type."""
    if prompt_type == "university":
        return PromptTemplates.get_university_chatbot_prompt()
    elif prompt_type == "quiz_solving":
        return PromptTemplates.get_quiz_solving_prompt()
    elif prompt_type == "assignment_solving":
        return PromptTemplates.get_assignment_solving_prompt()
    elif prompt_type == "paper_solving":
        return PromptTemplates.get_paper_solving_prompt()
    elif prompt_type == "quiz_creation":
        return PromptTemplates.get_quiz_creation_prompt()
    elif prompt_type == "assignment_creation":
        return PromptTemplates.get_assignment_creation_prompt()
    elif prompt_type == "paper_creation":
        return PromptTemplates.get_paper_creation_prompt()
    elif prompt_type == "check_quiz":
        return PromptTemplates.get_check_quiz_prompt()
    elif prompt_type == "check_assignment":
        return PromptTemplates.get_check_assignment_prompt()
    elif prompt_type == "check_paper":
        return PromptTemplates.get_check_paper_prompt()
    else:
        # Fallback to a default prompt if no matching type is found
        return PromptTemplates.get_quiz_solving_prompt()

@router.post("/initialize_bot", response_model=InitializeBotResponse)
def initialize_bot(prompt_type: str = "quiz_solving"):
    """
    Initializes a new bot with a custom prompt template.
    The prompt_type query parameter determines which prompt to use.
    """
    trainer = get_trainer()
    try:
        bot_id = trainer.initialize_bot_id()
        selected_prompt = select_prompt(prompt_type)
        trainer.set_custom_prompt_template(bot_id, selected_prompt)
        return InitializeBotResponse(bot_id=bot_id)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@router.post("/upload_document")
async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...)):
    trainer = get_trainer()
    try:
        # Save the uploaded file to a temporary location
        with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
            contents = await file.read()
            tmp.write(contents)
            tmp_path = tmp.name

        # Add the document from the temporary file path
        trainer.add_document_from_path(tmp_path, bot_id)

        # Remove the temporary file
        os.remove(tmp_path)
        return {"message": "Document uploaded and added successfully."}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@router.post("/create_bot/{bot_id}")
def create_bot(bot_id: str):
    trainer = get_trainer()
    try:
        trainer.create_bot(bot_id)
        return {"message": f"Bot {bot_id} created successfully."}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@router.post("/new_chat/{bot_id}", response_model=NewChatResponse)
def new_chat(bot_id: str):
    trainer = get_trainer()
    try:
        chat_id = trainer.new_chat(bot_id)
        return NewChatResponse(chat_id=chat_id)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@router.post("/query", response_model=QueryResponse)
def send_query(query_request: QueryRequest):
    trainer = get_trainer()
    try:
        response, web_sources = trainer.get_response(
            query_request.query, query_request.bot_id, query_request.chat_id
        )
        return QueryResponse(response=response, web_sources=web_sources)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@router.get("/list_chats/{bot_id}")
def list_chats(bot_id: str):
    trainer = get_trainer()
    try:
        chats = trainer.list_chats(bot_id)
        return chats
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@router.get("/chat_history/{chat_id}")
def chat_history(chat_id: str, bot_id: str):
    trainer = get_trainer()
    try:
        history = trainer.get_chat_by_id(chat_id=chat_id)
        return history
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))