Spaces:
Running
Running
import gradio as gr | |
import os | |
import pandas as pd | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
import numpy as np | |
import faiss | |
from sentence_transformers import SentenceTransformer, CrossEncoder | |
import openai | |
csv_path = 'train_data.csv' | |
if not os.path.isfile(csv_path): | |
raise FileNotFoundError(f"Could not find CSV at {csv_path}") | |
df = pd.read_csv(csv_path, on_bad_lines='skip').dropna() | |
df.columns = ['Question', 'Answer'] | |
# STEP 3: Build TF-IDF structures (same) | |
questions = df['Question'].tolist() | |
answers = df['Answer'].tolist() | |
qa_pairs = [f"Q: {q}\nA: {a}" for q, a in zip(questions, answers)] | |
tfidf = TfidfVectorizer(max_features=5000).fit(questions) | |
tfidf_matrix = tfidf.transform(questions) | |
# STEP 4: Enhanced Embedding of Q+A pairs | |
embedder = SentenceTransformer("all-mpnet-base-v2") | |
qa_embeddings = embedder.encode(qa_pairs, convert_to_numpy=True) | |
dim = qa_embeddings.shape[1] | |
index = faiss.IndexHNSWFlat(dim, 32) | |
index.hnsw.efConstruction = 200 | |
index.add(qa_embeddings) | |
# STEP 5: Together AI Setup (same) | |
openai.api_key = "cfbafb6a338787841b0295fa7fbe0e4acca77b70ccc3d92bafea2004783b93a3" | |
openai.api_base = "https://api.together.xyz/v1" | |
# STEP 6: Smarter Hybrid Context Retriever | |
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2") | |
def get_top_k_matches(query, lex_n=50, sem_k=20, ce_k=5): | |
# Lexical filter | |
q_tfidf = tfidf.transform([query]) | |
lex_scores = cosine_similarity(q_tfidf, tfidf_matrix).flatten() | |
lex_idxs = np.argsort(lex_scores)[-lex_n:][::-1] | |
# Embed query | |
q_emb = embedder.encode([query], convert_to_numpy=True) | |
sub_embs = qa_embeddings[lex_idxs] | |
dists = np.linalg.norm(sub_embs - q_emb, axis=1) | |
top_sem_idxs = np.argsort(dists)[:sem_k] | |
cand_idxs = [lex_idxs[i] for i in top_sem_idxs] | |
# Cross-encoder for precision rerank | |
candidates = [qa_pairs[i] for i in cand_idxs] | |
pairs = [[query, cand] for cand in candidates] | |
ce_scores = cross_encoder.predict(pairs) | |
scored = sorted(zip(ce_scores, candidates), reverse=True) | |
top_contexts = [ctx for _, ctx in scored[:ce_k]] | |
return top_contexts | |
# STEP 7: Smart Prompt Generator (unchanged) | |
def generate_prompt(user_query, context): | |
return f""" | |
You are a smart and friendly assistant helping students with academic-related queries. | |
Below is a question from a student. You have been given multiple pieces of relevant academic context pulled from the official college documentation. Carefully analyze all the given Q&A context and generate the most accurate, clear, and helpful answer for the student. | |
### Student's Question: | |
{user_query} | |
### Top Contexts: | |
{context} | |
### Instructions: | |
- Use all relevant context to form your answer. | |
- Avoid repeating the same sentences. Summarize smartly. | |
- Keep your answer polite and student-friendly. | |
- If not found, reply: "I'm sorry, I couldn't find this information in the provided academic context." | |
- If user asked questions like Hi , hello, how are you, good night , good morning answer to those questions only | |
- Use casual, conversational language with a touch of humor. Feel free to use emojis and friendly expressions to make interactions enjoyable | |
### Your Final Answer: | |
""" | |
# STEP 8: Ask a question and get response (unchanged) | |
def ask_bot(question): | |
context = get_top_k_matches(question) | |
prompt = generate_prompt(question, context) | |
response = openai.ChatCompletion.create( | |
model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free", | |
messages=[{"role":"user","content":prompt}], | |
temperature=0.5, max_tokens=1024 | |
) | |
return response.choices[0].message.content | |
# Define query function | |
def qa_pipeline(query, history=[]): | |
try: | |
response = ask_bot(query) | |
history.append((query, response)) | |
return "", history | |
except Exception as e: | |
history.append((query, f"⚠️ Error: {str(e)}")) | |
return "", history | |
# Launch UI with blocks | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("## 🤖 KCT Smart Chatbot") | |
gr.Markdown("Ask academic or college-related questions. Powered by your custom dataset.") | |
chatbot = gr.Chatbot(label="KCT Chatbot", height=400) | |
msg = gr.Textbox(label="Enter your question here") | |
clear = gr.Button("🧹 Clear Chat") | |
# On send | |
def user_submit(user_input, chat_history): | |
return qa_pipeline(user_input, chat_history) | |
msg.submit(user_submit, [msg, chatbot], [msg, chatbot]) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
demo.launch(share=True) |