raq_chatbot / app.py
vanaraj07's picture
app.py
d17cb8f verified
import gradio as gr
import os
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer, CrossEncoder
import openai
csv_path = 'train_data.csv'
if not os.path.isfile(csv_path):
raise FileNotFoundError(f"Could not find CSV at {csv_path}")
df = pd.read_csv(csv_path, on_bad_lines='skip').dropna()
df.columns = ['Question', 'Answer']
# STEP 3: Build TF-IDF structures (same)
questions = df['Question'].tolist()
answers = df['Answer'].tolist()
qa_pairs = [f"Q: {q}\nA: {a}" for q, a in zip(questions, answers)]
tfidf = TfidfVectorizer(max_features=5000).fit(questions)
tfidf_matrix = tfidf.transform(questions)
# STEP 4: Enhanced Embedding of Q+A pairs
embedder = SentenceTransformer("all-mpnet-base-v2")
qa_embeddings = embedder.encode(qa_pairs, convert_to_numpy=True)
dim = qa_embeddings.shape[1]
index = faiss.IndexHNSWFlat(dim, 32)
index.hnsw.efConstruction = 200
index.add(qa_embeddings)
# STEP 5: Together AI Setup (same)
openai.api_key = "cfbafb6a338787841b0295fa7fbe0e4acca77b70ccc3d92bafea2004783b93a3"
openai.api_base = "https://api.together.xyz/v1"
# STEP 6: Smarter Hybrid Context Retriever
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2")
def get_top_k_matches(query, lex_n=50, sem_k=20, ce_k=5):
# Lexical filter
q_tfidf = tfidf.transform([query])
lex_scores = cosine_similarity(q_tfidf, tfidf_matrix).flatten()
lex_idxs = np.argsort(lex_scores)[-lex_n:][::-1]
# Embed query
q_emb = embedder.encode([query], convert_to_numpy=True)
sub_embs = qa_embeddings[lex_idxs]
dists = np.linalg.norm(sub_embs - q_emb, axis=1)
top_sem_idxs = np.argsort(dists)[:sem_k]
cand_idxs = [lex_idxs[i] for i in top_sem_idxs]
# Cross-encoder for precision rerank
candidates = [qa_pairs[i] for i in cand_idxs]
pairs = [[query, cand] for cand in candidates]
ce_scores = cross_encoder.predict(pairs)
scored = sorted(zip(ce_scores, candidates), reverse=True)
top_contexts = [ctx for _, ctx in scored[:ce_k]]
return top_contexts
# STEP 7: Smart Prompt Generator (unchanged)
def generate_prompt(user_query, context):
return f"""
You are a smart and friendly assistant helping students with academic-related queries.
Below is a question from a student. You have been given multiple pieces of relevant academic context pulled from the official college documentation. Carefully analyze all the given Q&A context and generate the most accurate, clear, and helpful answer for the student.
### Student's Question:
{user_query}
### Top Contexts:
{context}
### Instructions:
- Use all relevant context to form your answer.
- Avoid repeating the same sentences. Summarize smartly.
- Keep your answer polite and student-friendly.
- If not found, reply: "I'm sorry, I couldn't find this information in the provided academic context."
- If user asked questions like Hi , hello, how are you, good night , good morning answer to those questions only
- Use casual, conversational language with a touch of humor. Feel free to use emojis and friendly expressions to make interactions enjoyable
### Your Final Answer:
"""
# STEP 8: Ask a question and get response (unchanged)
def ask_bot(question):
context = get_top_k_matches(question)
prompt = generate_prompt(question, context)
response = openai.ChatCompletion.create(
model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
messages=[{"role":"user","content":prompt}],
temperature=0.5, max_tokens=1024
)
return response.choices[0].message.content
# Define query function
def qa_pipeline(query, history=[]):
try:
response = ask_bot(query)
history.append((query, response))
return "", history
except Exception as e:
history.append((query, f"⚠️ Error: {str(e)}"))
return "", history
# Launch UI with blocks
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## 🤖 KCT Smart Chatbot")
gr.Markdown("Ask academic or college-related questions. Powered by your custom dataset.")
chatbot = gr.Chatbot(label="KCT Chatbot", height=400)
msg = gr.Textbox(label="Enter your question here")
clear = gr.Button("🧹 Clear Chat")
# On send
def user_submit(user_input, chat_history):
return qa_pipeline(user_input, chat_history)
msg.submit(user_submit, [msg, chatbot], [msg, chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
demo.launch(share=True)