Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -24,20 +24,11 @@ logger = logging.getLogger(__name__)
|
|
24 |
# Environment setup for Hugging Face token
|
25 |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN", "default-token")
|
26 |
if os.environ["HUGGINGFACEHUB_API_TOKEN"] == "default-token":
|
27 |
-
logger.warning("HUGGINGFACEHUB_API_TOKEN not set.
|
28 |
|
29 |
-
# Model and embedding
|
30 |
-
|
31 |
-
|
32 |
-
"Balanced (Gemma-2-2B)": "google/gemma-2-2b-it",
|
33 |
-
"Lightweight (Mistral-7B)": "mistralai/Mistral-7B-Instruct-v0.2"
|
34 |
-
}
|
35 |
-
|
36 |
-
EMBEDDING_MODELS = {
|
37 |
-
"Lightweight (MiniLM-L6)": "sentence-transformers/all-MiniLM-L6-v2",
|
38 |
-
"Balanced (MPNet-Base)": "sentence-transformers/all-mpnet-base-v2",
|
39 |
-
"High Accuracy (BGE-Large)": "BAAI/bge-large-en-v1.5"
|
40 |
-
}
|
41 |
|
42 |
# Global state
|
43 |
vector_store = None
|
@@ -93,7 +84,7 @@ def load_documents(files):
|
|
93 |
return documents
|
94 |
|
95 |
# Function to process documents and create vector store
|
96 |
-
def process_documents(files, chunk_size, chunk_overlap
|
97 |
global vector_store
|
98 |
if not files:
|
99 |
return "Please upload at least one document.", None
|
@@ -128,14 +119,13 @@ def process_documents(files, chunk_size, chunk_overlap, embedding_model):
|
|
128 |
|
129 |
# Create embeddings
|
130 |
try:
|
131 |
-
embeddings = HuggingFaceEmbeddings(model_name=
|
132 |
except Exception as e:
|
133 |
-
logger.error(f"Error initializing embeddings
|
134 |
return f"Error initializing embeddings: {str(e)}", None
|
135 |
|
136 |
# Create vector store
|
137 |
try:
|
138 |
-
# Use in-memory Chroma client to avoid filesystem issues
|
139 |
collection_name = f"doctalk_collection_{int(time.time())}"
|
140 |
client = chromadb.Client()
|
141 |
vector_store = Chroma.from_documents(
|
@@ -154,21 +144,20 @@ def process_documents(files, chunk_size, chunk_overlap, embedding_model):
|
|
154 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
155 |
retry=retry_if_exception_type((requests.exceptions.HTTPError, requests.exceptions.ConnectionError))
|
156 |
)
|
157 |
-
def initialize_qa_chain(
|
158 |
global qa_chain
|
159 |
if not vector_store:
|
160 |
return "Please process documents first.", None
|
161 |
|
162 |
try:
|
163 |
llm = HuggingFaceEndpoint(
|
164 |
-
repo_id=
|
165 |
task="text-generation",
|
166 |
temperature=float(temperature),
|
167 |
max_new_tokens=512,
|
168 |
huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"],
|
169 |
timeout=30
|
170 |
)
|
171 |
-
# Dynamically set k based on vector store size
|
172 |
collection = vector_store._collection
|
173 |
doc_count = collection.count()
|
174 |
k = min(3, doc_count) if doc_count > 0 else 1
|
@@ -177,17 +166,17 @@ def initialize_qa_chain(llm_model, temperature):
|
|
177 |
retriever=vector_store.as_retriever(search_kwargs={"k": k}),
|
178 |
memory=memory
|
179 |
)
|
180 |
-
logger.info(f"Initialized QA chain with {
|
181 |
return "QA Doctor: QA chain initialized successfully.", None
|
182 |
except requests.exceptions.HTTPError as e:
|
183 |
-
logger.error(f"HTTP error initializing QA chain
|
184 |
if "503" in str(e):
|
185 |
-
return
|
186 |
elif "403" in str(e):
|
187 |
-
return
|
188 |
return f"Error initializing QA chain: {str(e)}.", None
|
189 |
except Exception as e:
|
190 |
-
logger.error(f"Error initializing QA chain
|
191 |
return f"Error initializing QA chain: {str(e)}. Ensure your HF token is valid.", None
|
192 |
|
193 |
# Function to handle user query with retry logic
|
@@ -196,7 +185,7 @@ def initialize_qa_chain(llm_model, temperature):
|
|
196 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
197 |
retry=retry_if_exception_type((requests.exceptions.HTTPError, requests.exceptions.ConnectionError))
|
198 |
)
|
199 |
-
def answer_question(question,
|
200 |
global chat_history
|
201 |
if not vector_store:
|
202 |
return "Please process documents first.", chat_history
|
@@ -214,9 +203,9 @@ def answer_question(question, llm_model, embedding_model, temperature, chunk_siz
|
|
214 |
except requests.exceptions.HTTPError as e:
|
215 |
logger.error(f"HTTP error answering question: {str(e)}")
|
216 |
if "503" in str(e):
|
217 |
-
return
|
218 |
elif "403" in str(e):
|
219 |
-
return
|
220 |
return f"Error answering question: {str(e)}", chat_history
|
221 |
except Exception as e:
|
222 |
logger.error(f"Error answering question: {str(e)}")
|
@@ -261,7 +250,7 @@ def reset_app():
|
|
261 |
# Gradio interface
|
262 |
with gr.Blocks(theme=gr.themes.Soft(), title="DocTalk: Document Q&A Chatbot") as demo:
|
263 |
gr.Markdown("# DocTalk: Document Q&A Chatbot")
|
264 |
-
gr.Markdown("Upload documents (PDF, TXT, DOCX, PPTX),
|
265 |
|
266 |
with gr.Row():
|
267 |
with gr.Column(scale=2):
|
@@ -272,8 +261,6 @@ with gr.Blocks(theme=gr.themes.Soft(), title="DocTalk: Document Q&A Chatbot") as
|
|
272 |
status = gr.Textbox(label="Status", interactive=False)
|
273 |
|
274 |
with gr.Column(scale=1):
|
275 |
-
llm_model = gr.Dropdown(choices=list(LLM_MODELS.keys()), label="Select LLM Model", value="High Accuracy (Mixtral-8x7B)")
|
276 |
-
embedding_model = gr.Dropdown(choices=list(EMBEDDING_MODELS.keys()), label="Select Embedding Model", value="Lightweight (MiniLM-L6)")
|
277 |
temperature = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.7, label="Temperature")
|
278 |
chunk_size = gr.Slider(minimum=500, maximum=2000, step=100, value=1000, label="Chunk Size")
|
279 |
chunk_overlap = gr.Slider(minimum=0, maximum=500, step=50, value=100, label="Chunk Overlap")
|
@@ -289,17 +276,17 @@ with gr.Blocks(theme=gr.themes.Soft(), title="DocTalk: Document Q&A Chatbot") as
|
|
289 |
# Event handlers
|
290 |
process_button.click(
|
291 |
fn=process_documents,
|
292 |
-
inputs=[file_upload, chunk_size, chunk_overlap
|
293 |
outputs=[status, chat_display]
|
294 |
)
|
295 |
init_button.click(
|
296 |
fn=initialize_qa_chain,
|
297 |
-
inputs=[
|
298 |
outputs=[status, chat_display]
|
299 |
)
|
300 |
question.submit(
|
301 |
fn=answer_question,
|
302 |
-
inputs=[question,
|
303 |
outputs=[answer, chat_display]
|
304 |
)
|
305 |
export_button.click(
|
|
|
24 |
# Environment setup for Hugging Face token
|
25 |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN", "default-token")
|
26 |
if os.environ["HUGGINGFACEHUB_API_TOKEN"] == "default-token":
|
27 |
+
logger.warning("HUGGINGFACEHUB_API_TOKEN not set. Model may not work.")
|
28 |
|
29 |
+
# Model and embedding configuration
|
30 |
+
LLM_MODEL = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
31 |
+
EMBEDDING_MODEL = "BAAI/bge-large-en-v1.5"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
# Global state
|
34 |
vector_store = None
|
|
|
84 |
return documents
|
85 |
|
86 |
# Function to process documents and create vector store
|
87 |
+
def process_documents(files, chunk_size, chunk_overlap):
|
88 |
global vector_store
|
89 |
if not files:
|
90 |
return "Please upload at least one document.", None
|
|
|
119 |
|
120 |
# Create embeddings
|
121 |
try:
|
122 |
+
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
|
123 |
except Exception as e:
|
124 |
+
logger.error(f"Error initializing embeddings: {str(e)}")
|
125 |
return f"Error initializing embeddings: {str(e)}", None
|
126 |
|
127 |
# Create vector store
|
128 |
try:
|
|
|
129 |
collection_name = f"doctalk_collection_{int(time.time())}"
|
130 |
client = chromadb.Client()
|
131 |
vector_store = Chroma.from_documents(
|
|
|
144 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
145 |
retry=retry_if_exception_type((requests.exceptions.HTTPError, requests.exceptions.ConnectionError))
|
146 |
)
|
147 |
+
def initialize_qa_chain(temperature):
|
148 |
global qa_chain
|
149 |
if not vector_store:
|
150 |
return "Please process documents first.", None
|
151 |
|
152 |
try:
|
153 |
llm = HuggingFaceEndpoint(
|
154 |
+
repo_id=LLM_MODEL,
|
155 |
task="text-generation",
|
156 |
temperature=float(temperature),
|
157 |
max_new_tokens=512,
|
158 |
huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"],
|
159 |
timeout=30
|
160 |
)
|
|
|
161 |
collection = vector_store._collection
|
162 |
doc_count = collection.count()
|
163 |
k = min(3, doc_count) if doc_count > 0 else 1
|
|
|
166 |
retriever=vector_store.as_retriever(search_kwargs={"k": k}),
|
167 |
memory=memory
|
168 |
)
|
169 |
+
logger.info(f"Initialized QA chain with {LLM_MODEL} and k={k}.")
|
170 |
return "QA Doctor: QA chain initialized successfully.", None
|
171 |
except requests.exceptions.HTTPError as e:
|
172 |
+
logger.error(f"HTTP error initializing QA chain: {str(e)}")
|
173 |
if "503" in str(e):
|
174 |
+
return "Error: Hugging Face API temporarily unavailable. Please wait and retry.", None
|
175 |
elif "403" in str(e):
|
176 |
+
return "Error: Access denied. Check your HF token or upgrade to Pro at https://huggingface.co/settings/billing.", None
|
177 |
return f"Error initializing QA chain: {str(e)}.", None
|
178 |
except Exception as e:
|
179 |
+
logger.error(f"Error initializing QA chain: {str(e)}")
|
180 |
return f"Error initializing QA chain: {str(e)}. Ensure your HF token is valid.", None
|
181 |
|
182 |
# Function to handle user query with retry logic
|
|
|
185 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
186 |
retry=retry_if_exception_type((requests.exceptions.HTTPError, requests.exceptions.ConnectionError))
|
187 |
)
|
188 |
+
def answer_question(question, temperature, chunk_size, chunk_overlap):
|
189 |
global chat_history
|
190 |
if not vector_store:
|
191 |
return "Please process documents first.", chat_history
|
|
|
203 |
except requests.exceptions.HTTPError as e:
|
204 |
logger.error(f"HTTP error answering question: {str(e)}")
|
205 |
if "503" in str(e):
|
206 |
+
return "Error: Hugging Face API temporarily unavailable. Please wait and retry.", chat_history
|
207 |
elif "403" in str(e):
|
208 |
+
return "Error: Access denied. Check your HF token or upgrade to Pro at https://huggingface.co/settings/billing.", chat_history
|
209 |
return f"Error answering question: {str(e)}", chat_history
|
210 |
except Exception as e:
|
211 |
logger.error(f"Error answering question: {str(e)}")
|
|
|
250 |
# Gradio interface
|
251 |
with gr.Blocks(theme=gr.themes.Soft(), title="DocTalk: Document Q&A Chatbot") as demo:
|
252 |
gr.Markdown("# DocTalk: Document Q&A Chatbot")
|
253 |
+
gr.Markdown("Upload documents (PDF, TXT, DOCX, PPTX), tune parameters, and ask questions! Uses Mixtral-8x7B and BGE-Large for high accuracy.")
|
254 |
|
255 |
with gr.Row():
|
256 |
with gr.Column(scale=2):
|
|
|
261 |
status = gr.Textbox(label="Status", interactive=False)
|
262 |
|
263 |
with gr.Column(scale=1):
|
|
|
|
|
264 |
temperature = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.7, label="Temperature")
|
265 |
chunk_size = gr.Slider(minimum=500, maximum=2000, step=100, value=1000, label="Chunk Size")
|
266 |
chunk_overlap = gr.Slider(minimum=0, maximum=500, step=50, value=100, label="Chunk Overlap")
|
|
|
276 |
# Event handlers
|
277 |
process_button.click(
|
278 |
fn=process_documents,
|
279 |
+
inputs=[file_upload, chunk_size, chunk_overlap],
|
280 |
outputs=[status, chat_display]
|
281 |
)
|
282 |
init_button.click(
|
283 |
fn=initialize_qa_chain,
|
284 |
+
inputs=[temperature],
|
285 |
outputs=[status, chat_display]
|
286 |
)
|
287 |
question.submit(
|
288 |
fn=answer_question,
|
289 |
+
inputs=[question, temperature, chunk_size, chunk_overlap],
|
290 |
outputs=[answer, chat_display]
|
291 |
)
|
292 |
export_button.click(
|