Spaces:
Runtime error
Runtime error
import os | |
import logging | |
import numpy as np | |
from typing import List, Optional, Tuple | |
import torch | |
import gradio as gr | |
import spaces | |
from sentence_transformers import SentenceTransformer | |
from langchain_community.vectorstores import FAISS | |
from langchain.embeddings.base import Embeddings | |
from gradio_client import Client | |
import requests | |
from tqdm import tqdm | |
# Configuration | |
DATABASE_DIR = "semantic_memory" | |
QWEN_API_URL = "Qwen/Qwen2.5-Max-Demo" # Gradio API for Qwen2.5 chat | |
CHUNK_SIZE = 800 | |
TOP_K_RESULTS = 150 | |
SIMILARITY_THRESHOLD = 0.4 | |
PASSWORD_HASH = "abc12345" # Replace with hashed password in production | |
BASE_SYSTEM_PROMPT = """ | |
Répondez en français selon ces règles : | |
1. Utilisez EXCLUSIVEMENT le contexte fourni | |
2. Structurez la réponse en : | |
- Définition principale | |
- Caractéristiques clés (3 points maximum) | |
- Relations avec d'autres concepts | |
3. Si aucune information pertinente, indiquez-le clairement | |
Contexte : | |
{context} | |
""" | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler("mtc_chat.log"), | |
logging.StreamHandler() | |
] | |
) | |
class LocalEmbeddings(Embeddings): | |
"""Local sentence-transformers embeddings""" | |
def __init__(self, model): | |
self.model = model | |
def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
embeddings = [] | |
for text in tqdm(texts, desc="Creating embeddings"): | |
embeddings.append(self.model.encode(text).tolist()) | |
return embeddings | |
def embed_query(self, text: str) -> List[float]: | |
return self.model.encode(text).tolist() | |
def split_text_into_chunks(text: str) -> List[str]: | |
"""Split text with overlap and sentence preservation""" | |
chunks = [] | |
start = 0 | |
text_length = len(text) | |
while start < text_length: | |
end = min(start + CHUNK_SIZE, text_length) | |
chunk = text[start:end] | |
# Find last complete punctuation | |
last_punct = max( | |
chunk.rfind('.'), | |
chunk.rfind('!'), | |
chunk.rfind('?'), | |
chunk.rfind('\n\n') | |
) | |
if last_punct != -1 and (end - start) > CHUNK_SIZE // 2: | |
end = start + last_punct + 1 | |
chunks.append(text[start:end].strip()) | |
start = end if end > start else start + CHUNK_SIZE | |
return chunks | |
def initialize_vector_store(embeddings: Embeddings, db_name: str) -> FAISS: | |
"""Initialize or load a FAISS vector store""" | |
db_path = os.path.join(DATABASE_DIR, db_name) | |
if os.path.exists(db_path): | |
try: | |
logging.info(f"Loading existing database: {db_name}") | |
return FAISS.load_local( | |
db_path, | |
embeddings, | |
allow_dangerous_deserialization=True | |
) | |
except Exception as e: | |
logging.error(f"FAISS load error: {str(e)}") | |
raise | |
logging.info(f"Creating new vector database: {db_name}") | |
os.makedirs(db_path, exist_ok=True) | |
return None | |
def create_new_database(file_content: str, db_name: str, password: str, progress=gr.Progress()) -> str: | |
"""Create a new FAISS database from uploaded file""" | |
if password != PASSWORD_HASH: | |
return "Incorrect password. Database creation failed." | |
if not file_content.strip(): | |
return "Uploaded file is empty. Database creation failed." | |
if not db_name.isalnum(): | |
return "Database name must be alphanumeric. Database creation failed." | |
try: | |
db_path = os.path.join(DATABASE_DIR, db_name) | |
if os.path.exists(db_path): | |
return f"Database '{db_name}' already exists." | |
# Initialize embeddings and split text | |
chunks = split_text_into_chunks(file_content) | |
if not chunks: | |
return "No valid chunks generated. Database creation failed." | |
logging.info(f"Creating {len(chunks)} chunks...") | |
progress(0, desc="Starting embedding process...") | |
# Create embeddings with progress tracking | |
embeddings_list = [] | |
for i, chunk in enumerate(chunks): | |
progress(i / len(chunks), desc=f"Embedding chunk {i+1}/{len(chunks)}") | |
embeddings_list.append(embeddings.embed_query(chunk)) | |
# Create FAISS database | |
vector_store = FAISS.from_embeddings( | |
text_embeddings=list(zip(chunks, embeddings_list)), | |
embedding=embeddings | |
) | |
vector_store.save_local(db_path) | |
logging.info(f"Vector store '{db_name}' initialized successfully") | |
return f"Database '{db_name}' created successfully." | |
except Exception as e: | |
logging.error(f"Database creation failed: {str(e)}") | |
return f"Error creating database: {str(e)}" | |
def generate_response(user_input: str, db_name: str) -> Optional[str]: | |
"""Generate response using Qwen2.5 MAX""" | |
try: | |
db_path = os.path.join(DATABASE_DIR, db_name) | |
if not os.path.exists(db_path): | |
return f"Database '{db_name}' does not exist." | |
vector_store = FAISS.load_local( | |
db_path, | |
embeddings, | |
allow_dangerous_deserialization=True | |
) | |
# Contextual search | |
docs_scores = vector_store.similarity_search_with_score( | |
user_input, | |
k=TOP_K_RESULTS * 3 | |
) | |
# Filter results | |
filtered_docs = [ | |
(doc, score) for doc, score in docs_scores | |
if score < SIMILARITY_THRESHOLD | |
] | |
filtered_docs.sort(key=lambda x: x[1]) | |
if not filtered_docs: | |
return "Aucune correspondance trouvée. Essayez des termes plus spécifiques." | |
best_docs = [doc for doc, _ in filtered_docs[:TOP_K_RESULTS]] | |
# Build context | |
context = "\n".join( | |
f"=== Source {i+1} ===\n{doc.page_content}\n" | |
for i, doc in enumerate(best_docs) | |
) | |
# Call Qwen API | |
client = Client(QWEN_API_URL, verbose=False) | |
response = client.predict( | |
query=user_input, | |
history=[], | |
system=BASE_SYSTEM_PROMPT.format(context=context), | |
api_name="/model_chat" | |
) | |
# Extract response | |
if isinstance(response, tuple) and len(response) >= 2: | |
chat_history = response[1] | |
if chat_history and len(chat_history[-1]) >= 2: | |
return chat_history[-1][1] | |
return "Réponse indisponible - Veuillez reformuler votre question." | |
except Exception as e: | |
logging.error(f"Generation error: {str(e)}", exc_info=True) | |
return None | |
# Initialize models and vector store | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = SentenceTransformer("cnmoro/snowflake-arctic-embed-m-v2.0-cpu", device=device, trust_remote_code=True) | |
embeddings = LocalEmbeddings(model) | |
# Gradio interface | |
with gr.Blocks() as app: | |
gr.Markdown("# Local Tech Knowledge Assistant") | |
with gr.Tab("Create Database"): | |
gr.Markdown("## Create a New FAISS Database") | |
file_input = gr.File(label="Upload .txt File") | |
db_name_input = gr.Textbox(label="Enter Desired Database Name (Alphanumeric Only)") | |
password_input = gr.Textbox(label="Enter Password", type="password") | |
create_output = gr.Textbox(label="Status") | |
create_button = gr.Button("Create Database") | |
def handle_create(file, db_name, password, progress=gr.Progress()): | |
if not file or not db_name or not password: | |
return "Please provide all required inputs." | |
# Check if the file is valid | |
if isinstance(file, str): # Gradio provides the file path as a string | |
try: | |
with open(file, "r", encoding="utf-8") as f: | |
file_content = f.read() | |
except Exception as e: | |
return f"Error reading file: {str(e)}" | |
else: | |
return "Invalid file format. Please upload a .txt file." | |
return create_new_database(file_content, db_name, password, progress) | |
create_button.click( | |
handle_create, | |
inputs=[file_input, db_name_input, password_input], | |
outputs=create_output | |
) | |
with gr.Tab("Chat with Database"): | |
gr.Markdown("## Chat with Existing Databases") | |
db_select = gr.Dropdown(choices=[], label="Select Database") | |
chatbot = gr.Chatbot(height=500) | |
msg = gr.Textbox(label="Votre question") | |
clear = gr.ClearButton([msg, chatbot]) | |
def update_db_list(): | |
if not os.path.exists(DATABASE_DIR): | |
return [] | |
return [name for name in os.listdir(DATABASE_DIR) if os.path.isdir(os.path.join(DATABASE_DIR, name))] | |
def chat_response(message: str, db_name: str, history: List[Tuple[str, str]]): | |
response = generate_response(message, db_name) | |
return "", history + [(message, response or "Erreur de génération - Veuillez réessayer.")] | |
msg.submit( | |
chat_response, | |
inputs=[msg, db_select, chatbot], | |
outputs=[msg, chatbot], | |
queue=True | |
) | |
# Update database list on page load | |
db_select.choices = update_db_list() | |
if __name__ == "__main__": | |
app.launch(server_name="0.0.0.0", server_port=7860) | |