Chatbot / app.py
melk2025's picture
Upload 4 files
c773e53 verified
raw
history blame
9.04 kB
import chromadb
import pandas as pd
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
import json
import openai
from openai import OpenAI
import numpy as np
import requests
import chromadb
from chromadb import Client
from sentence_transformers import SentenceTransformer, util
from langchain_community.embeddings import HuggingFaceEmbeddings
from chromadb import Client
from chromadb import PersistentClient
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import requests
#HF_TOKEN = os.getenv("HF_TOKEN")
API_KEY = os.environ.get("OPENROUTER_API_KEY")
# Load the Excel file
df = pd.read_excel("web_documents.xlsx", engine='openpyxl')
# Initialize Chroma Persistent Client
client = chromadb.PersistentClient(path="./db")
# Create (or get) the Chroma collection
collection = client.get_or_create_collection(
name="rag_web_db_cosine_full_documents",
metadata={"hnsw:space": "cosine"}
)
# Load the embedding model
embedding_model = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2')
# Initialize the text splitter
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1200, chunk_overlap=150)
total_chunks = 0
# Process each row in the DataFrame
for idx, row in df.iterrows():
content = str(row['Content']) # Just in case it’s not a string
metadata_str = str(row['Metadata'])
# Convert metadata string back to a dictionary (optional: keep it simple if needed)
metadata = {"metadata": metadata_str}
# Split content into chunks
chunks = text_splitter.split_text(content)
total_chunks += len(chunks)
# Generate embeddings for each chunk
chunk_embeddings = embedding_model.encode(chunks)
# Add each chunk to the Chroma collection
for i, chunk in enumerate(chunks):
collection.add(
documents=[chunk],
metadatas=[metadata],
ids=[f"{idx}_chunk_{i}"],
embeddings=[chunk_embeddings[i]]
)
# ---------------------- Config ----------------------
SIMILARITY_THRESHOLD = 0.80
client1 = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=API_KEY) # remplace par ta clé OpenRouter
# ---------------------- Models ----------------------
# High-accuracy model for semantic search
semantic_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
# For ChromaDB
#embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-MiniLM-L6-v2")
# ---------------------- Load QA Data ----------------------
with open("qa.json", "r", encoding="utf-8") as f:
qa_data = json.load(f)
qa_questions = list(qa_data.keys())
qa_answers = list(qa_data.values())
qa_embeddings = semantic_model.encode(qa_questions, convert_to_tensor=True)
# ---------------------- CAG ----------------------
def retrieve_from_cag(user_query):
query_embedding = semantic_model.encode(user_query, convert_to_tensor=True)
cosine_scores = util.cos_sim(query_embedding, qa_embeddings)[0]
best_idx = int(np.argmax(cosine_scores))
best_score = float(cosine_scores[best_idx])
print(f"[CAG] Best score: {best_score:.4f} | Closest question: {qa_questions[best_idx]}")
if best_score >= SIMILARITY_THRESHOLD:
return qa_answers[best_idx], best_score
else:
return None, best_score
# ---------------------- RAG ----------------------
#client = chromadb.Client()
#collection = client.get_collection(name="rag_web_db_cosine_full_documents")
# Assuming you have a persistent Chroma client setup
#client = PersistentClient("./db_new/db_new")# Replace with the correct path if needed
#collection = client.get_collection(name="rag_web_db_cosine_full_documents")
# ---------------------- RAG retrieval ----------------------
def retrieve_from_rag(user_query):
print("Searching in RAG...")
query_embedding = embedding_model.encode(user_query)
results = collection.query(query_embeddings=[query_embedding], n_results=3)
if not results or not results.get('documents'):
return None
documents = []
for i, content in enumerate(results['documents'][0]):
metadata = results['metadatas'][0][i]
documents.append({
"content": content.strip(),
"metadata": metadata
})
print("Documents retrieved:", documents)
return documents
# ---------------------- Generation function (OpenRouter) ----------------------
def generate_via_openrouter(context, query):
print("\n--- Generating via OpenRouter ---")
print("Context received:", context)
prompt = f"""<s>[INST]
You are a Moodle expert assistant.
Instructions:
- Always respond in the same language as the question.
- Use only the provided documents below to answer.
- If the answer is not in the documents, simply say: "I don't know." / "Je ne sais pas."
- Cite only the sources you use, indicated at the end of each document like (Source: https://example.com).
Documents :
{context}
Question : {query}
Answer :
[/INST]
"""
try:
response = client1.chat.completions.create(
model="mistralai/mistral-small-3.1-24b-instruct:free",
messages=[{"role": "user", "content": prompt}]
)
return response.choices[0].message.content.strip()
except Exception as e:
print(f"Erreur lors de la génération : {e}")
return "Erreur lors de la génération."
# ---------------------- Generation function (Huggingface) ----------------------
def generate_via_huggingface(context, query, max_new_tokens=512, hf_token="your_huggingface_token"):
print("\n--- Generating via Huggingface ---")
print("Context received:", context)
prompt = f"""<s>[INST]
You are a Moodle expert assistant.
Rules:
- Answer only based on the provided documents.
- If the answer is not found, reply: "I don't know."
- Only cite sources mentioned (metadata 'source').
Documents:
{context}
Question: {query}
Answer:
[/INST]
"""
API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
headers = {"Authorization": f"Bearer {hf_token}"}
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": max_new_tokens
}
}
response = requests.post(API_URL, headers=headers, json=payload)
if response.status_code == 200:
result = response.json()
if isinstance(result, list) and "generated_text" in result[0]:
return result[0]["generated_text"].strip()
else:
return "Error: Unexpected response format."
else:
return f"Error {response.status_code}: {response.text}"
# ---------------------- Main Chatbot ----------------------
def chatbot(query):
print("\n==== New Query ====")
print("User Query:", query)
# Try to retrieve from CAG (cache)
answer, score = retrieve_from_cag(query)
if answer:
print("Answer retrieved from CAG cache.")
return answer
# If not found, retrieve from RAG
docs = retrieve_from_rag(query)
if docs:
context_blocks = []
for doc in docs:
content = doc.get("content", "").strip()
metadata = doc.get("metadata") or {}
source = "Source inconnue"
if isinstance(metadata, dict):
source_field = metadata.get("metadata", "")
if isinstance(source_field, str) and source_field.startswith("source:"):
source = source_field.replace("source:", "").strip()
context_blocks.append(f"{content}\n(Source: {source})")
context = "\n\n".join(context_blocks)
# Choose the generation backend (OpenRouter or Huggingface)
response = generate_via_openrouter(context, query)
return response
else:
print("No relevant documents found.")
return "Je ne sais pas."
# ---------------------- Gradio App ----------------------
# Define the chatbot response function
def ask(user_message, chat_history):
if not user_message:
return chat_history, chat_history, ""
# Get chatbot response
response = chatbot(user_message)
# Update chat history
chat_history.append((user_message, response))
return chat_history, chat_history, ""
# Initialize chat history with a welcome message
initial_message = (None, "Hello, how can I help you with Moodle?")
# Build Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
chat_history = gr.State([initial_message]) # <-- Move inside here!
chatbot_ui = gr.Chatbot(value=[initial_message])
question = gr.Textbox(placeholder="Ask me anything about Moodle...", show_label=False)
clear_button = gr.Button("Clear")
question.submit(ask, [question, chat_history], [chatbot_ui, chat_history, question])
clear_button.click(lambda: ([initial_message], [initial_message], ""), None, [chatbot_ui, chat_history, question], queue=False)
demo.queue()
demo.launch(share=False)