localsavageai commited on
Commit
30c0b2f
·
verified ·
1 Parent(s): d8f5f8c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -88
app.py CHANGED
@@ -9,21 +9,21 @@ from langchain.embeddings.base import Embeddings
9
  from tqdm import tqdm
10
 
11
  # Configuration
12
- QWEN_API_URL = "Qwen/Qwen2.5-Max-Demo" # Gradio API for Qwen2.5 chat
13
  CHUNK_SIZE = 800
14
  TOP_K_RESULTS = 150
15
  SIMILARITY_THRESHOLD = 0.4
16
- PASSWORD_HASH = os.getenv("PASSWORD_HASH", "abc12345") # Use environment variable for password
17
 
18
  BASE_SYSTEM_PROMPT = """
19
  Répondez en français selon ces règles :
20
 
21
- 1. Utilisez EXCLUSIVEMENT le contexte fourni
22
  2. Structurez la réponse en :
23
- - Définition principale
24
- - Caractéristiques clés (3 points maximum)
25
- - Relations avec d'autres concepts
26
- 3. Si aucune information pertinente, indiquez-le clairement
27
 
28
  Contexte :
29
  {context}
@@ -54,7 +54,7 @@ class LocalEmbeddings(Embeddings):
54
  return self.model.encode(text).tolist()
55
 
56
  def split_text_into_chunks(text: str) -> List[str]:
57
- """Split text with overlap and sentence preservation"""
58
  chunks = []
59
  start = 0
60
  text_length = len(text)
@@ -98,7 +98,7 @@ def create_new_database(file_content: str, db_name: str, password: str, progress
98
  if os.path.exists(faiss_file) or os.path.exists(pkl_file):
99
  return f"Database '{db_name}' already exists.", []
100
 
101
- # Initialize embeddings and split text
102
  chunks = split_text_into_chunks(file_content)
103
  if not chunks:
104
  return "No valid chunks generated. Database creation failed.", []
@@ -118,21 +118,18 @@ def create_new_database(file_content: str, db_name: str, password: str, progress
118
  embedding=embeddings
119
  )
120
 
121
- # Save FAISS database
122
- try:
123
- vector_store.save_local(".")
124
- logging.info(f"FAISS database saved to: {faiss_file} and {pkl_file}")
125
- except Exception as e:
126
- logging.error(f"FAISS save error: {str(e)}")
127
- return "Failed to save FAISS database. Please check logs for details.", []
128
-
129
- # Verify files were created
130
  if not os.path.exists(faiss_file) or not os.path.exists(pkl_file):
131
  return "Failed to save FAISS database files. Please check file permissions.", []
 
132
  logging.info(f"FAISS database files created: {faiss_file}, {pkl_file}")
133
 
134
  # Update the list of available databases
135
  db_list = [os.path.splitext(f)[0].replace("-index", "") for f in os.listdir(".") if f.endswith(".faiss")]
 
136
  return f"Database '{db_name}' created successfully.", db_list
137
 
138
  except Exception as e:
@@ -151,23 +148,19 @@ def generate_response(user_input: str, db_name: str) -> str:
151
  if not os.path.exists(faiss_file) or not os.path.exists(pkl_file):
152
  return f"Database '{db_name}' does not exist."
153
 
154
- try:
155
- vector_store = FAISS.load_local(".", embeddings, allow_dangerous_deserialization=True)
156
- except Exception as e:
157
- logging.error(f"FAISS load error: {str(e)}")
158
- return "Failed to load FAISS database. Please check logs for details."
159
 
160
- # Contextual search
161
  docs_scores = vector_store.similarity_search_with_score(
162
  user_input,
163
  k=TOP_K_RESULTS * 3
164
  )
165
 
166
- # Filter results
167
  filtered_docs = [
168
  (doc, score) for doc, score in docs_scores
169
  if score < SIMILARITY_THRESHOLD
170
  ]
 
171
  filtered_docs.sort(key=lambda x: x[1])
172
 
173
  if not filtered_docs:
@@ -175,14 +168,13 @@ def generate_response(user_input: str, db_name: str) -> str:
175
 
176
  best_docs = [doc for doc, _ in filtered_docs[:TOP_K_RESULTS]]
177
 
178
- # Build context
179
  context = "\n".join(
180
  f"=== Source {i+1} ===\n{doc.page_content}\n"
181
  for i, doc in enumerate(best_docs)
182
  )
183
 
184
- # Call Qwen API
185
  client = Client(QWEN_API_URL, verbose=False)
 
186
  response = client.predict(
187
  query=user_input,
188
  history=[],
@@ -190,7 +182,6 @@ def generate_response(user_input: str, db_name: str) -> str:
190
  api_name="/model_chat"
191
  )
192
 
193
- # Extract response
194
  if isinstance(response, tuple) and len(response) >= 2:
195
  chat_history = response[1]
196
  if chat_history and len(chat_history[-1]) >= 2:
@@ -207,7 +198,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
207
  model = SentenceTransformer("cnmoro/snowflake-arctic-embed-m-v2.0-cpu", device=device, trust_remote_code=True)
208
  embeddings = LocalEmbeddings(model)
209
 
210
- # Gradio interface
211
  with gr.Blocks() as app:
212
  gr.Markdown("# Local Tech Knowledge Assistant")
213
 
@@ -220,64 +211,5 @@ with gr.Blocks() as app:
220
 
221
  with gr.Tab("Create Database"):
222
  gr.Markdown("## Create a New FAISS Database")
223
- file_input = gr.File(label="Upload .txt File")
224
- db_name_input = gr.Textbox(label="Enter Desired Database Name (Alphanumeric Only)")
225
- password_input = gr.Textbox(label="Enter Password", type="password")
226
- create_output = gr.Textbox(label="Status")
227
- create_button = gr.Button("Create Database")
228
-
229
- def handle_create(file, db_name, password, progress=gr.Progress()):
230
- if not file or not db_name or not password:
231
- return "Please provide all required inputs.", []
232
-
233
- # Check if the file is valid
234
- if isinstance(file, str): # Gradio provides the file path as a string
235
- try:
236
- with open(file, "r", encoding="utf-8") as f:
237
- file_content = f.read()
238
- except Exception as e:
239
- return f"Error reading file: {str(e)}", []
240
- else:
241
- return "Invalid file format. Please upload a .txt file.", []
242
-
243
- result, db_list = create_new_database(file_content, db_name, password, progress)
244
- return result, db_list
245
 
246
- create_button.click(
247
- handle_create,
248
- inputs=[file_input, db_name_input, password_input],
249
- outputs=[create_output, db_list_state]
250
- )
251
-
252
- with gr.Tab("Chat with Database"):
253
- gr.Markdown("## Chat with Existing Databases")
254
- db_select = gr.Dropdown(choices=[], label="Select Database")
255
- chatbot = gr.Chatbot(height=500)
256
- msg = gr.Textbox(label="Votre question")
257
- clear = gr.ClearButton([msg, chatbot])
258
-
259
- def chat_response(message: str, db_name: str, history: List[Tuple[str, str]]):
260
- if not db_name:
261
- return "", history + [("System", "Please select a database to chat with.")]
262
- response = generate_response(message, db_name)
263
- return "", history + [(message, response)]
264
-
265
- msg.submit(
266
- chat_response,
267
- inputs=[msg, db_select, chatbot],
268
- outputs=[msg, chatbot],
269
- queue=True
270
- )
271
-
272
- # Update dropdown on page load
273
- db_select.choices = update_db_list()
274
-
275
- # Update dropdown when db_list_state changes
276
- db_list_state.change(
277
- lambda dbs: gr.Dropdown.update(choices=dbs),
278
- inputs=db_list_state,
279
- outputs=db_select
280
- )
281
 
282
- if __name__ == "__main__":
283
- app.launch(server_name="0.0.0.0", server_port=7860)
 
9
  from tqdm import tqdm
10
 
11
  # Configuration
12
+ QWEN_API_URL = os.getenv("QWEN_API_URL", "Qwen/Qwen2.5-Max-Demo") # Environment variable for Qwen API URL
13
  CHUNK_SIZE = 800
14
  TOP_K_RESULTS = 150
15
  SIMILARITY_THRESHOLD = 0.4
16
+ PASSWORD_HASH = os.getenv("PASSWORD_HASH", "abc12345") # Environment variable for password
17
 
18
  BASE_SYSTEM_PROMPT = """
19
  Répondez en français selon ces règles :
20
 
21
+ 1. Utilisez EXCLUSIVEMENT le contexte fourni.
22
  2. Structurez la réponse en :
23
+ - Définition principale.
24
+ - Caractéristiques clés (3 points maximum).
25
+ - Relations avec d'autres concepts.
26
+ 3. Si aucune information pertinente, indiquez-le clairement.
27
 
28
  Contexte :
29
  {context}
 
54
  return self.model.encode(text).tolist()
55
 
56
  def split_text_into_chunks(text: str) -> List[str]:
57
+ """Split text into chunks with overlap and sentence preservation"""
58
  chunks = []
59
  start = 0
60
  text_length = len(text)
 
98
  if os.path.exists(faiss_file) or os.path.exists(pkl_file):
99
  return f"Database '{db_name}' already exists.", []
100
 
101
+ # Initialize embeddings and split text into chunks
102
  chunks = split_text_into_chunks(file_content)
103
  if not chunks:
104
  return "No valid chunks generated. Database creation failed.", []
 
118
  embedding=embeddings
119
  )
120
 
121
+ # Save FAISS database locally
122
+ vector_store.save_local(".")
123
+
124
+ # Verify files were created successfully
 
 
 
 
 
125
  if not os.path.exists(faiss_file) or not os.path.exists(pkl_file):
126
  return "Failed to save FAISS database files. Please check file permissions.", []
127
+
128
  logging.info(f"FAISS database files created: {faiss_file}, {pkl_file}")
129
 
130
  # Update the list of available databases
131
  db_list = [os.path.splitext(f)[0].replace("-index", "") for f in os.listdir(".") if f.endswith(".faiss")]
132
+
133
  return f"Database '{db_name}' created successfully.", db_list
134
 
135
  except Exception as e:
 
148
  if not os.path.exists(faiss_file) or not os.path.exists(pkl_file):
149
  return f"Database '{db_name}' does not exist."
150
 
151
+ vector_store = FAISS.load_local(".", embeddings, allow_dangerous_deserialization=True)
 
 
 
 
152
 
153
+ # Perform contextual search in the database
154
  docs_scores = vector_store.similarity_search_with_score(
155
  user_input,
156
  k=TOP_K_RESULTS * 3
157
  )
158
 
 
159
  filtered_docs = [
160
  (doc, score) for doc, score in docs_scores
161
  if score < SIMILARITY_THRESHOLD
162
  ]
163
+
164
  filtered_docs.sort(key=lambda x: x[1])
165
 
166
  if not filtered_docs:
 
168
 
169
  best_docs = [doc for doc, _ in filtered_docs[:TOP_K_RESULTS]]
170
 
 
171
  context = "\n".join(
172
  f"=== Source {i+1} ===\n{doc.page_content}\n"
173
  for i, doc in enumerate(best_docs)
174
  )
175
 
 
176
  client = Client(QWEN_API_URL, verbose=False)
177
+
178
  response = client.predict(
179
  query=user_input,
180
  history=[],
 
182
  api_name="/model_chat"
183
  )
184
 
 
185
  if isinstance(response, tuple) and len(response) >= 2:
186
  chat_history = response[1]
187
  if chat_history and len(chat_history[-1]) >= 2:
 
198
  model = SentenceTransformer("cnmoro/snowflake-arctic-embed-m-v2.0-cpu", device=device, trust_remote_code=True)
199
  embeddings = LocalEmbeddings(model)
200
 
201
+ # Gradio interface setup remains unchanged from your original code.
202
  with gr.Blocks() as app:
203
  gr.Markdown("# Local Tech Knowledge Assistant")
204
 
 
211
 
212
  with gr.Tab("Create Database"):
213
  gr.Markdown("## Create a New FAISS Database")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215