File size: 9,583 Bytes
1301284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9896ffc
1301284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9896ffc
1301284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d59a8ff
1301284
9896ffc
1301284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d59a8ff
 
 
 
 
 
 
 
 
 
 
 
 
1301284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9896ffc
1301284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9896ffc
1301284
 
 
 
 
 
 
9896ffc
1301284
 
 
 
 
 
 
 
d59a8ff
1301284
 
 
c2715c9
 
 
 
 
 
 
 
 
9896ffc
d59a8ff
1301284
 
 
 
 
 
9896ffc
1301284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d59a8ff
1301284
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
import os
import logging
import numpy as np
from typing import List, Optional, Tuple
import torch
import gradio as gr
import spaces
from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import FAISS
from langchain.embeddings.base import Embeddings
from gradio_client import Client
import requests
from tqdm import tqdm

# Configuration
DATABASE_DIR = "semantic_memory"
QWEN_API_URL = "Qwen/Qwen2.5-Max-Demo"  # Gradio API for Qwen2.5 chat
CHUNK_SIZE = 800
TOP_K_RESULTS = 150
SIMILARITY_THRESHOLD = 0.4
PASSWORD_HASH = "abc12345"  # Replace with hashed password in production

BASE_SYSTEM_PROMPT = """
Répondez en français selon ces règles :

1. Utilisez EXCLUSIVEMENT le contexte fourni
2. Structurez la réponse en :
   - Définition principale
   - Caractéristiques clés (3 points maximum)
   - Relations avec d'autres concepts
3. Si aucune information pertinente, indiquez-le clairement

Contexte :
{context}
"""

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("mtc_chat.log"),
        logging.StreamHandler()
    ]
)

class LocalEmbeddings(Embeddings):
    """Local sentence-transformers embeddings"""
    def __init__(self, model):
        self.model = model

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        embeddings = []
        for text in tqdm(texts, desc="Creating embeddings"):
            embeddings.append(self.model.encode(text).tolist())
        return embeddings

    def embed_query(self, text: str) -> List[float]:
        return self.model.encode(text).tolist()

def split_text_into_chunks(text: str) -> List[str]:
    """Split text with overlap and sentence preservation"""
    chunks = []
    start = 0
    text_length = len(text)
    
    while start < text_length:
        end = min(start + CHUNK_SIZE, text_length)
        chunk = text[start:end]
        
        # Find last complete punctuation
        last_punct = max(
            chunk.rfind('.'),
            chunk.rfind('!'),
            chunk.rfind('?'),
            chunk.rfind('\n\n')
        )
        
        if last_punct != -1 and (end - start) > CHUNK_SIZE // 2:
            end = start + last_punct + 1
        
        chunks.append(text[start:end].strip())
        start = end if end > start else start + CHUNK_SIZE
    
    return chunks

def initialize_vector_store(embeddings: Embeddings, db_name: str) -> FAISS:
    """Initialize or load a FAISS vector store"""
    db_path = os.path.join(DATABASE_DIR, db_name)
    if os.path.exists(db_path):
        try:
            logging.info(f"Loading existing database: {db_name}")
            return FAISS.load_local(
                db_path,
                embeddings,
                allow_dangerous_deserialization=True
            )
        except Exception as e:
            logging.error(f"FAISS load error: {str(e)}")
            raise

    logging.info(f"Creating new vector database: {db_name}")
    os.makedirs(db_path, exist_ok=True)
    return None

def create_new_database(file_content: str, db_name: str, password: str, progress=gr.Progress()) -> str:
    """Create a new FAISS database from uploaded file"""
    if password != PASSWORD_HASH:
        return "Incorrect password. Database creation failed."

    if not file_content.strip():
        return "Uploaded file is empty. Database creation failed."

    if not db_name.isalnum():
        return "Database name must be alphanumeric. Database creation failed."

    try:
        db_path = os.path.join(DATABASE_DIR, db_name)
        if os.path.exists(db_path):
            return f"Database '{db_name}' already exists."

        # Initialize embeddings and split text
        chunks = split_text_into_chunks(file_content)
        if not chunks:
            return "No valid chunks generated. Database creation failed."

        logging.info(f"Creating {len(chunks)} chunks...")
        progress(0, desc="Starting embedding process...")
        
        # Create embeddings with progress tracking
        embeddings_list = []
        for i, chunk in enumerate(chunks):
            progress(i / len(chunks), desc=f"Embedding chunk {i+1}/{len(chunks)}")
            embeddings_list.append(embeddings.embed_query(chunk))
        
        # Create FAISS database
        vector_store = FAISS.from_embeddings(
            text_embeddings=list(zip(chunks, embeddings_list)),
            embedding=embeddings
        )
        vector_store.save_local(db_path)
        logging.info(f"Vector store '{db_name}' initialized successfully")
        return f"Database '{db_name}' created successfully."
    except Exception as e:
        logging.error(f"Database creation failed: {str(e)}")
        return f"Error creating database: {str(e)}"

def generate_response(user_input: str, db_name: str) -> Optional[str]:
    """Generate response using Qwen2.5 MAX"""
    try:
        db_path = os.path.join(DATABASE_DIR, db_name)
        if not os.path.exists(db_path):
            return f"Database '{db_name}' does not exist."

        vector_store = FAISS.load_local(
            db_path,
            embeddings,
            allow_dangerous_deserialization=True
        )

        # Contextual search
        docs_scores = vector_store.similarity_search_with_score(
            user_input, 
            k=TOP_K_RESULTS * 3
        )
        
        # Filter results
        filtered_docs = [
            (doc, score) for doc, score in docs_scores 
            if score < SIMILARITY_THRESHOLD
        ]
        filtered_docs.sort(key=lambda x: x[1])
        
        if not filtered_docs:
            return "Aucune correspondance trouvée. Essayez des termes plus spécifiques."
        
        best_docs = [doc for doc, _ in filtered_docs[:TOP_K_RESULTS]]
        
        # Build context
        context = "\n".join(
            f"=== Source {i+1} ===\n{doc.page_content}\n" 
            for i, doc in enumerate(best_docs)
        )
        
        # Call Qwen API
        client = Client(QWEN_API_URL, verbose=False)
        response = client.predict(
            query=user_input,
            history=[],
            system=BASE_SYSTEM_PROMPT.format(context=context),
            api_name="/model_chat"
        )

        # Extract response
        if isinstance(response, tuple) and len(response) >= 2:
            chat_history = response[1]
            if chat_history and len(chat_history[-1]) >= 2:
                return chat_history[-1][1]
        
        return "Réponse indisponible - Veuillez reformuler votre question."

    except Exception as e:
        logging.error(f"Generation error: {str(e)}", exc_info=True)
        return None

# Initialize models and vector store
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SentenceTransformer("cnmoro/snowflake-arctic-embed-m-v2.0-cpu", device=device, trust_remote_code=True)
embeddings = LocalEmbeddings(model)

# Gradio interface
with gr.Blocks() as app:
    gr.Markdown("# Local Tech Knowledge Assistant")

    with gr.Tab("Create Database"):
        gr.Markdown("## Create a New FAISS Database")
        file_input = gr.File(label="Upload .txt File")
        db_name_input = gr.Textbox(label="Enter Desired Database Name (Alphanumeric Only)")
        password_input = gr.Textbox(label="Enter Password", type="password")
        create_output = gr.Textbox(label="Status")
        create_button = gr.Button("Create Database")
        
        def handle_create(file, db_name, password, progress=gr.Progress()):
            if not file or not db_name or not password:
                return "Please provide all required inputs."
            
            # Check if the file is valid
            if isinstance(file, str):  # Gradio provides the file path as a string
                try:
                    with open(file, "r", encoding="utf-8") as f:
                        file_content = f.read()
                except Exception as e:
                    return f"Error reading file: {str(e)}"
            else:
                return "Invalid file format. Please upload a .txt file."
            
            return create_new_database(file_content, db_name, password, progress)
        
        create_button.click(
            handle_create,
            inputs=[file_input, db_name_input, password_input],
            outputs=create_output
        )

    with gr.Tab("Chat with Database"):
        gr.Markdown("## Chat with Existing Databases")
        db_select = gr.Dropdown(choices=[], label="Select Database")
        chatbot = gr.Chatbot(height=500)
        msg = gr.Textbox(label="Votre question")
        clear = gr.ClearButton([msg, chatbot])
        
        def update_db_list():
            if not os.path.exists(DATABASE_DIR):
                return []
            return [name for name in os.listdir(DATABASE_DIR) if os.path.isdir(os.path.join(DATABASE_DIR, name))]
        
        def chat_response(message: str, db_name: str, history: List[Tuple[str, str]]):
            response = generate_response(message, db_name)
            return "", history + [(message, response or "Erreur de génération - Veuillez réessayer.")]
        
        msg.submit(
            chat_response,
            inputs=[msg, db_select, chatbot],
            outputs=[msg, chatbot],
            queue=True
        )
        
        # Update database list on page load
        db_select.choices = update_db_list()

if __name__ == "__main__":
    app.launch(server_name="0.0.0.0", server_port=7860)