File size: 5,229 Bytes
d280d87
 
 
 
 
 
 
 
 
 
 
 
 
4c238b3
d280d87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c238b3
d280d87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c238b3
d280d87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c238b3
d280d87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48318ed
4c238b3
48318ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d280d87
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# routes/chat.py
import uuid
from fastapi import APIRouter, HTTPException, Path
from fastapi.responses import StreamingResponse
from langchain_groq import ChatGroq
from langchain.prompts import ChatPromptTemplate
from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory

import config
from models import ChatIDOut, MessageIn

router = APIRouter(prefix="/chat", tags=["chat"])

# ─── LLM & Prompt Setup ──────────────────────────────────────────────────────
def get_llm() -> ChatGroq:
    if not config.CHATGROQ_API_KEY:
        raise RuntimeError("CHATGROQ_API_KEY not set in environment")
    return ChatGroq(
        model="llama-3.3-70b-versatile",
        temperature=0,
        max_tokens=1024,
        api_key=config.CHATGROQ_API_KEY
    )

llm = get_llm()

SYSTEM_PROMPT = """
You are an assistant specialized in solving quizzes. Your goal is to provide accurate,
concise, and contextually relevant answers.
"""
qa_template = ChatPromptTemplate.from_messages(
    [
        ("system", SYSTEM_PROMPT),
        ("user", "{question}"),
    ]
)

# ─── MongoDB History Setup ───────────────────────────────────────────────────
chat_sessions: dict[str, MongoDBChatMessageHistory] = {}

def create_history(session_id: str) -> MongoDBChatMessageHistory:
    history = MongoDBChatMessageHistory(
        session_id=session_id,
        connection_string=config.CONNECTION_STRING,
        database_name="Education_chatbot",
        collection_name="chat_histories",
    )
    chat_sessions[session_id] = history
    return history

def get_history(session_id: str) -> MongoDBChatMessageHistory:
    history = chat_sessions.get(session_id)
    if not history:
        raise HTTPException(status_code=404, detail="Chat session not found")
    return history

# ─── Summarization (to control token use) ────────────────────────────────────
def summarize_if_needed(history: MongoDBChatMessageHistory, threshold: int = 10):
    msgs = history.messages
    if len(msgs) <= threshold:
        return

    summarization_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", "You are a summarization assistant."),
            ("user",
                "Here is the chat history:\n\n{chat_history}\n\n"
                "Summarize the above chat messages into a single concise message with key details."
            ),
        ]
    )
    text_history = "\n".join(
        f"{'User' if m.type=='human' else 'Assistant'}: {m.content}"
        for m in msgs
    )
    summary_chain = summarization_prompt | llm
    summary = summary_chain.invoke({"chat_history": text_history})

    history.clear()
    history.add_ai_message(f"[Summary] {summary.content}")

# ─── Endpoints ──────────────────────────────────────────────────────────────

@router.post("", response_model=ChatIDOut)
async def create_chat():
    """
    Create a new chat session and return its ID.
    """
    session_id = str(uuid.uuid4())
    create_history(session_id)
    return ChatIDOut(chat_id=session_id)

@router.post("/{chat_id}/message")
async def post_message(
    chat_id: str = Path(..., description="The chat session ID"),
    payload: MessageIn = None
):
    """
    Send a question and stream back the assistant's answer.
    """
    history = get_history(chat_id)
    question = (payload and payload.question.strip()) or ""
    if not question:
        raise HTTPException(status_code=400, detail="Question cannot be empty")

    # Summarize old turns if too long
    summarize_if_needed(history)

    # Build conversation for the LLM
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    for msg in history.messages:
        role = "user" if msg.type == "human" else "assistant"
        messages.append({"role": role, "content": msg.content})
    messages.append({"role": "user", "content": question})

    # Persist user turn
    history.add_user_message(question)

    async def stream_generator():
        full_response = ""
        # Pass messages list as positional 'input' to .stream()
        for chunk in llm.stream(messages):
            # 1) Try AIMessageChunk.content
            content = getattr(chunk, "content", None)
            # 2) Fallback to dict-based chunk
            if content is None and isinstance(chunk, dict):
                content = (
                    chunk.get("content")
                    or chunk.get("choices", [{}])[0]
                             .get("delta", {})
                             .get("content")
                )
            if not content:
                continue
            # Yield and accumulate
            yield content
            full_response += content

        # Save final AI message
        history.add_ai_message(full_response)

    return StreamingResponse(stream_generator(), media_type="text/plain")