Spaces:
Running
Running
import os | |
import gradio as gr | |
import torch | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import Chroma | |
from langchain_community.llms import HuggingFacePipeline | |
from transformers import pipeline | |
# Set Hugging Face Cache Directory | |
os.environ["HF_HOME"] = "/tmp/huggingface_cache" | |
# Check for GPU availability | |
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" | |
# Global variables | |
conversation_retrieval_chain = None | |
chat_history = [] | |
llm_pipeline = None | |
embeddings = None | |
persist_directory = "/tmp/chroma_db" # Storage for vector DB | |
def init_llm(): | |
"""Initialize LLM and Embeddings""" | |
global llm_pipeline, embeddings | |
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
if not hf_token: | |
raise ValueError("HUGGINGFACEHUB_API_TOKEN is not set in environment variables.") | |
model_id = "tiiuae/falcon-7b-instruct" | |
hf_pipeline = pipeline("text-generation", model=model_id, device=DEVICE) | |
llm_pipeline = HuggingFacePipeline(pipeline=hf_pipeline) | |
embeddings = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-MiniLM-L6-v2", | |
model_kwargs={"device": DEVICE} | |
) | |
import time | |
def process_document(file): | |
global conversation_retrieval_chain | |
if not llm_pipeline or not embeddings: | |
init_llm() | |
start_time = time.time() | |
print(f"π Uploading PDF: {file.name}") | |
try: | |
# β Ensure file is saved correctly | |
file_path = os.path.join("/tmp/uploads", file.name) | |
with open(file_path, "wb") as f: | |
f.write(file.read()) | |
print(f"β PDF saved at {file_path} in {time.time() - start_time:.2f}s") | |
# β Load PDF | |
start_time = time.time() | |
loader = PyPDFLoader(file_path) | |
documents = loader.load() | |
print(f"β PDF loaded in {time.time() - start_time:.2f}s") | |
# β Split text | |
start_time = time.time() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50) | |
texts = text_splitter.split_documents(documents) | |
print(f"β Text split in {time.time() - start_time:.2f}s") | |
# β Create ChromaDB | |
start_time = time.time() | |
db = Chroma.from_documents(texts, embedding=embeddings, persist_directory="/tmp/chroma_db") | |
print(f"β ChromaDB created in {time.time() - start_time:.2f}s") | |
# β Create retrieval chain | |
conversation_retrieval_chain = ConversationalRetrievalChain.from_llm( | |
llm=llm_pipeline, retriever=db.as_retriever() | |
) | |
print("β Document processing complete!") | |
return "π PDF uploaded and processed successfully! You can now ask questions." | |
except Exception as e: | |
print(f"β Error processing PDF: {str(e)}") | |
return f"Error: {str(e)}" | |
def process_prompt(prompt, chat_history_display): | |
"""Generate a response using the retrieval chain""" | |
global conversation_retrieval_chain, chat_history | |
if not conversation_retrieval_chain: | |
return chat_history_display + [("β No document uploaded.", "Please upload a PDF first.")] | |
output = conversation_retrieval_chain({"question": prompt, "chat_history": chat_history}) | |
answer = output["answer"] | |
chat_history.append((prompt, answer)) | |
return chat_history | |
# Define Gradio UI | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("<h1 style='text-align: center;'>Personal Data Assistant</h1>") | |
with gr.Row(): | |
dark_mode = gr.Checkbox(label="π Toggle light/dark mode") | |
with gr.Column(): # β Replace `gr.Box()` with `gr.Column()` | |
gr.Markdown("Hello there! I'm your friendly data assistant, ready to answer any questions regarding your data. Could you please upload a PDF file for me to analyze?") | |
file_input = gr.File(label="Upload File") | |
upload_button = gr.Button("π Upload File") | |
status_output = gr.Textbox(label="Status", interactive=False) | |
chat_history_display = gr.Chatbot(label="Chat History") | |
with gr.Row(): | |
user_input = gr.Textbox(placeholder="Type your message here...", scale=4) | |
submit_button = gr.Button("π©", scale=1) | |
clear_button = gr.Button("π", scale=1) | |
# Button Click Actions | |
upload_button.click(process_document, inputs=file_input, outputs=status_output) | |
submit_button.click(process_prompt, inputs=[user_input, chat_history_display], outputs=chat_history_display) | |
clear_button.click(lambda: [], outputs=chat_history_display) | |
# Launch Gradio App | |
if __name__ == "__main__": | |
demo.launch(share=True) | |