Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,48 +1,16 @@
|
|
1 |
import os
|
2 |
import json
|
3 |
-
import
|
4 |
-
import
|
5 |
-
import
|
6 |
-
from duckduckgo_search import DDGS
|
7 |
from typing import List
|
8 |
from pydantic import BaseModel, Field
|
9 |
-
from tempfile import NamedTemporaryFile
|
10 |
-
from langchain_community.vectorstores import FAISS
|
11 |
-
from langchain_core.vectorstores import VectorStore
|
12 |
from langchain_core.documents import Document
|
13 |
from langchain_community.document_loaders import PyPDFLoader
|
|
|
14 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
15 |
from llama_parse import LlamaParse
|
16 |
-
|
17 |
-
from huggingface_hub import InferenceClient
|
18 |
-
import inspect
|
19 |
-
import logging
|
20 |
-
import shutil
|
21 |
-
|
22 |
-
|
23 |
-
# Set up basic configuration for logging
|
24 |
-
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
25 |
-
|
26 |
-
# Environment variables and configurations
|
27 |
-
huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
|
28 |
-
llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY")
|
29 |
-
ACCOUNT_ID = os.environ.get("CLOUDFARE_ACCOUNT_ID")
|
30 |
-
API_TOKEN = os.environ.get("CLOUDFLARE_AUTH_TOKEN")
|
31 |
-
API_BASE_URL = "https://api.cloudflare.com/client/v4/accounts/a17f03e0f049ccae0c15cdcf3b9737ce/ai/run/"
|
32 |
-
|
33 |
-
print(f"ACCOUNT_ID: {ACCOUNT_ID}")
|
34 |
-
print(f"CLOUDFLARE_AUTH_TOKEN: {API_TOKEN[:5]}..." if API_TOKEN else "Not set")
|
35 |
-
|
36 |
-
MODELS = [
|
37 |
-
"mistralai/Mistral-7B-Instruct-v0.3",
|
38 |
-
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
39 |
-
"@cf/meta/llama-3.1-8b-instruct",
|
40 |
-
"mistralai/Mistral-Nemo-Instruct-2407",
|
41 |
-
"gpt-4o-mini",
|
42 |
-
"claude-3-haiku",
|
43 |
-
"llama-3.1-70b",
|
44 |
-
"mixtral-8x7b"
|
45 |
-
]
|
46 |
|
47 |
# Initialize LlamaParse
|
48 |
llama_parser = LlamaParse(
|
@@ -73,7 +41,6 @@ def load_document(file: NamedTemporaryFile, parser: str = "llamaparse") -> List[
|
|
73 |
def get_embeddings():
|
74 |
return HuggingFaceEmbeddings(model_name="avsolatorio/GIST-Embedding-v0")
|
75 |
|
76 |
-
# Add this at the beginning of your script, after imports
|
77 |
DOCUMENTS_FILE = "uploaded_documents.json"
|
78 |
|
79 |
def load_documents():
|
@@ -89,7 +56,6 @@ def save_documents(documents):
|
|
89 |
# Replace the global uploaded_documents with this
|
90 |
uploaded_documents = load_documents()
|
91 |
|
92 |
-
# Modify the update_vectors function
|
93 |
def update_vectors(files, parser):
|
94 |
global uploaded_documents
|
95 |
logging.info(f"Entering update_vectors with {len(files)} files and parser: {parser}")
|
@@ -185,316 +151,6 @@ def delete_documents(selected_docs):
|
|
185 |
|
186 |
return f"Deleted documents: {', '.join(deleted_docs)}", display_documents()
|
187 |
|
188 |
-
def generate_chunked_response(prompt, model, max_tokens=10000, num_calls=3, temperature=0.2, should_stop=False):
|
189 |
-
print(f"Starting generate_chunked_response with {num_calls} calls")
|
190 |
-
full_response = ""
|
191 |
-
messages = [{"role": "user", "content": prompt}]
|
192 |
-
|
193 |
-
if model == "@cf/meta/llama-3.1-8b-instruct":
|
194 |
-
# Cloudflare API
|
195 |
-
for i in range(num_calls):
|
196 |
-
print(f"Starting Cloudflare API call {i+1}")
|
197 |
-
if should_stop:
|
198 |
-
print("Stop clicked, breaking loop")
|
199 |
-
break
|
200 |
-
try:
|
201 |
-
response = requests.post(
|
202 |
-
f"https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai/run/@cf/meta/llama-3.1-8b-instruct",
|
203 |
-
headers={"Authorization": f"Bearer {API_TOKEN}"},
|
204 |
-
json={
|
205 |
-
"stream": true,
|
206 |
-
"messages": [
|
207 |
-
{"role": "system", "content": "You are a friendly assistant"},
|
208 |
-
{"role": "user", "content": prompt}
|
209 |
-
],
|
210 |
-
"max_tokens": max_tokens,
|
211 |
-
"temperature": temperature
|
212 |
-
},
|
213 |
-
stream=true
|
214 |
-
)
|
215 |
-
|
216 |
-
for line in response.iter_lines():
|
217 |
-
if should_stop:
|
218 |
-
print("Stop clicked during streaming, breaking")
|
219 |
-
break
|
220 |
-
if line:
|
221 |
-
try:
|
222 |
-
json_data = json.loads(line.decode('utf-8').split('data: ')[1])
|
223 |
-
chunk = json_data['response']
|
224 |
-
full_response += chunk
|
225 |
-
except json.JSONDecodeError:
|
226 |
-
continue
|
227 |
-
print(f"Cloudflare API call {i+1} completed")
|
228 |
-
except Exception as e:
|
229 |
-
print(f"Error in generating response from Cloudflare: {str(e)}")
|
230 |
-
else:
|
231 |
-
# Original Hugging Face API logic
|
232 |
-
client = InferenceClient(model, token=huggingface_token)
|
233 |
-
|
234 |
-
for i in range(num_calls):
|
235 |
-
print(f"Starting Hugging Face API call {i+1}")
|
236 |
-
if should_stop:
|
237 |
-
print("Stop clicked, breaking loop")
|
238 |
-
break
|
239 |
-
try:
|
240 |
-
for message in client.chat_completion(
|
241 |
-
messages=messages,
|
242 |
-
max_tokens=max_tokens,
|
243 |
-
temperature=temperature,
|
244 |
-
stream=True,
|
245 |
-
):
|
246 |
-
if should_stop:
|
247 |
-
print("Stop clicked during streaming, breaking")
|
248 |
-
break
|
249 |
-
if message.choices and message.choices[0].delta and message.choices[0].delta.content:
|
250 |
-
chunk = message.choices[0].delta.content
|
251 |
-
full_response += chunk
|
252 |
-
print(f"Hugging Face API call {i+1} completed")
|
253 |
-
except Exception as e:
|
254 |
-
print(f"Error in generating response from Hugging Face: {str(e)}")
|
255 |
-
|
256 |
-
# Clean up the response
|
257 |
-
clean_response = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', full_response, flags=re.DOTALL)
|
258 |
-
clean_response = clean_response.replace("Using the following context:", "").strip()
|
259 |
-
clean_response = clean_response.replace("Using the following context from the PDF documents:", "").strip()
|
260 |
-
|
261 |
-
# Remove duplicate paragraphs and sentences
|
262 |
-
paragraphs = clean_response.split('\n\n')
|
263 |
-
unique_paragraphs = []
|
264 |
-
for paragraph in paragraphs:
|
265 |
-
if paragraph not in unique_paragraphs:
|
266 |
-
sentences = paragraph.split('. ')
|
267 |
-
unique_sentences = []
|
268 |
-
for sentence in sentences:
|
269 |
-
if sentence not in unique_sentences:
|
270 |
-
unique_sentences.append(sentence)
|
271 |
-
unique_paragraphs.append('. '.join(unique_sentences))
|
272 |
-
|
273 |
-
final_response = '\n\n'.join(unique_paragraphs)
|
274 |
-
|
275 |
-
print(f"Final clean response: {final_response[:100]}...")
|
276 |
-
return final_response
|
277 |
-
|
278 |
-
def duckduckgo_search(query):
|
279 |
-
with DDGS() as ddgs:
|
280 |
-
results = ddgs.text(query, max_results=5)
|
281 |
-
return results
|
282 |
-
|
283 |
-
class CitingSources(BaseModel):
|
284 |
-
sources: List[str] = Field(
|
285 |
-
...,
|
286 |
-
description="List of sources to cite. Should be an URL of the source."
|
287 |
-
)
|
288 |
-
def chatbot_interface(message, history, use_web_search, model, temperature, num_calls):
|
289 |
-
if not message.strip():
|
290 |
-
return "", history
|
291 |
-
|
292 |
-
history = history + [(message, "")]
|
293 |
-
|
294 |
-
try:
|
295 |
-
for response in respond(message, history, model, temperature, num_calls, use_web_search):
|
296 |
-
history[-1] = (message, response)
|
297 |
-
yield history
|
298 |
-
except gr.CancelledError:
|
299 |
-
yield history
|
300 |
-
except Exception as e:
|
301 |
-
logging.error(f"Unexpected error in chatbot_interface: {str(e)}")
|
302 |
-
history[-1] = (message, f"An unexpected error occurred: {str(e)}")
|
303 |
-
yield history
|
304 |
-
|
305 |
-
def retry_last_response(history, use_web_search, model, temperature, num_calls):
|
306 |
-
if not history:
|
307 |
-
return history
|
308 |
-
|
309 |
-
last_user_msg = history[-1][0]
|
310 |
-
history = history[:-1] # Remove the last response
|
311 |
-
|
312 |
-
return chatbot_interface(last_user_msg, history, use_web_search, model, temperature, num_calls)
|
313 |
-
|
314 |
-
def respond(message, history, model, temperature, num_calls, use_web_search, selected_docs):
|
315 |
-
logging.info(f"User Query: {message}")
|
316 |
-
logging.info(f"Model Used: {model}")
|
317 |
-
logging.info(f"Search Type: {'Web Search' if use_web_search else 'PDF Search'}")
|
318 |
-
logging.info(f"Selected Documents: {selected_docs}")
|
319 |
-
|
320 |
-
try:
|
321 |
-
if use_web_search:
|
322 |
-
for main_content, sources in get_response_with_search(message, model, num_calls=num_calls, temperature=temperature):
|
323 |
-
response = f"{main_content}\n\n{sources}"
|
324 |
-
first_line = response.split('\n')[0] if response else ''
|
325 |
-
logging.info(f"Generated Response (first line): {first_line}")
|
326 |
-
yield response
|
327 |
-
else:
|
328 |
-
embed = get_embeddings()
|
329 |
-
if os.path.exists("faiss_database"):
|
330 |
-
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
|
331 |
-
retriever = database.as_retriever(search_kwargs={"k": 20})
|
332 |
-
|
333 |
-
# Filter relevant documents based on user selection
|
334 |
-
all_relevant_docs = retriever.get_relevant_documents(message)
|
335 |
-
relevant_docs = [doc for doc in all_relevant_docs if doc.metadata["source"] in selected_docs]
|
336 |
-
|
337 |
-
if not relevant_docs:
|
338 |
-
logging.info("No relevant information found in the selected documents.")
|
339 |
-
yield "No relevant information found in the selected documents. Please try selecting different documents or rephrasing your query."
|
340 |
-
return
|
341 |
-
|
342 |
-
context_str = "\n".join([doc.page_content for doc in relevant_docs])
|
343 |
-
else:
|
344 |
-
context_str = "No documents available."
|
345 |
-
logging.info("No documents available.")
|
346 |
-
yield "No documents available. Please upload PDF documents to answer questions."
|
347 |
-
return
|
348 |
-
|
349 |
-
if model in ["gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"]:
|
350 |
-
# Use DuckDuckGo Chat API
|
351 |
-
logging.info(f"Calling DuckDuckGo Chat API with model: {model}")
|
352 |
-
response = chat(message, model=model, timeout=30)
|
353 |
-
yield response
|
354 |
-
else:
|
355 |
-
# Use Hugging Face API or Cloudflare API
|
356 |
-
for partial_response in get_response_from_pdf(message, model, selected_docs, num_calls=num_calls, temperature=temperature):
|
357 |
-
first_line = partial_response.split('\n')[0] if partial_response else ''
|
358 |
-
logging.info(f"Generated Response (first line): {first_line}")
|
359 |
-
yield partial_response
|
360 |
-
except Exception as e:
|
361 |
-
logging.error(f"Error with {model}: {str(e)}")
|
362 |
-
if "microsoft/Phi-3-mini-4k-instruct" in model:
|
363 |
-
logging.info("Falling back to Mistral model due to Phi-3 error")
|
364 |
-
fallback_model = "mistralai/Mistral-7B-Instruct-v0.3"
|
365 |
-
yield from respond(message, history, fallback_model, temperature, num_calls, use_web_search, selected_docs)
|
366 |
-
else:
|
367 |
-
yield f"An error occurred with the {model} model: {str(e)}. Please try again or select a different model."
|
368 |
-
|
369 |
-
from duckduckgo_search import DDGS
|
370 |
-
|
371 |
-
def chat(keywords: str, model: str, timeout: int = 30) -> str:
|
372 |
-
"""Initiates a chat session with DuckDuckGo AI.
|
373 |
-
|
374 |
-
Args:
|
375 |
-
keywords (str): The initial message or question to send to the AI.
|
376 |
-
model (str): The model to use: "gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b".
|
377 |
-
timeout (int): Timeout value for the HTTP client. Defaults to 30.
|
378 |
-
|
379 |
-
Returns:
|
380 |
-
str: The response from the AI.
|
381 |
-
"""
|
382 |
-
logging.info(f"Calling DuckDuckGo Chat API with model: {model}")
|
383 |
-
|
384 |
-
try:
|
385 |
-
with DDGS() as ddgs:
|
386 |
-
return ddgs.chat(keywords, model=model, timeout=timeout)
|
387 |
-
except Exception as e:
|
388 |
-
logging.error(f"Error in DuckDuckGo chat: {str(e)}")
|
389 |
-
return "Error in DuckDuckGo chat. Please try again later."
|
390 |
-
|
391 |
-
def get_response_from_cloudflare(prompt, context, query, num_calls=3, temperature=0.2, search_type="pdf"):
|
392 |
-
headers = {
|
393 |
-
"Authorization": f"Bearer {API_TOKEN}",
|
394 |
-
"Content-Type": "application/json"
|
395 |
-
}
|
396 |
-
model = "@cf/meta/llama-3.1-8b-instruct"
|
397 |
-
|
398 |
-
if search_type == "pdf":
|
399 |
-
instruction = f"""Using the following context from the PDF documents:
|
400 |
-
{context}
|
401 |
-
Write a detailed and complete response that answers the following user question: '{query}'"""
|
402 |
-
else: # web search
|
403 |
-
instruction = f"""Using the following context:
|
404 |
-
{context}
|
405 |
-
Write a detailed and complete research document that fulfills the following user request: '{query}'
|
406 |
-
After writing the document, please provide a list of sources used in your response."""
|
407 |
-
|
408 |
-
inputs = [
|
409 |
-
{"role": "system", "content": instruction},
|
410 |
-
{"role": "user", "content": query}
|
411 |
-
]
|
412 |
-
|
413 |
-
payload = {
|
414 |
-
"messages": inputs,
|
415 |
-
"stream": True,
|
416 |
-
"temperature": temperature,
|
417 |
-
"max_tokens": 32000
|
418 |
-
}
|
419 |
-
|
420 |
-
full_response = ""
|
421 |
-
for i in range(num_calls):
|
422 |
-
try:
|
423 |
-
with requests.post(f"{API_BASE_URL}{model}", headers=headers, json=payload, stream=True) as response:
|
424 |
-
if response.status_code == 200:
|
425 |
-
for line in response.iter_lines():
|
426 |
-
if line:
|
427 |
-
try:
|
428 |
-
json_response = json.loads(line.decode('utf-8').split('data: ')[1])
|
429 |
-
if 'response' in json_response:
|
430 |
-
chunk = json_response['response']
|
431 |
-
full_response += chunk
|
432 |
-
yield full_response
|
433 |
-
except (json.JSONDecodeError, IndexError) as e:
|
434 |
-
logging.error(f"Error parsing streaming response: {str(e)}")
|
435 |
-
continue
|
436 |
-
else:
|
437 |
-
logging.error(f"HTTP Error: {response.status_code}, Response: {response.text}")
|
438 |
-
yield f"I apologize, but I encountered an HTTP error: {response.status_code}. Please try again later."
|
439 |
-
except Exception as e:
|
440 |
-
logging.error(f"Error in generating response from Cloudflare: {str(e)}")
|
441 |
-
yield f"I apologize, but an error occurred: {str(e)}. Please try again later."
|
442 |
-
|
443 |
-
if not full_response:
|
444 |
-
yield "I apologize, but I couldn't generate a response at this time. Please try again later."
|
445 |
-
|
446 |
-
def create_web_search_vectors(search_results):
|
447 |
-
embed = get_embeddings()
|
448 |
-
|
449 |
-
documents = []
|
450 |
-
for result in search_results:
|
451 |
-
if 'body' in result:
|
452 |
-
content = f"{result['title']}\n{result['body']}\nSource: {result['href']}"
|
453 |
-
documents.append(Document(page_content=content, metadata={"source": result['href']}))
|
454 |
-
|
455 |
-
return FAISS.from_documents(documents, embed)
|
456 |
-
|
457 |
-
def get_response_with_search(query, model, num_calls=3, temperature=0.2):
|
458 |
-
search_results = duckduckgo_search(query)
|
459 |
-
web_search_database = create_web_search_vectors(search_results)
|
460 |
-
|
461 |
-
if not web_search_database:
|
462 |
-
yield "No web search results available. Please try again.", ""
|
463 |
-
return
|
464 |
-
|
465 |
-
retriever = web_search_database.as_retriever(search_kwargs={"k": 5})
|
466 |
-
relevant_docs = retriever.get_relevant_documents(query)
|
467 |
-
|
468 |
-
context = "\n".join([doc.page_content for doc in relevant_docs])
|
469 |
-
|
470 |
-
prompt = f"""Using the following context from web search results:
|
471 |
-
{context}
|
472 |
-
You are an expert AI assistant, write a detailed and complete research document that fulfills the following user request: '{query}'
|
473 |
-
Base your entire response strictly on the information retrieved from trusted sources. Importantly, only include information that is directly supported by the retrieved content.
|
474 |
-
If any part of the information cannot be verified from the given sources, clearly state that it could not be confirmed.
|
475 |
-
After writing the document, please provide a list of sources used in your response."""
|
476 |
-
|
477 |
-
if model == "@cf/meta/llama-3.1-8b-instruct":
|
478 |
-
# Use Cloudflare API
|
479 |
-
for response in get_response_from_cloudflare(prompt="", context=context, query=query, num_calls=num_calls, temperature=temperature, search_type="web"):
|
480 |
-
yield response, "" # Yield streaming response without sources
|
481 |
-
else:
|
482 |
-
# Use Hugging Face API
|
483 |
-
client = InferenceClient(model, token=huggingface_token)
|
484 |
-
|
485 |
-
main_content = ""
|
486 |
-
for i in range(num_calls):
|
487 |
-
for message in client.chat_completion(
|
488 |
-
messages=[{"role": "user", "content": prompt}],
|
489 |
-
max_tokens=10000,
|
490 |
-
temperature=temperature,
|
491 |
-
stream=True,
|
492 |
-
):
|
493 |
-
if message.choices and message.choices[0].delta and message.choices[0].delta.content:
|
494 |
-
chunk = message.choices[0].delta.content
|
495 |
-
main_content += chunk
|
496 |
-
yield main_content, "" # Yield partial main content without sources
|
497 |
-
|
498 |
def get_response_from_pdf(query, model, selected_docs, num_calls=3, temperature=0.2):
|
499 |
logging.info(f"Entering get_response_from_pdf with query: {query}, model: {model}, selected_docs: {selected_docs}")
|
500 |
|
@@ -565,25 +221,36 @@ Write a detailed and complete response that answers the following user question:
|
|
565 |
|
566 |
logging.info("Finished generating response")
|
567 |
|
568 |
-
def
|
569 |
-
if
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
587 |
|
588 |
def display_documents():
|
589 |
return gr.CheckboxGroup(
|
@@ -601,7 +268,7 @@ def initial_conversation():
|
|
601 |
"4. For any queries feel free to reach out @[email protected] or discord - shreyas094\n\n"
|
602 |
"To get started, upload some PDFs or ask me a question!")
|
603 |
]
|
604 |
-
|
605 |
def refresh_documents():
|
606 |
global uploaded_documents
|
607 |
uploaded_documents = load_documents()
|
@@ -615,7 +282,7 @@ use_web_search = gr.Checkbox(label="Use Web Search", value=True)
|
|
615 |
custom_placeholder = "Ask a question (Note: You can toggle between Web Search and PDF Chat in Additional Inputs below)"
|
616 |
|
617 |
demo = gr.ChatInterface(
|
618 |
-
|
619 |
additional_inputs=[
|
620 |
gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[3]),
|
621 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.2, step=0.1, label="Temperature"),
|
@@ -652,13 +319,13 @@ demo = gr.ChatInterface(
|
|
652 |
cache_examples=False,
|
653 |
analytics_enabled=False,
|
654 |
textbox=gr.Textbox(placeholder=custom_placeholder, container=False, scale=7),
|
655 |
-
chatbot
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
)
|
662 |
)
|
663 |
|
664 |
# Add file upload functionality
|
|
|
1 |
import os
|
2 |
import json
|
3 |
+
import logging
|
4 |
+
import shutil
|
5 |
+
from tempfile import NamedTemporaryFile
|
|
|
6 |
from typing import List
|
7 |
from pydantic import BaseModel, Field
|
|
|
|
|
|
|
8 |
from langchain_core.documents import Document
|
9 |
from langchain_community.document_loaders import PyPDFLoader
|
10 |
+
from langchain_community.vectorstores import FAISS
|
11 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
12 |
from llama_parse import LlamaParse
|
13 |
+
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
# Initialize LlamaParse
|
16 |
llama_parser = LlamaParse(
|
|
|
41 |
def get_embeddings():
|
42 |
return HuggingFaceEmbeddings(model_name="avsolatorio/GIST-Embedding-v0")
|
43 |
|
|
|
44 |
DOCUMENTS_FILE = "uploaded_documents.json"
|
45 |
|
46 |
def load_documents():
|
|
|
56 |
# Replace the global uploaded_documents with this
|
57 |
uploaded_documents = load_documents()
|
58 |
|
|
|
59 |
def update_vectors(files, parser):
|
60 |
global uploaded_documents
|
61 |
logging.info(f"Entering update_vectors with {len(files)} files and parser: {parser}")
|
|
|
151 |
|
152 |
return f"Deleted documents: {', '.join(deleted_docs)}", display_documents()
|
153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
def get_response_from_pdf(query, model, selected_docs, num_calls=3, temperature=0.2):
|
155 |
logging.info(f"Entering get_response_from_pdf with query: {query}, model: {model}, selected_docs: {selected_docs}")
|
156 |
|
|
|
221 |
|
222 |
logging.info("Finished generating response")
|
223 |
|
224 |
+
def chatbot_interface(message, history, use_web_search, model, temperature, num_calls, selected_docs):
|
225 |
+
if not message.strip():
|
226 |
+
return "", history
|
227 |
+
|
228 |
+
history = history + [(message, "")]
|
229 |
+
|
230 |
+
try:
|
231 |
+
if use_web_search:
|
232 |
+
for response in get_response_with_search(message, model, num_calls=num_calls, temperature=temperature):
|
233 |
+
history[-1] = (message, response)
|
234 |
+
yield history
|
235 |
+
else:
|
236 |
+
for response in get_response_from_pdf(message, model, selected_docs, num_calls=num_calls, temperature=temperature):
|
237 |
+
history[-1] = (message, response)
|
238 |
+
yield history
|
239 |
+
except gr.CancelledError:
|
240 |
+
yield history
|
241 |
+
except Exception as e:
|
242 |
+
logging.error(f"Unexpected error in chatbot_interface: {str(e)}")
|
243 |
+
history[-1] = (message, f"An unexpected error occurred: {str(e)}")
|
244 |
+
yield history
|
245 |
+
|
246 |
+
def retry_last_response(history, use_web_search, model, temperature, num_calls, selected_docs):
|
247 |
+
if not history:
|
248 |
+
return history
|
249 |
+
|
250 |
+
last_user_msg = history[-1][0]
|
251 |
+
history = history[:-1] # Remove the last response
|
252 |
+
|
253 |
+
return chatbot_interface(last_user_msg, history, use_web_search, model, temperature, num_calls, selected_docs)
|
254 |
|
255 |
def display_documents():
|
256 |
return gr.CheckboxGroup(
|
|
|
268 |
"4. For any queries feel free to reach out @[email protected] or discord - shreyas094\n\n"
|
269 |
"To get started, upload some PDFs or ask me a question!")
|
270 |
]
|
271 |
+
|
272 |
def refresh_documents():
|
273 |
global uploaded_documents
|
274 |
uploaded_documents = load_documents()
|
|
|
282 |
custom_placeholder = "Ask a question (Note: You can toggle between Web Search and PDF Chat in Additional Inputs below)"
|
283 |
|
284 |
demo = gr.ChatInterface(
|
285 |
+
chatbot_interface,
|
286 |
additional_inputs=[
|
287 |
gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[3]),
|
288 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.2, step=0.1, label="Temperature"),
|
|
|
319 |
cache_examples=False,
|
320 |
analytics_enabled=False,
|
321 |
textbox=gr.Textbox(placeholder=custom_placeholder, container=False, scale=7),
|
322 |
+
chatbot=gr.Chatbot(
|
323 |
+
show_copy_button=True,
|
324 |
+
likeable=True,
|
325 |
+
layout="bubble",
|
326 |
+
height=400,
|
327 |
+
value=initial_conversation()
|
328 |
+
)
|
329 |
)
|
330 |
|
331 |
# Add file upload functionality
|