Local-Solution / app.py
localsavageai's picture
Upload app.py
d59a8ff verified
raw
history blame
9.58 kB
import os
import logging
import numpy as np
from typing import List, Optional, Tuple
import torch
import gradio as gr
import spaces
from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import FAISS
from langchain.embeddings.base import Embeddings
from gradio_client import Client
import requests
from tqdm import tqdm
# Configuration
DATABASE_DIR = "semantic_memory"
QWEN_API_URL = "Qwen/Qwen2.5-Max-Demo" # Gradio API for Qwen2.5 chat
CHUNK_SIZE = 800
TOP_K_RESULTS = 150
SIMILARITY_THRESHOLD = 0.4
PASSWORD_HASH = "abc12345" # Replace with hashed password in production
BASE_SYSTEM_PROMPT = """
Répondez en français selon ces règles :
1. Utilisez EXCLUSIVEMENT le contexte fourni
2. Structurez la réponse en :
- Définition principale
- Caractéristiques clés (3 points maximum)
- Relations avec d'autres concepts
3. Si aucune information pertinente, indiquez-le clairement
Contexte :
{context}
"""
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("mtc_chat.log"),
logging.StreamHandler()
]
)
class LocalEmbeddings(Embeddings):
"""Local sentence-transformers embeddings"""
def __init__(self, model):
self.model = model
def embed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = []
for text in tqdm(texts, desc="Creating embeddings"):
embeddings.append(self.model.encode(text).tolist())
return embeddings
def embed_query(self, text: str) -> List[float]:
return self.model.encode(text).tolist()
def split_text_into_chunks(text: str) -> List[str]:
"""Split text with overlap and sentence preservation"""
chunks = []
start = 0
text_length = len(text)
while start < text_length:
end = min(start + CHUNK_SIZE, text_length)
chunk = text[start:end]
# Find last complete punctuation
last_punct = max(
chunk.rfind('.'),
chunk.rfind('!'),
chunk.rfind('?'),
chunk.rfind('\n\n')
)
if last_punct != -1 and (end - start) > CHUNK_SIZE // 2:
end = start + last_punct + 1
chunks.append(text[start:end].strip())
start = end if end > start else start + CHUNK_SIZE
return chunks
def initialize_vector_store(embeddings: Embeddings, db_name: str) -> FAISS:
"""Initialize or load a FAISS vector store"""
db_path = os.path.join(DATABASE_DIR, db_name)
if os.path.exists(db_path):
try:
logging.info(f"Loading existing database: {db_name}")
return FAISS.load_local(
db_path,
embeddings,
allow_dangerous_deserialization=True
)
except Exception as e:
logging.error(f"FAISS load error: {str(e)}")
raise
logging.info(f"Creating new vector database: {db_name}")
os.makedirs(db_path, exist_ok=True)
return None
def create_new_database(file_content: str, db_name: str, password: str, progress=gr.Progress()) -> str:
"""Create a new FAISS database from uploaded file"""
if password != PASSWORD_HASH:
return "Incorrect password. Database creation failed."
if not file_content.strip():
return "Uploaded file is empty. Database creation failed."
if not db_name.isalnum():
return "Database name must be alphanumeric. Database creation failed."
try:
db_path = os.path.join(DATABASE_DIR, db_name)
if os.path.exists(db_path):
return f"Database '{db_name}' already exists."
# Initialize embeddings and split text
chunks = split_text_into_chunks(file_content)
if not chunks:
return "No valid chunks generated. Database creation failed."
logging.info(f"Creating {len(chunks)} chunks...")
progress(0, desc="Starting embedding process...")
# Create embeddings with progress tracking
embeddings_list = []
for i, chunk in enumerate(chunks):
progress(i / len(chunks), desc=f"Embedding chunk {i+1}/{len(chunks)}")
embeddings_list.append(embeddings.embed_query(chunk))
# Create FAISS database
vector_store = FAISS.from_embeddings(
text_embeddings=list(zip(chunks, embeddings_list)),
embedding=embeddings
)
vector_store.save_local(db_path)
logging.info(f"Vector store '{db_name}' initialized successfully")
return f"Database '{db_name}' created successfully."
except Exception as e:
logging.error(f"Database creation failed: {str(e)}")
return f"Error creating database: {str(e)}"
def generate_response(user_input: str, db_name: str) -> Optional[str]:
"""Generate response using Qwen2.5 MAX"""
try:
db_path = os.path.join(DATABASE_DIR, db_name)
if not os.path.exists(db_path):
return f"Database '{db_name}' does not exist."
vector_store = FAISS.load_local(
db_path,
embeddings,
allow_dangerous_deserialization=True
)
# Contextual search
docs_scores = vector_store.similarity_search_with_score(
user_input,
k=TOP_K_RESULTS * 3
)
# Filter results
filtered_docs = [
(doc, score) for doc, score in docs_scores
if score < SIMILARITY_THRESHOLD
]
filtered_docs.sort(key=lambda x: x[1])
if not filtered_docs:
return "Aucune correspondance trouvée. Essayez des termes plus spécifiques."
best_docs = [doc for doc, _ in filtered_docs[:TOP_K_RESULTS]]
# Build context
context = "\n".join(
f"=== Source {i+1} ===\n{doc.page_content}\n"
for i, doc in enumerate(best_docs)
)
# Call Qwen API
client = Client(QWEN_API_URL, verbose=False)
response = client.predict(
query=user_input,
history=[],
system=BASE_SYSTEM_PROMPT.format(context=context),
api_name="/model_chat"
)
# Extract response
if isinstance(response, tuple) and len(response) >= 2:
chat_history = response[1]
if chat_history and len(chat_history[-1]) >= 2:
return chat_history[-1][1]
return "Réponse indisponible - Veuillez reformuler votre question."
except Exception as e:
logging.error(f"Generation error: {str(e)}", exc_info=True)
return None
# Initialize models and vector store
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SentenceTransformer("cnmoro/snowflake-arctic-embed-m-v2.0-cpu", device=device, trust_remote_code=True)
embeddings = LocalEmbeddings(model)
# Gradio interface
with gr.Blocks() as app:
gr.Markdown("# Local Tech Knowledge Assistant")
with gr.Tab("Create Database"):
gr.Markdown("## Create a New FAISS Database")
file_input = gr.File(label="Upload .txt File")
db_name_input = gr.Textbox(label="Enter Desired Database Name (Alphanumeric Only)")
password_input = gr.Textbox(label="Enter Password", type="password")
create_output = gr.Textbox(label="Status")
create_button = gr.Button("Create Database")
def handle_create(file, db_name, password, progress=gr.Progress()):
if not file or not db_name or not password:
return "Please provide all required inputs."
# Check if the file is valid
if isinstance(file, str): # Gradio provides the file path as a string
try:
with open(file, "r", encoding="utf-8") as f:
file_content = f.read()
except Exception as e:
return f"Error reading file: {str(e)}"
else:
return "Invalid file format. Please upload a .txt file."
return create_new_database(file_content, db_name, password, progress)
create_button.click(
handle_create,
inputs=[file_input, db_name_input, password_input],
outputs=create_output
)
with gr.Tab("Chat with Database"):
gr.Markdown("## Chat with Existing Databases")
db_select = gr.Dropdown(choices=[], label="Select Database")
chatbot = gr.Chatbot(height=500)
msg = gr.Textbox(label="Votre question")
clear = gr.ClearButton([msg, chatbot])
def update_db_list():
if not os.path.exists(DATABASE_DIR):
return []
return [name for name in os.listdir(DATABASE_DIR) if os.path.isdir(os.path.join(DATABASE_DIR, name))]
def chat_response(message: str, db_name: str, history: List[Tuple[str, str]]):
response = generate_response(message, db_name)
return "", history + [(message, response or "Erreur de génération - Veuillez réessayer.")]
msg.submit(
chat_response,
inputs=[msg, db_select, chatbot],
outputs=[msg, chatbot],
queue=True
)
# Update database list on page load
db_select.choices = update_db_list()
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860)