localsavageai commited on
Commit
1301284
·
verified ·
1 Parent(s): 3c84337

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +251 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import numpy as np
4
+ from typing import List, Optional, Tuple
5
+ import torch
6
+ import gradio as gr
7
+ import spaces
8
+ from sentence_transformers import SentenceTransformer
9
+ from langchain_community.vectorstores import FAISS
10
+ from langchain.embeddings.base import Embeddings
11
+ from gradio_client import Client
12
+ import requests
13
+ from tqdm import tqdm
14
+
15
+ # Configuration
16
+ DATABASE_DIR = "semantic_memory"
17
+ QWEN_API_URL = "Qwen/Qwen2.5-Max-Demo" # Gradio API for Qwen2.5 chat
18
+ CHUNK_SIZE = 800
19
+ TOP_K_RESULTS = 150
20
+ SIMILARITY_THRESHOLD = 0.4
21
+ PASSWORD = "abc12345"
22
+
23
+ BASE_SYSTEM_PROMPT = """
24
+ Répondez en français selon ces règles :
25
+
26
+ 1. Utilisez EXCLUSIVEMENT le contexte fourni
27
+ 2. Structurez la réponse en :
28
+ - Définition principale
29
+ - Caractéristiques clés (3 points maximum)
30
+ - Relations avec d'autres concepts
31
+ 3. Si aucune information pertinente, indiquez-le clairement
32
+
33
+ Contexte :
34
+ {context}
35
+ """
36
+
37
+ # Configure logging
38
+ logging.basicConfig(
39
+ level=logging.INFO,
40
+ format='%(asctime)s - %(levelname)s - %(message)s',
41
+ handlers=[
42
+ logging.FileHandler("mtc_chat.log"),
43
+ logging.StreamHandler()
44
+ ]
45
+ )
46
+
47
+ class LocalEmbeddings(Embeddings):
48
+ """Local sentence-transformers embeddings"""
49
+ def __init__(self, model):
50
+ self.model = model
51
+
52
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
53
+ embeddings = []
54
+ for text in tqdm(texts, desc="Creating embeddings"):
55
+ embeddings.append(self.model.encode(text).tolist())
56
+ return embeddings
57
+
58
+ def embed_query(self, text: str) -> List[float]:
59
+ return self.model.encode(text).tolist()
60
+
61
+ def split_text_into_chunks(text: str) -> List[str]:
62
+ """Split text with overlap and sentence preservation"""
63
+ chunks = []
64
+ start = 0
65
+ text_length = len(text)
66
+
67
+ while start < text_length:
68
+ end = min(start + CHUNK_SIZE, text_length)
69
+ chunk = text[start:end]
70
+
71
+ # Find last complete punctuation
72
+ last_punct = max(
73
+ chunk.rfind('.'),
74
+ chunk.rfind('!'),
75
+ chunk.rfind('?'),
76
+ chunk.rfind('\n\n')
77
+ )
78
+
79
+ if last_punct != -1 and (end - start) > CHUNK_SIZE//2:
80
+ end = start + last_punct + 1
81
+
82
+ chunks.append(text[start:end].strip())
83
+ start = end if end > start else start + CHUNK_SIZE
84
+
85
+ return chunks
86
+
87
+ def initialize_vector_store(embeddings: Embeddings, db_name: str) -> FAISS:
88
+ """Initialize or load a FAISS vector store"""
89
+ db_path = os.path.join(DATABASE_DIR, db_name)
90
+ if os.path.exists(db_path):
91
+ try:
92
+ logging.info(f"Loading existing database: {db_name}")
93
+ return FAISS.load_local(
94
+ db_path,
95
+ embeddings,
96
+ allow_dangerous_deserialization=True
97
+ )
98
+ except Exception as e:
99
+ logging.error(f"FAISS load error: {str(e)}")
100
+ raise
101
+
102
+ logging.info(f"Creating new vector database: {db_name}")
103
+ os.makedirs(db_path, exist_ok=True)
104
+ return None
105
+
106
+ def create_new_database(file_content: str, db_name: str, password: str) -> str:
107
+ """Create a new FAISS database from uploaded file"""
108
+ if password != PASSWORD:
109
+ return "Incorrect password. Database creation failed."
110
+
111
+ if not file_content.strip():
112
+ return "Uploaded file is empty. Database creation failed."
113
+
114
+ if not db_name.isalnum():
115
+ return "Database name must be alphanumeric. Database creation failed."
116
+
117
+ try:
118
+ db_path = os.path.join(DATABASE_DIR, db_name)
119
+ if os.path.exists(db_path):
120
+ return f"Database '{db_name}' already exists."
121
+
122
+ # Initialize embeddings and split text
123
+ chunks = split_text_into_chunks(file_content)
124
+ if not chunks:
125
+ return "No valid chunks generated. Database creation failed."
126
+
127
+ logging.info(f"Creating {len(chunks)} chunks...")
128
+ vector_store = FAISS.from_texts(chunks, embeddings)
129
+ vector_store.save_local(db_path)
130
+ logging.info(f"Vector store '{db_name}' initialized successfully")
131
+ return f"Database '{db_name}' created successfully."
132
+ except Exception as e:
133
+ logging.error(f"Database creation failed: {str(e)}")
134
+ return f"Error creating database: {str(e)}"
135
+
136
+ def generate_response(user_input: str, db_name: str) -> Optional[str]:
137
+ """Generate response using Qwen2.5 MAX"""
138
+ try:
139
+ db_path = os.path.join(DATABASE_DIR, db_name)
140
+ if not os.path.exists(db_path):
141
+ return f"Database '{db_name}' does not exist."
142
+
143
+ vector_store = FAISS.load_local(
144
+ db_path,
145
+ embeddings,
146
+ allow_dangerous_deserialization=True
147
+ )
148
+
149
+ # Contextual search
150
+ docs_scores = vector_store.similarity_search_with_score(
151
+ user_input,
152
+ k=TOP_K_RESULTS*3
153
+ )
154
+
155
+ # Filter results
156
+ filtered_docs = [
157
+ (doc, score) for doc, score in docs_scores
158
+ if score < SIMILARITY_THRESHOLD
159
+ ]
160
+ filtered_docs.sort(key=lambda x: x[1])
161
+
162
+ if not filtered_docs:
163
+ return "Aucune correspondance trouvée. Essayez des termes plus spécifiques."
164
+
165
+ best_docs = [doc for doc, _ in filtered_docs[:TOP_K_RESULTS]]
166
+
167
+ # Build context
168
+ context = "\n".join(
169
+ f"=== Source {i+1} ===\n{doc.page_content}\n"
170
+ for i, doc in enumerate(best_docs)
171
+ )
172
+
173
+ # Call Qwen API
174
+ client = Client(QWEN_API_URL, verbose=False)
175
+ response = client.predict(
176
+ query=user_input,
177
+ history=[],
178
+ system=BASE_SYSTEM_PROMPT.format(context=context),
179
+ api_name="/model_chat"
180
+ )
181
+
182
+ # Extract response
183
+ if isinstance(response, tuple) and len(response) >= 2:
184
+ chat_history = response[1]
185
+ if chat_history and len(chat_history[-1]) >= 2:
186
+ return chat_history[-1][1]
187
+
188
+ return "Réponse indisponible - Veuillez reformuler votre question."
189
+
190
+ except Exception as e:
191
+ logging.error(f"Generation error: {str(e)}", exc_info=True)
192
+ return None
193
+
194
+ # Initialize models
195
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
196
+ model = SentenceTransformer("cnmoro/snowflake-arctic-embed-m-v2.0-cpu", device=device, trust_remote_code=True)
197
+ embeddings = LocalEmbeddings(model)
198
+
199
+ # Gradio interface
200
+ with gr.Blocks() as app:
201
+ gr.Markdown("# Local Tech Knowledge Assistant")
202
+
203
+ with gr.Tab("Create Database"):
204
+ gr.Markdown("## Create a New FAISS Database")
205
+ file_input = gr.File(label="Upload .txt File")
206
+ db_name_input = gr.Textbox(label="Enter Desired Database Name (Alphanumeric Only)")
207
+ password_input = gr.Textbox(label="Enter Password", type="password")
208
+ create_output = gr.Textbox(label="Status")
209
+ create_button = gr.Button("Create Database")
210
+
211
+ def handle_create(file, db_name, password):
212
+ if not file or not db_name or not password:
213
+ return "Please provide all required inputs."
214
+
215
+ # Read file content
216
+ file_content = file.decode("utf-8")
217
+ return create_new_database(file_content, db_name, password)
218
+
219
+ create_button.click(
220
+ handle_create,
221
+ inputs=[file_input, db_name_input, password_input],
222
+ outputs=create_output
223
+ )
224
+
225
+ with gr.Tab("Chat with Database"):
226
+ gr.Markdown("## Chat with Existing Databases")
227
+ db_select = gr.Dropdown(choices=[], label="Select Database")
228
+ chatbot = gr.Chatbot(height=500)
229
+ msg = gr.Textbox(label="Votre question")
230
+ clear = gr.ClearButton([msg, chatbot])
231
+
232
+ def update_db_list():
233
+ if not os.path.exists(DATABASE_DIR):
234
+ return []
235
+ return [name for name in os.listdir(DATABASE_DIR) if os.path.isdir(os.path.join(DATABASE_DIR, name))]
236
+
237
+ def chat_response(message: str, db_name: str, history: List[Tuple[str, str]]):
238
+ response = generate_response(message, db_name)
239
+ return "", history + [(message, response or "Erreur de génération - Veuillez réessayer.")]
240
+
241
+ msg.submit(
242
+ chat_response,
243
+ inputs=[msg, db_select, chatbot],
244
+ outputs=[msg, chatbot],
245
+ queue=True
246
+ )
247
+
248
+ db_select.choices = update_db_list()
249
+
250
+ if __name__ == "__main__":
251
+ app.launch(server_name="0.0.0.0", server_port=7860)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=5.23.2
2
+ sentence-transformers
3
+ torch
4
+ langchain
5
+ langchain-community
6
+ faiss-cpu
7
+ gradio-client
8
+ tqdm
9
+ requests
10
+ numpy