localsavageai commited on
Commit
c26c573
·
verified ·
1 Parent(s): 79b023d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -49
app.py CHANGED
@@ -1,15 +1,11 @@
1
  import os
2
  import logging
3
- import numpy as np
4
- from typing import List, Optional, Tuple
5
  import torch
6
  import gradio as gr
7
- import spaces
8
  from sentence_transformers import SentenceTransformer
9
  from langchain_community.vectorstores import FAISS
10
  from langchain.embeddings.base import Embeddings
11
- from gradio_client import Client
12
- import requests
13
  from tqdm import tqdm
14
 
15
  # Configuration
@@ -17,7 +13,7 @@ QWEN_API_URL = "Qwen/Qwen2.5-Max-Demo" # Gradio API for Qwen2.5 chat
17
  CHUNK_SIZE = 800
18
  TOP_K_RESULTS = 150
19
  SIMILARITY_THRESHOLD = 0.4
20
- PASSWORD_HASH = "abc12345" # Replace with hashed password in production
21
 
22
  BASE_SYSTEM_PROMPT = """
23
  Répondez en français selon ces règles :
@@ -83,30 +79,29 @@ def split_text_into_chunks(text: str) -> List[str]:
83
 
84
  return chunks
85
 
86
- def create_new_database(file_content: str, db_name: str, password: str, progress=gr.Progress()) -> str:
87
  """Create a new FAISS database from uploaded file"""
88
  if password != PASSWORD_HASH:
89
- return "Incorrect password. Database creation failed."
90
 
91
  if not file_content.strip():
92
- return "Uploaded file is empty. Database creation failed."
93
 
94
  if not db_name.isalnum():
95
- return "Database name must be alphanumeric. Database creation failed."
96
 
97
  try:
98
- # Define file names for the FAISS database
99
  faiss_file = f"{db_name}-index.faiss"
100
  pkl_file = f"{db_name}-index.pkl"
101
-
102
  # Check if the database already exists
103
  if os.path.exists(faiss_file) or os.path.exists(pkl_file):
104
- return f"Database '{db_name}' already exists."
105
 
106
  # Initialize embeddings and split text
107
  chunks = split_text_into_chunks(file_content)
108
  if not chunks:
109
- return "No valid chunks generated. Database creation failed."
110
 
111
  logging.info(f"Creating {len(chunks)} chunks...")
112
  progress(0, desc="Starting embedding process...")
@@ -122,39 +117,34 @@ def create_new_database(file_content: str, db_name: str, password: str, progress
122
  text_embeddings=list(zip(chunks, embeddings_list)),
123
  embedding=embeddings
124
  )
125
-
126
- # Save the FAISS database to the root directory
127
  vector_store.save_local(".")
128
  logging.info(f"FAISS database saved to: {faiss_file} and {pkl_file}")
129
 
130
- # Rename the default FAISS files to match the desired naming convention
131
- os.rename("index.faiss", faiss_file)
132
- os.rename("index.pkl", pkl_file)
133
-
134
  # Verify files were created
135
  if not os.path.exists(faiss_file) or not os.path.exists(pkl_file):
136
- return f"Failed to save FAISS database files: {faiss_file} or {pkl_file}"
137
- logging.info(f"FAISS database files: {faiss_file}, {pkl_file}")
 
 
 
 
138
 
139
- return f"Database '{db_name}' created successfully."
140
  except Exception as e:
141
  logging.error(f"Database creation failed: {str(e)}")
142
- return f"Error creating database: {str(e)}"
143
 
144
- def generate_response(user_input: str, db_name: str) -> Optional[str]:
145
  """Generate response using Qwen2.5 MAX"""
146
  try:
147
  if not db_name:
148
  return "Please select a database to chat with."
149
 
150
- # Define file names for the FAISS database
151
  faiss_file = f"{db_name}-index.faiss"
152
  pkl_file = f"{db_name}-index.pkl"
153
 
154
  if not os.path.exists(faiss_file) or not os.path.exists(pkl_file):
155
  return f"Database '{db_name}' does not exist."
156
 
157
- # Load the FAISS database
158
  vector_store = FAISS.load_local(".", embeddings, allow_dangerous_deserialization=True)
159
 
160
  # Contextual search
@@ -200,7 +190,7 @@ def generate_response(user_input: str, db_name: str) -> Optional[str]:
200
 
201
  except Exception as e:
202
  logging.error(f"Generation error: {str(e)}", exc_info=True)
203
- return None
204
 
205
  # Initialize models and vector store
206
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -216,11 +206,7 @@ with gr.Blocks() as app:
216
 
217
  def update_db_list():
218
  """Update the list of available databases"""
219
- return [
220
- name.replace("-index.faiss", "") # Remove "-index.faiss" suffix for display
221
- for name in os.listdir(".")
222
- if name.endswith("-index.faiss")
223
- ]
224
 
225
  with gr.Tab("Create Database"):
226
  gr.Markdown("## Create a New FAISS Database")
@@ -232,7 +218,7 @@ with gr.Blocks() as app:
232
 
233
  def handle_create(file, db_name, password, progress=gr.Progress()):
234
  if not file or not db_name or not password:
235
- return "Please provide all required inputs."
236
 
237
  # Check if the file is valid
238
  if isinstance(file, str): # Gradio provides the file path as a string
@@ -240,15 +226,12 @@ with gr.Blocks() as app:
240
  with open(file, "r", encoding="utf-8") as f:
241
  file_content = f.read()
242
  except Exception as e:
243
- return f"Error reading file: {str(e)}"
244
  else:
245
- return "Invalid file format. Please upload a .txt file."
246
 
247
- result = create_new_database(file_content, db_name, password, progress)
248
- if "created successfully" in result:
249
- # Update the database list
250
- return result, update_db_list()
251
- return result, None
252
 
253
  create_button.click(
254
  handle_create,
@@ -267,8 +250,8 @@ with gr.Blocks() as app:
267
  if not db_name:
268
  return "", history + [("System", "Please select a database to chat with.")]
269
  response = generate_response(message, db_name)
270
- return "", history + [(message, response or "Erreur de génération - Veuillez réessayer.")]
271
-
272
  msg.submit(
273
  chat_response,
274
  inputs=[msg, db_select, chatbot],
@@ -287,10 +270,4 @@ with gr.Blocks() as app:
287
  )
288
 
289
  if __name__ == "__main__":
290
- # Log existing databases at startup
291
- logging.info("Existing databases:")
292
- for name in os.listdir("."):
293
- if name.endswith("-index.faiss"):
294
- logging.info(f"- {name}")
295
-
296
  app.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import os
2
  import logging
3
+ from typing import List, Tuple
 
4
  import torch
5
  import gradio as gr
 
6
  from sentence_transformers import SentenceTransformer
7
  from langchain_community.vectorstores import FAISS
8
  from langchain.embeddings.base import Embeddings
 
 
9
  from tqdm import tqdm
10
 
11
  # Configuration
 
13
  CHUNK_SIZE = 800
14
  TOP_K_RESULTS = 150
15
  SIMILARITY_THRESHOLD = 0.4
16
+ PASSWORD_HASH = os.getenv("PASSWORD_HASH", "default_password") # Use environment variable for password
17
 
18
  BASE_SYSTEM_PROMPT = """
19
  Répondez en français selon ces règles :
 
79
 
80
  return chunks
81
 
82
+ def create_new_database(file_content: str, db_name: str, password: str, progress=gr.Progress()) -> Tuple[str, List[str]]:
83
  """Create a new FAISS database from uploaded file"""
84
  if password != PASSWORD_HASH:
85
+ return "Incorrect password. Database creation failed.", []
86
 
87
  if not file_content.strip():
88
+ return "Uploaded file is empty. Database creation failed.", []
89
 
90
  if not db_name.isalnum():
91
+ return "Database name must be alphanumeric. Database creation failed.", []
92
 
93
  try:
 
94
  faiss_file = f"{db_name}-index.faiss"
95
  pkl_file = f"{db_name}-index.pkl"
96
+
97
  # Check if the database already exists
98
  if os.path.exists(faiss_file) or os.path.exists(pkl_file):
99
+ return f"Database '{db_name}' already exists.", []
100
 
101
  # Initialize embeddings and split text
102
  chunks = split_text_into_chunks(file_content)
103
  if not chunks:
104
+ return "No valid chunks generated. Database creation failed.", []
105
 
106
  logging.info(f"Creating {len(chunks)} chunks...")
107
  progress(0, desc="Starting embedding process...")
 
117
  text_embeddings=list(zip(chunks, embeddings_list)),
118
  embedding=embeddings
119
  )
 
 
120
  vector_store.save_local(".")
121
  logging.info(f"FAISS database saved to: {faiss_file} and {pkl_file}")
122
 
 
 
 
 
123
  # Verify files were created
124
  if not os.path.exists(faiss_file) or not os.path.exists(pkl_file):
125
+ return f"Failed to save FAISS database files.", []
126
+ logging.info(f"FAISS database files created: {faiss_file}, {pkl_file}")
127
+
128
+ # Update the list of available databases
129
+ db_list = [os.path.splitext(f)[0].replace("-index", "") for f in os.listdir(".") if f.endswith(".faiss")]
130
+ return f"Database '{db_name}' created successfully.", db_list
131
 
 
132
  except Exception as e:
133
  logging.error(f"Database creation failed: {str(e)}")
134
+ return f"Error creating database: {str(e)}", []
135
 
136
+ def generate_response(user_input: str, db_name: str) -> str:
137
  """Generate response using Qwen2.5 MAX"""
138
  try:
139
  if not db_name:
140
  return "Please select a database to chat with."
141
 
 
142
  faiss_file = f"{db_name}-index.faiss"
143
  pkl_file = f"{db_name}-index.pkl"
144
 
145
  if not os.path.exists(faiss_file) or not os.path.exists(pkl_file):
146
  return f"Database '{db_name}' does not exist."
147
 
 
148
  vector_store = FAISS.load_local(".", embeddings, allow_dangerous_deserialization=True)
149
 
150
  # Contextual search
 
190
 
191
  except Exception as e:
192
  logging.error(f"Generation error: {str(e)}", exc_info=True)
193
+ return "Erreur de génération - Veuillez réessayer."
194
 
195
  # Initialize models and vector store
196
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
206
 
207
  def update_db_list():
208
  """Update the list of available databases"""
209
+ return [os.path.splitext(f)[0].replace("-index", "") for f in os.listdir(".") if f.endswith(".faiss")]
 
 
 
 
210
 
211
  with gr.Tab("Create Database"):
212
  gr.Markdown("## Create a New FAISS Database")
 
218
 
219
  def handle_create(file, db_name, password, progress=gr.Progress()):
220
  if not file or not db_name or not password:
221
+ return "Please provide all required inputs.", []
222
 
223
  # Check if the file is valid
224
  if isinstance(file, str): # Gradio provides the file path as a string
 
226
  with open(file, "r", encoding="utf-8") as f:
227
  file_content = f.read()
228
  except Exception as e:
229
+ return f"Error reading file: {str(e)}", []
230
  else:
231
+ return "Invalid file format. Please upload a .txt file.", []
232
 
233
+ result, db_list = create_new_database(file_content, db_name, password, progress)
234
+ return result, db_list
 
 
 
235
 
236
  create_button.click(
237
  handle_create,
 
250
  if not db_name:
251
  return "", history + [("System", "Please select a database to chat with.")]
252
  response = generate_response(message, db_name)
253
+ return "", history + [(message, response)]
254
+
255
  msg.submit(
256
  chat_response,
257
  inputs=[msg, db_select, chatbot],
 
270
  )
271
 
272
  if __name__ == "__main__":
 
 
 
 
 
 
273
  app.launch(server_name="0.0.0.0", server_port=7860)