Spaces:
Running
Running
File size: 5,229 Bytes
d280d87 4c238b3 d280d87 4c238b3 d280d87 4c238b3 d280d87 4c238b3 d280d87 48318ed 4c238b3 48318ed d280d87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
# routes/chat.py
import uuid
from fastapi import APIRouter, HTTPException, Path
from fastapi.responses import StreamingResponse
from langchain_groq import ChatGroq
from langchain.prompts import ChatPromptTemplate
from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
import config
from models import ChatIDOut, MessageIn
router = APIRouter(prefix="/chat", tags=["chat"])
# βββ LLM & Prompt Setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def get_llm() -> ChatGroq:
if not config.CHATGROQ_API_KEY:
raise RuntimeError("CHATGROQ_API_KEY not set in environment")
return ChatGroq(
model="llama-3.3-70b-versatile",
temperature=0,
max_tokens=1024,
api_key=config.CHATGROQ_API_KEY
)
llm = get_llm()
SYSTEM_PROMPT = """
You are an assistant specialized in solving quizzes. Your goal is to provide accurate,
concise, and contextually relevant answers.
"""
qa_template = ChatPromptTemplate.from_messages(
[
("system", SYSTEM_PROMPT),
("user", "{question}"),
]
)
# βββ MongoDB History Setup βββββββββββββββββββββββββββββββββββββββββββββββββββ
chat_sessions: dict[str, MongoDBChatMessageHistory] = {}
def create_history(session_id: str) -> MongoDBChatMessageHistory:
history = MongoDBChatMessageHistory(
session_id=session_id,
connection_string=config.CONNECTION_STRING,
database_name="Education_chatbot",
collection_name="chat_histories",
)
chat_sessions[session_id] = history
return history
def get_history(session_id: str) -> MongoDBChatMessageHistory:
history = chat_sessions.get(session_id)
if not history:
raise HTTPException(status_code=404, detail="Chat session not found")
return history
# βββ Summarization (to control token use) ββββββββββββββββββββββββββββββββββββ
def summarize_if_needed(history: MongoDBChatMessageHistory, threshold: int = 10):
msgs = history.messages
if len(msgs) <= threshold:
return
summarization_prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a summarization assistant."),
("user",
"Here is the chat history:\n\n{chat_history}\n\n"
"Summarize the above chat messages into a single concise message with key details."
),
]
)
text_history = "\n".join(
f"{'User' if m.type=='human' else 'Assistant'}: {m.content}"
for m in msgs
)
summary_chain = summarization_prompt | llm
summary = summary_chain.invoke({"chat_history": text_history})
history.clear()
history.add_ai_message(f"[Summary] {summary.content}")
# βββ Endpoints ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@router.post("", response_model=ChatIDOut)
async def create_chat():
"""
Create a new chat session and return its ID.
"""
session_id = str(uuid.uuid4())
create_history(session_id)
return ChatIDOut(chat_id=session_id)
@router.post("/{chat_id}/message")
async def post_message(
chat_id: str = Path(..., description="The chat session ID"),
payload: MessageIn = None
):
"""
Send a question and stream back the assistant's answer.
"""
history = get_history(chat_id)
question = (payload and payload.question.strip()) or ""
if not question:
raise HTTPException(status_code=400, detail="Question cannot be empty")
# Summarize old turns if too long
summarize_if_needed(history)
# Build conversation for the LLM
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
for msg in history.messages:
role = "user" if msg.type == "human" else "assistant"
messages.append({"role": role, "content": msg.content})
messages.append({"role": "user", "content": question})
# Persist user turn
history.add_user_message(question)
async def stream_generator():
full_response = ""
# Pass messages list as positional 'input' to .stream()
for chunk in llm.stream(messages):
# 1) Try AIMessageChunk.content
content = getattr(chunk, "content", None)
# 2) Fallback to dict-based chunk
if content is None and isinstance(chunk, dict):
content = (
chunk.get("content")
or chunk.get("choices", [{}])[0]
.get("delta", {})
.get("content")
)
if not content:
continue
# Yield and accumulate
yield content
full_response += content
# Save final AI message
history.add_ai_message(full_response)
return StreamingResponse(stream_generator(), media_type="text/plain")
|