Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -419,261 +419,118 @@
|
|
419 |
|
420 |
|
421 |
import os
|
422 |
-
import logging
|
423 |
-
import math
|
424 |
import streamlit as st
|
425 |
import fitz # PyMuPDF
|
|
|
|
|
426 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
427 |
-
from langchain_community.document_loaders import PDFMinerLoader
|
428 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
429 |
-
from langchain_community.embeddings import SentenceTransformerEmbeddings
|
430 |
from langchain_community.vectorstores import Chroma
|
|
|
431 |
from langchain_community.llms import HuggingFacePipeline
|
432 |
from langchain.chains import RetrievalQA
|
|
|
433 |
|
434 |
-
#
|
435 |
-
|
436 |
-
|
437 |
-
# Define global variables
|
438 |
-
device = 'cpu'
|
439 |
persist_directory = "db"
|
440 |
-
|
441 |
-
|
442 |
-
# Streamlit app configuration
|
443 |
-
st.set_page_config(page_title="RAG-based Chatbot", layout="wide")
|
444 |
-
st.title("RAG-based Chatbot")
|
445 |
-
|
446 |
-
# Load the model
|
447 |
-
checkpoint = "MBZUAI/LaMini-T5-738M"
|
448 |
-
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
449 |
-
base_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
|
450 |
|
451 |
-
#
|
|
|
452 |
|
453 |
-
|
454 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
455 |
try:
|
456 |
-
doc = fitz.open(
|
457 |
text = ""
|
458 |
-
for
|
459 |
-
|
460 |
-
|
461 |
-
return text
|
462 |
except Exception as e:
|
463 |
-
logging.error(f"
|
464 |
-
return
|
465 |
-
|
466 |
-
def data_ingestion():
|
467 |
-
"""Function to load PDFs and create embeddings with improved error handling and efficiency."""
|
468 |
-
try:
|
469 |
-
logging.info("Starting data ingestion")
|
470 |
-
|
471 |
-
if not os.path.exists(uploaded_files_dir):
|
472 |
-
os.makedirs(uploaded_files_dir)
|
473 |
-
|
474 |
-
documents = []
|
475 |
-
for filename in os.listdir(uploaded_files_dir):
|
476 |
-
if filename.endswith(".pdf"):
|
477 |
-
file_path = os.path.join(uploaded_files_dir, filename)
|
478 |
-
logging.info(f"Processing file: {file_path}")
|
479 |
-
|
480 |
-
try:
|
481 |
-
loader = PDFMinerLoader(file_path)
|
482 |
-
loaded_docs = loader.load()
|
483 |
-
if not loaded_docs:
|
484 |
-
logging.warning(f"Skipping file with missing or invalid metadata: {file_path}")
|
485 |
-
continue
|
486 |
-
|
487 |
-
for doc in loaded_docs:
|
488 |
-
if hasattr(doc, 'page_content') and len(doc.page_content.strip()) > 0:
|
489 |
-
documents.append(doc)
|
490 |
-
else:
|
491 |
-
logging.warning(f"Skipping invalid document structure in {file_path}")
|
492 |
-
except ValueError as e:
|
493 |
-
logging.error(f"Skipping {file_path}: {str(e)}")
|
494 |
-
continue
|
495 |
-
|
496 |
-
if not documents:
|
497 |
-
logging.error("No valid documents found to process.")
|
498 |
-
return
|
499 |
-
|
500 |
-
logging.info(f"Total valid documents: {len(documents)}")
|
501 |
-
|
502 |
-
# Proceed with splitting and embedding documents
|
503 |
-
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
|
504 |
-
texts = text_splitter.split_documents(documents)
|
505 |
-
|
506 |
-
logging.info(f"Total text chunks created: {len(texts)}")
|
507 |
-
|
508 |
-
if not texts:
|
509 |
-
logging.error("No valid text chunks to create embeddings.")
|
510 |
-
return
|
511 |
-
|
512 |
-
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
|
513 |
-
|
514 |
-
# Proceed to split and embed the documents
|
515 |
-
MAX_BATCH_SIZE = 5461
|
516 |
-
total_batches = math.ceil(len(texts) / MAX_BATCH_SIZE)
|
517 |
-
|
518 |
-
logging.info(f"Processing {len(texts)} text chunks in {total_batches} batches...")
|
519 |
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
text_batch = texts[batch_start:batch_end]
|
525 |
-
|
526 |
-
logging.info(f"Processing batch {i + 1}/{total_batches}, size: {len(text_batch)}")
|
527 |
|
528 |
-
|
529 |
-
|
530 |
-
else:
|
531 |
-
db.add_documents(text_batch)
|
532 |
-
|
533 |
-
db.persist()
|
534 |
-
logging.info("Data ingestion completed successfully")
|
535 |
-
|
536 |
-
except Exception as e:
|
537 |
-
logging.error(f"Error during data ingestion: {str(e)}")
|
538 |
-
raise
|
539 |
-
|
540 |
-
def llm_pipeline():
|
541 |
-
"""Set up the language model pipeline."""
|
542 |
-
logging.info("Setting up LLM pipeline")
|
543 |
-
pipe = pipeline(
|
544 |
-
'text2text-generation',
|
545 |
-
model=base_model,
|
546 |
-
tokenizer=tokenizer,
|
547 |
-
max_length=256,
|
548 |
-
do_sample=True,
|
549 |
-
temperature=0.3,
|
550 |
-
top_p=0.95,
|
551 |
-
device=device
|
552 |
-
)
|
553 |
-
local_llm = HuggingFacePipeline(pipeline=pipe)
|
554 |
-
logging.info("LLM pipeline setup complete")
|
555 |
-
return local_llm
|
556 |
-
|
557 |
-
def qa_llm():
|
558 |
-
"""Set up the question-answering chain."""
|
559 |
-
logging.info("Setting up QA model")
|
560 |
-
llm = llm_pipeline()
|
561 |
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
|
562 |
-
db = Chroma(persist_directory=persist_directory
|
563 |
-
|
564 |
-
|
565 |
-
llm=llm,
|
566 |
-
chain_type="stuff",
|
567 |
-
retriever=retriever,
|
568 |
-
return_source_documents=True
|
569 |
-
)
|
570 |
-
logging.info("QA model setup complete")
|
571 |
-
return qa
|
572 |
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
|
|
577 |
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
|
583 |
-
retriever = db.as_retriever() # Set up the retriever to use Chroma database
|
584 |
-
|
585 |
-
# Here we're just adding the full_text as a document for simplicity
|
586 |
-
db.add_documents([full_text])
|
587 |
-
|
588 |
-
# Set up the language model pipeline (assuming you already have a pipeline set up)
|
589 |
-
llm = llm_pipeline()
|
590 |
-
|
591 |
-
# Construct the retrieval chain using the retriever and LLM
|
592 |
-
qa_chain = RetrievalQA.from_chain_type(
|
593 |
-
llm=llm,
|
594 |
-
chain_type="stuff",
|
595 |
-
retriever=retriever,
|
596 |
-
return_source_documents=True
|
597 |
-
)
|
598 |
-
|
599 |
-
# Create a tailored prompt for the question (providing context to the chatbot)
|
600 |
-
tailored_prompt = f"""
|
601 |
-
You are a helpful RAG-based chatbot designed to assist with answering questions from any uploaded document.
|
602 |
-
You should answer the question using relevant information from the provided PDF text.
|
603 |
-
Please provide a clear, informative answer based on the document content.
|
604 |
-
User question: {user_question}
|
605 |
-
"""
|
606 |
-
|
607 |
-
# Generate the answer using the retrieval-augmented generation model
|
608 |
-
generated_text = qa_chain({"query": tailored_prompt})
|
609 |
-
|
610 |
-
# Extract the generated answer
|
611 |
-
answer = generated_text['result']
|
612 |
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
|
617 |
-
|
618 |
-
|
|
|
619 |
|
|
|
|
|
|
|
|
|
|
|
620 |
except Exception as e:
|
621 |
-
logging.error(f"Error
|
622 |
-
return "Sorry, I
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
st.
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
with
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
# Extract and display the full text from the PDF
|
656 |
-
st.subheader("Full Text from the PDF:")
|
657 |
-
full_text = extract_text_from_pdf(file_path)
|
658 |
-
if full_text:
|
659 |
-
st.text_area("PDF Text", full_text, height=300)
|
660 |
-
else:
|
661 |
-
st.warning("Failed to extract text from this PDF.")
|
662 |
-
|
663 |
-
# # Generate summary option
|
664 |
-
# if st.button("Generate Summary of Document"):
|
665 |
-
# st.write("Summary: [Provide the generated summary here]")
|
666 |
-
|
667 |
-
# Run data ingestion when files are uploaded
|
668 |
-
data_ingestion()
|
669 |
-
|
670 |
-
# Display UI for Q&A
|
671 |
-
st.header("Ask a Question")
|
672 |
-
user_question = st.text_input("Enter your question here:")
|
673 |
-
|
674 |
-
if user_question:
|
675 |
-
answer = process_answer(user_question)
|
676 |
-
st.write(answer)
|
677 |
-
|
678 |
else:
|
679 |
-
st.
|
|
|
419 |
|
420 |
|
421 |
import os
|
|
|
|
|
422 |
import streamlit as st
|
423 |
import fitz # PyMuPDF
|
424 |
+
import logging
|
425 |
+
import math
|
426 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
|
|
427 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
428 |
from langchain_community.vectorstores import Chroma
|
429 |
+
from langchain_community.embeddings import SentenceTransformerEmbeddings
|
430 |
from langchain_community.llms import HuggingFacePipeline
|
431 |
from langchain.chains import RetrievalQA
|
432 |
+
from langchain.schema import Document
|
433 |
|
434 |
+
# --- Configuration ---
|
435 |
+
st.set_page_config(page_title="π RAG PDF Chatbot", layout="wide")
|
436 |
+
st.title("π RAG-based PDF Chatbot")
|
|
|
|
|
437 |
persist_directory = "db"
|
438 |
+
device = "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
|
440 |
+
# --- Logging ---
|
441 |
+
logging.basicConfig(level=logging.INFO)
|
442 |
|
443 |
+
# --- Load LLM ---
|
444 |
+
@st.cache_resource
|
445 |
+
def load_model():
|
446 |
+
checkpoint = "MBZUAI/LaMini-T5-738M"
|
447 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
448 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
|
449 |
+
pipe = pipeline('text2text-generation', model=model, tokenizer=tokenizer, max_length=512)
|
450 |
+
return HuggingFacePipeline(pipeline=pipe)
|
451 |
+
|
452 |
+
# --- Extract PDF Text ---
|
453 |
+
def read_pdf(file):
|
454 |
try:
|
455 |
+
doc = fitz.open(stream=file.read(), filetype="pdf")
|
456 |
text = ""
|
457 |
+
for page in doc:
|
458 |
+
text += page.get_text()
|
459 |
+
return text.strip()
|
|
|
460 |
except Exception as e:
|
461 |
+
logging.error(f"Failed to extract text: {e}")
|
462 |
+
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
|
464 |
+
# --- Split Text into Chunks ---
|
465 |
+
def split_text_into_chunks(text):
|
466 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
467 |
+
return splitter.create_documents([text])
|
|
|
|
|
|
|
468 |
|
469 |
+
# --- Create Vector DB ---
|
470 |
+
def create_vectorstore(documents):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
471 |
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
|
472 |
+
db = Chroma.from_documents(documents, embeddings, persist_directory=persist_directory)
|
473 |
+
db.persist()
|
474 |
+
return db
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
475 |
|
476 |
+
# --- Setup QA Chain ---
|
477 |
+
def setup_qa(db):
|
478 |
+
retriever = db.as_retriever()
|
479 |
+
llm = load_model()
|
480 |
+
return RetrievalQA.from_chain_type(llm=llm, retriever=retriever, return_source_documents=True)
|
481 |
|
482 |
+
# --- Process Answer ---
|
483 |
+
def process_answer(user_question, full_text):
|
484 |
+
if not full_text:
|
485 |
+
return "No content was extracted from the PDF. Please try another file."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
486 |
|
487 |
+
docs = split_text_into_chunks(full_text)
|
488 |
+
db = create_vectorstore(docs)
|
489 |
+
qa = setup_qa(db)
|
490 |
|
491 |
+
prompt = f"""
|
492 |
+
You are a helpful AI assistant. Based on the provided context from a PDF document,
|
493 |
+
generate an accurate, informative answer to the following question:
|
494 |
|
495 |
+
{user_question}
|
496 |
+
"""
|
497 |
+
try:
|
498 |
+
result = qa({"query": prompt})
|
499 |
+
return result['result']
|
500 |
except Exception as e:
|
501 |
+
logging.error(f"Error generating answer: {e}")
|
502 |
+
return "Sorry, I couldn't generate an answer due to an internal error."
|
503 |
+
|
504 |
+
# --- UI Layout ---
|
505 |
+
with st.sidebar:
|
506 |
+
st.header("π Upload PDF")
|
507 |
+
uploaded_file = st.file_uploader("Choose a PDF", type=["pdf"])
|
508 |
+
|
509 |
+
# --- Main Interface ---
|
510 |
+
if uploaded_file:
|
511 |
+
st.success(f"You uploaded: {uploaded_file.name}")
|
512 |
+
full_text = read_pdf(uploaded_file)
|
513 |
+
|
514 |
+
if full_text:
|
515 |
+
st.subheader("π PDF Preview")
|
516 |
+
with st.expander("View Extracted Text"):
|
517 |
+
st.write(full_text[:3000] + ("..." if len(full_text) > 3000 else ""))
|
518 |
+
|
519 |
+
st.subheader("π¬ Ask a Question")
|
520 |
+
user_question = st.text_input("Type your question about the PDF content")
|
521 |
+
|
522 |
+
if user_question:
|
523 |
+
with st.spinner("Thinking..."):
|
524 |
+
answer = process_answer(user_question, full_text)
|
525 |
+
st.markdown("### π€ Answer")
|
526 |
+
st.write(answer)
|
527 |
+
|
528 |
+
with st.sidebar:
|
529 |
+
st.markdown("---")
|
530 |
+
st.markdown("**π‘ Suggestions:**")
|
531 |
+
st.caption("Try: \"Summarize this document\" or \"What is the key idea?\")
|
532 |
+
|
533 |
+
else:
|
534 |
+
st.error("β οΈ No text could be extracted from the PDF. Try another file.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
535 |
else:
|
536 |
+
st.info("Upload a PDF to begin.")
|