logasanjeev commited on
Commit
88ec884
·
verified ·
1 Parent(s): 8b40c0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -34
app.py CHANGED
@@ -24,20 +24,11 @@ logger = logging.getLogger(__name__)
24
  # Environment setup for Hugging Face token
25
  os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN", "default-token")
26
  if os.environ["HUGGINGFACEHUB_API_TOKEN"] == "default-token":
27
- logger.warning("HUGGINGFACEHUB_API_TOKEN not set. Some models may not work.")
28
 
29
- # Model and embedding options
30
- LLM_MODELS = {
31
- "High Accuracy (Mixtral-8x7B)": "mistralai/Mixtral-8x7B-Instruct-v0.1",
32
- "Balanced (Gemma-2-2B)": "google/gemma-2-2b-it",
33
- "Lightweight (Mistral-7B)": "mistralai/Mistral-7B-Instruct-v0.2"
34
- }
35
-
36
- EMBEDDING_MODELS = {
37
- "Lightweight (MiniLM-L6)": "sentence-transformers/all-MiniLM-L6-v2",
38
- "Balanced (MPNet-Base)": "sentence-transformers/all-mpnet-base-v2",
39
- "High Accuracy (BGE-Large)": "BAAI/bge-large-en-v1.5"
40
- }
41
 
42
  # Global state
43
  vector_store = None
@@ -93,7 +84,7 @@ def load_documents(files):
93
  return documents
94
 
95
  # Function to process documents and create vector store
96
- def process_documents(files, chunk_size, chunk_overlap, embedding_model):
97
  global vector_store
98
  if not files:
99
  return "Please upload at least one document.", None
@@ -128,14 +119,13 @@ def process_documents(files, chunk_size, chunk_overlap, embedding_model):
128
 
129
  # Create embeddings
130
  try:
131
- embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODELS[embedding_model])
132
  except Exception as e:
133
- logger.error(f"Error initializing embeddings for {embedding_model}: {str(e)}")
134
  return f"Error initializing embeddings: {str(e)}", None
135
 
136
  # Create vector store
137
  try:
138
- # Use in-memory Chroma client to avoid filesystem issues
139
  collection_name = f"doctalk_collection_{int(time.time())}"
140
  client = chromadb.Client()
141
  vector_store = Chroma.from_documents(
@@ -154,21 +144,20 @@ def process_documents(files, chunk_size, chunk_overlap, embedding_model):
154
  wait=wait_exponential(multiplier=1, min=4, max=10),
155
  retry=retry_if_exception_type((requests.exceptions.HTTPError, requests.exceptions.ConnectionError))
156
  )
157
- def initialize_qa_chain(llm_model, temperature):
158
  global qa_chain
159
  if not vector_store:
160
  return "Please process documents first.", None
161
 
162
  try:
163
  llm = HuggingFaceEndpoint(
164
- repo_id=LLM_MODELS[llm_model],
165
  task="text-generation",
166
  temperature=float(temperature),
167
  max_new_tokens=512,
168
  huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"],
169
  timeout=30
170
  )
171
- # Dynamically set k based on vector store size
172
  collection = vector_store._collection
173
  doc_count = collection.count()
174
  k = min(3, doc_count) if doc_count > 0 else 1
@@ -177,17 +166,17 @@ def initialize_qa_chain(llm_model, temperature):
177
  retriever=vector_store.as_retriever(search_kwargs={"k": k}),
178
  memory=memory
179
  )
180
- logger.info(f"Initialized QA chain with {llm_model} and k={k}.")
181
  return "QA Doctor: QA chain initialized successfully.", None
182
  except requests.exceptions.HTTPError as e:
183
- logger.error(f"HTTP error initializing QA chain for {llm_model}: {str(e)}")
184
  if "503" in str(e):
185
- return f"Error: Hugging Face API temporarily unavailable for {llm_model}. Try 'High Accuracy (Mixtral-8x7B)' or wait and retry.", None
186
  elif "403" in str(e):
187
- return f"Error: Access denied for {llm_model}. Free-tier API limits models >10GB. Try 'High Accuracy (Mixtral-8x7B)' or upgrade to Pro at https://huggingface.co/settings/billing.", None
188
  return f"Error initializing QA chain: {str(e)}.", None
189
  except Exception as e:
190
- logger.error(f"Error initializing QA chain for {llm_model}: {str(e)}")
191
  return f"Error initializing QA chain: {str(e)}. Ensure your HF token is valid.", None
192
 
193
  # Function to handle user query with retry logic
@@ -196,7 +185,7 @@ def initialize_qa_chain(llm_model, temperature):
196
  wait=wait_exponential(multiplier=1, min=4, max=10),
197
  retry=retry_if_exception_type((requests.exceptions.HTTPError, requests.exceptions.ConnectionError))
198
  )
199
- def answer_question(question, llm_model, embedding_model, temperature, chunk_size, chunk_overlap):
200
  global chat_history
201
  if not vector_store:
202
  return "Please process documents first.", chat_history
@@ -214,9 +203,9 @@ def answer_question(question, llm_model, embedding_model, temperature, chunk_siz
214
  except requests.exceptions.HTTPError as e:
215
  logger.error(f"HTTP error answering question: {str(e)}")
216
  if "503" in str(e):
217
- return f"Error: Hugging Face API temporarily unavailable for {llm_model}. Try 'High Accuracy (Mixtral-8x7B)' or wait and retry.", chat_history
218
  elif "403" in str(e):
219
- return f"Error: Access denied for {llm_model}. Free-tier API limits models >10GB. Try 'High Accuracy (Mixtral-8x7B)' or upgrade to Pro at https://huggingface.co/settings/billing.", chat_history
220
  return f"Error answering question: {str(e)}", chat_history
221
  except Exception as e:
222
  logger.error(f"Error answering question: {str(e)}")
@@ -261,7 +250,7 @@ def reset_app():
261
  # Gradio interface
262
  with gr.Blocks(theme=gr.themes.Soft(), title="DocTalk: Document Q&A Chatbot") as demo:
263
  gr.Markdown("# DocTalk: Document Q&A Chatbot")
264
- gr.Markdown("Upload documents (PDF, TXT, DOCX, PPTX), select models, tune parameters, and ask questions!")
265
 
266
  with gr.Row():
267
  with gr.Column(scale=2):
@@ -272,8 +261,6 @@ with gr.Blocks(theme=gr.themes.Soft(), title="DocTalk: Document Q&A Chatbot") as
272
  status = gr.Textbox(label="Status", interactive=False)
273
 
274
  with gr.Column(scale=1):
275
- llm_model = gr.Dropdown(choices=list(LLM_MODELS.keys()), label="Select LLM Model", value="High Accuracy (Mixtral-8x7B)")
276
- embedding_model = gr.Dropdown(choices=list(EMBEDDING_MODELS.keys()), label="Select Embedding Model", value="Lightweight (MiniLM-L6)")
277
  temperature = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.7, label="Temperature")
278
  chunk_size = gr.Slider(minimum=500, maximum=2000, step=100, value=1000, label="Chunk Size")
279
  chunk_overlap = gr.Slider(minimum=0, maximum=500, step=50, value=100, label="Chunk Overlap")
@@ -289,17 +276,17 @@ with gr.Blocks(theme=gr.themes.Soft(), title="DocTalk: Document Q&A Chatbot") as
289
  # Event handlers
290
  process_button.click(
291
  fn=process_documents,
292
- inputs=[file_upload, chunk_size, chunk_overlap, embedding_model],
293
  outputs=[status, chat_display]
294
  )
295
  init_button.click(
296
  fn=initialize_qa_chain,
297
- inputs=[llm_model, temperature],
298
  outputs=[status, chat_display]
299
  )
300
  question.submit(
301
  fn=answer_question,
302
- inputs=[question, llm_model, embedding_model, temperature, chunk_size, chunk_overlap],
303
  outputs=[answer, chat_display]
304
  )
305
  export_button.click(
 
24
  # Environment setup for Hugging Face token
25
  os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN", "default-token")
26
  if os.environ["HUGGINGFACEHUB_API_TOKEN"] == "default-token":
27
+ logger.warning("HUGGINGFACEHUB_API_TOKEN not set. Model may not work.")
28
 
29
+ # Model and embedding configuration
30
+ LLM_MODEL = "mistralai/Mixtral-8x7B-Instruct-v0.1"
31
+ EMBEDDING_MODEL = "BAAI/bge-large-en-v1.5"
 
 
 
 
 
 
 
 
 
32
 
33
  # Global state
34
  vector_store = None
 
84
  return documents
85
 
86
  # Function to process documents and create vector store
87
+ def process_documents(files, chunk_size, chunk_overlap):
88
  global vector_store
89
  if not files:
90
  return "Please upload at least one document.", None
 
119
 
120
  # Create embeddings
121
  try:
122
+ embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
123
  except Exception as e:
124
+ logger.error(f"Error initializing embeddings: {str(e)}")
125
  return f"Error initializing embeddings: {str(e)}", None
126
 
127
  # Create vector store
128
  try:
 
129
  collection_name = f"doctalk_collection_{int(time.time())}"
130
  client = chromadb.Client()
131
  vector_store = Chroma.from_documents(
 
144
  wait=wait_exponential(multiplier=1, min=4, max=10),
145
  retry=retry_if_exception_type((requests.exceptions.HTTPError, requests.exceptions.ConnectionError))
146
  )
147
+ def initialize_qa_chain(temperature):
148
  global qa_chain
149
  if not vector_store:
150
  return "Please process documents first.", None
151
 
152
  try:
153
  llm = HuggingFaceEndpoint(
154
+ repo_id=LLM_MODEL,
155
  task="text-generation",
156
  temperature=float(temperature),
157
  max_new_tokens=512,
158
  huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"],
159
  timeout=30
160
  )
 
161
  collection = vector_store._collection
162
  doc_count = collection.count()
163
  k = min(3, doc_count) if doc_count > 0 else 1
 
166
  retriever=vector_store.as_retriever(search_kwargs={"k": k}),
167
  memory=memory
168
  )
169
+ logger.info(f"Initialized QA chain with {LLM_MODEL} and k={k}.")
170
  return "QA Doctor: QA chain initialized successfully.", None
171
  except requests.exceptions.HTTPError as e:
172
+ logger.error(f"HTTP error initializing QA chain: {str(e)}")
173
  if "503" in str(e):
174
+ return "Error: Hugging Face API temporarily unavailable. Please wait and retry.", None
175
  elif "403" in str(e):
176
+ return "Error: Access denied. Check your HF token or upgrade to Pro at https://huggingface.co/settings/billing.", None
177
  return f"Error initializing QA chain: {str(e)}.", None
178
  except Exception as e:
179
+ logger.error(f"Error initializing QA chain: {str(e)}")
180
  return f"Error initializing QA chain: {str(e)}. Ensure your HF token is valid.", None
181
 
182
  # Function to handle user query with retry logic
 
185
  wait=wait_exponential(multiplier=1, min=4, max=10),
186
  retry=retry_if_exception_type((requests.exceptions.HTTPError, requests.exceptions.ConnectionError))
187
  )
188
+ def answer_question(question, temperature, chunk_size, chunk_overlap):
189
  global chat_history
190
  if not vector_store:
191
  return "Please process documents first.", chat_history
 
203
  except requests.exceptions.HTTPError as e:
204
  logger.error(f"HTTP error answering question: {str(e)}")
205
  if "503" in str(e):
206
+ return "Error: Hugging Face API temporarily unavailable. Please wait and retry.", chat_history
207
  elif "403" in str(e):
208
+ return "Error: Access denied. Check your HF token or upgrade to Pro at https://huggingface.co/settings/billing.", chat_history
209
  return f"Error answering question: {str(e)}", chat_history
210
  except Exception as e:
211
  logger.error(f"Error answering question: {str(e)}")
 
250
  # Gradio interface
251
  with gr.Blocks(theme=gr.themes.Soft(), title="DocTalk: Document Q&A Chatbot") as demo:
252
  gr.Markdown("# DocTalk: Document Q&A Chatbot")
253
+ gr.Markdown("Upload documents (PDF, TXT, DOCX, PPTX), tune parameters, and ask questions! Uses Mixtral-8x7B and BGE-Large for high accuracy.")
254
 
255
  with gr.Row():
256
  with gr.Column(scale=2):
 
261
  status = gr.Textbox(label="Status", interactive=False)
262
 
263
  with gr.Column(scale=1):
 
 
264
  temperature = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.7, label="Temperature")
265
  chunk_size = gr.Slider(minimum=500, maximum=2000, step=100, value=1000, label="Chunk Size")
266
  chunk_overlap = gr.Slider(minimum=0, maximum=500, step=50, value=100, label="Chunk Overlap")
 
276
  # Event handlers
277
  process_button.click(
278
  fn=process_documents,
279
+ inputs=[file_upload, chunk_size, chunk_overlap],
280
  outputs=[status, chat_display]
281
  )
282
  init_button.click(
283
  fn=initialize_qa_chain,
284
+ inputs=[temperature],
285
  outputs=[status, chat_display]
286
  )
287
  question.submit(
288
  fn=answer_question,
289
+ inputs=[question, temperature, chunk_size, chunk_overlap],
290
  outputs=[answer, chat_display]
291
  )
292
  export_button.click(