mominah commited on
Commit
d280d87
Β·
verified Β·
1 Parent(s): b8a0141

Create chat.py

Browse files
Files changed (1) hide show
  1. chat.py +133 -0
chat.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # routes/chat.py
2
+ import uuid
3
+ from fastapi import APIRouter, HTTPException, Path
4
+ from fastapi.responses import StreamingResponse
5
+ from langchain_groq import ChatGroq
6
+ from langchain.prompts import ChatPromptTemplate
7
+ from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
8
+
9
+ import config
10
+ from models import ChatIDOut, MessageIn
11
+
12
+ router = APIRouter(prefix="/chat", tags=["chat"])
13
+
14
+ # ─── LLM & Prompt Setup ─────────────────────────────────────────────────
15
+ def get_llm() -> ChatGroq:
16
+ if not config.CHATGROQ_API_KEY:
17
+ raise RuntimeError("CHATGROQ_API_KEY not set in environment")
18
+ return ChatGroq(
19
+ model="llama-3.3-70b-versatile",
20
+ temperature=0,
21
+ max_tokens=1024,
22
+ api_key=config.CHATGROQ_API_KEY
23
+ )
24
+
25
+ llm = get_llm()
26
+
27
+ SYSTEM_PROMPT = """
28
+ You are an assistant specialized in solving quizzes. Your goal is to provide accurate,
29
+ concise, and contextually relevant answers.
30
+ """
31
+ qa_template = ChatPromptTemplate.from_messages(
32
+ [
33
+ ("system", SYSTEM_PROMPT),
34
+ ("user", "{question}"),
35
+ ]
36
+ )
37
+
38
+ # ─── MongoDB History Setup ────────────────────────────────────────────────
39
+ chat_sessions: dict[str, MongoDBChatMessageHistory] = {}
40
+
41
+ def create_history(session_id: str) -> MongoDBChatMessageHistory:
42
+ history = MongoDBChatMessageHistory(
43
+ session_id=session_id,
44
+ connection_string=config.CONNECTION_STRING,
45
+ database_name="Education_chatbot",
46
+ collection_name="chat_histories",
47
+ )
48
+ chat_sessions[session_id] = history
49
+ return history
50
+
51
+ def get_history(session_id: str) -> MongoDBChatMessageHistory:
52
+ history = chat_sessions.get(session_id)
53
+ if not history:
54
+ raise HTTPException(status_code=404, detail="Chat session not found")
55
+ return history
56
+
57
+ # ─── Summarization (to control token use) ─────────────────────────────────
58
+ def summarize_if_needed(history: MongoDBChatMessageHistory, threshold: int = 10):
59
+ msgs = history.messages
60
+ if len(msgs) <= threshold:
61
+ return
62
+
63
+ summarization_prompt = ChatPromptTemplate.from_messages(
64
+ [
65
+ ("system", "You are a summarization assistant."),
66
+ ("user",
67
+ "Here is the chat history:\n\n{chat_history}\n\n"
68
+ "Summarize the above chat messages into a single concise message with key details."
69
+ ),
70
+ ]
71
+ )
72
+ text_history = "\n".join(
73
+ f"{'User' if m.type=='human' else 'Assistant'}: {m.content}"
74
+ for m in msgs
75
+ )
76
+ summary_chain = summarization_prompt | llm
77
+ summary = summary_chain.invoke({"chat_history": text_history})
78
+
79
+ history.clear()
80
+ history.add_ai_message(f"[Summary] {summary.content}")
81
+
82
+ # ─── Endpoints ────────────────────────────────────────────────────────────
83
+
84
+ @router.post("", response_model=ChatIDOut)
85
+ async def create_chat():
86
+ """
87
+ Create a new chat session and return its ID.
88
+ """
89
+ session_id = str(uuid.uuid4())
90
+ create_history(session_id)
91
+ return ChatIDOut(chat_id=session_id)
92
+
93
+ @router.post("/{chat_id}/message")
94
+ async def post_message(
95
+ chat_id: str = Path(..., description="The chat session ID"),
96
+ payload: MessageIn = None
97
+ ):
98
+ """
99
+ Send a question and stream back the assistant's answer.
100
+ """
101
+ history = get_history(chat_id)
102
+ question = (payload and payload.question.strip()) or ""
103
+ if not question:
104
+ raise HTTPException(status_code=400, detail="Question cannot be empty")
105
+
106
+ # Summarize old turns if too long
107
+ summarize_if_needed(history)
108
+
109
+ # Build conversation for the LLM
110
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
111
+ for msg in history.messages:
112
+ role = "user" if msg.type == "human" else "assistant"
113
+ messages.append({"role": role, "content": msg.content})
114
+ messages.append({"role": "user", "content": question})
115
+
116
+ # Persist user turn
117
+ history.add_user_message(question)
118
+
119
+ async def stream_generator():
120
+ full_response = ""
121
+ for chunk in llm.stream(messages=messages):
122
+ # adjust based on actual ChatGroq chunk schema
123
+ content = (
124
+ chunk.get("content")
125
+ or chunk.get("choices", [{}])[0].get("delta", {}).get("content")
126
+ )
127
+ if content:
128
+ yield content
129
+ full_response += content
130
+ # save final AI message
131
+ history.add_ai_message(full_response)
132
+
133
+ return StreamingResponse(stream_generator(), media_type="text/plain")