pradeepsengarr commited on
Commit
bbd8a88
Β·
verified Β·
1 Parent(s): a908e1c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -234
app.py CHANGED
@@ -419,261 +419,118 @@
419
 
420
 
421
  import os
422
- import logging
423
- import math
424
  import streamlit as st
425
  import fitz # PyMuPDF
 
 
426
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
427
- from langchain_community.document_loaders import PDFMinerLoader
428
  from langchain.text_splitter import RecursiveCharacterTextSplitter
429
- from langchain_community.embeddings import SentenceTransformerEmbeddings
430
  from langchain_community.vectorstores import Chroma
 
431
  from langchain_community.llms import HuggingFacePipeline
432
  from langchain.chains import RetrievalQA
 
433
 
434
- # Set up logging
435
- logging.basicConfig(level=logging.INFO)
436
-
437
- # Define global variables
438
- device = 'cpu'
439
  persist_directory = "db"
440
- uploaded_files_dir = "uploaded_files"
441
-
442
- # Streamlit app configuration
443
- st.set_page_config(page_title="RAG-based Chatbot", layout="wide")
444
- st.title("RAG-based Chatbot")
445
-
446
- # Load the model
447
- checkpoint = "MBZUAI/LaMini-T5-738M"
448
- tokenizer = AutoTokenizer.from_pretrained(checkpoint)
449
- base_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
450
 
451
- # Helper Functions
 
452
 
453
- def extract_text_from_pdf(file_path):
454
- """Extract full text from a PDF using PyMuPDF (fitz)."""
 
 
 
 
 
 
 
 
 
455
  try:
456
- doc = fitz.open(file_path)
457
  text = ""
458
- for page_num in range(doc.page_count):
459
- page = doc.load_page(page_num)
460
- text += page.get_text("text")
461
- return text
462
  except Exception as e:
463
- logging.error(f"Error reading PDF {file_path}: {e}")
464
- return None
465
-
466
- def data_ingestion():
467
- """Function to load PDFs and create embeddings with improved error handling and efficiency."""
468
- try:
469
- logging.info("Starting data ingestion")
470
-
471
- if not os.path.exists(uploaded_files_dir):
472
- os.makedirs(uploaded_files_dir)
473
-
474
- documents = []
475
- for filename in os.listdir(uploaded_files_dir):
476
- if filename.endswith(".pdf"):
477
- file_path = os.path.join(uploaded_files_dir, filename)
478
- logging.info(f"Processing file: {file_path}")
479
-
480
- try:
481
- loader = PDFMinerLoader(file_path)
482
- loaded_docs = loader.load()
483
- if not loaded_docs:
484
- logging.warning(f"Skipping file with missing or invalid metadata: {file_path}")
485
- continue
486
-
487
- for doc in loaded_docs:
488
- if hasattr(doc, 'page_content') and len(doc.page_content.strip()) > 0:
489
- documents.append(doc)
490
- else:
491
- logging.warning(f"Skipping invalid document structure in {file_path}")
492
- except ValueError as e:
493
- logging.error(f"Skipping {file_path}: {str(e)}")
494
- continue
495
-
496
- if not documents:
497
- logging.error("No valid documents found to process.")
498
- return
499
-
500
- logging.info(f"Total valid documents: {len(documents)}")
501
-
502
- # Proceed with splitting and embedding documents
503
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
504
- texts = text_splitter.split_documents(documents)
505
-
506
- logging.info(f"Total text chunks created: {len(texts)}")
507
-
508
- if not texts:
509
- logging.error("No valid text chunks to create embeddings.")
510
- return
511
-
512
- embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
513
-
514
- # Proceed to split and embed the documents
515
- MAX_BATCH_SIZE = 5461
516
- total_batches = math.ceil(len(texts) / MAX_BATCH_SIZE)
517
-
518
- logging.info(f"Processing {len(texts)} text chunks in {total_batches} batches...")
519
 
520
- db = None
521
- for i in range(total_batches):
522
- batch_start = i * MAX_BATCH_SIZE
523
- batch_end = min((i + 1) * MAX_BATCH_SIZE, len(texts))
524
- text_batch = texts[batch_start:batch_end]
525
-
526
- logging.info(f"Processing batch {i + 1}/{total_batches}, size: {len(text_batch)}")
527
 
528
- if db is None:
529
- db = Chroma.from_documents(text_batch, embeddings, persist_directory=persist_directory)
530
- else:
531
- db.add_documents(text_batch)
532
-
533
- db.persist()
534
- logging.info("Data ingestion completed successfully")
535
-
536
- except Exception as e:
537
- logging.error(f"Error during data ingestion: {str(e)}")
538
- raise
539
-
540
- def llm_pipeline():
541
- """Set up the language model pipeline."""
542
- logging.info("Setting up LLM pipeline")
543
- pipe = pipeline(
544
- 'text2text-generation',
545
- model=base_model,
546
- tokenizer=tokenizer,
547
- max_length=256,
548
- do_sample=True,
549
- temperature=0.3,
550
- top_p=0.95,
551
- device=device
552
- )
553
- local_llm = HuggingFacePipeline(pipeline=pipe)
554
- logging.info("LLM pipeline setup complete")
555
- return local_llm
556
-
557
- def qa_llm():
558
- """Set up the question-answering chain."""
559
- logging.info("Setting up QA model")
560
- llm = llm_pipeline()
561
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
562
- db = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
563
- retriever = db.as_retriever() # Set up the retriever for the vector store
564
- qa = RetrievalQA.from_chain_type(
565
- llm=llm,
566
- chain_type="stuff",
567
- retriever=retriever,
568
- return_source_documents=True
569
- )
570
- logging.info("QA model setup complete")
571
- return qa
572
 
573
- def process_answer(user_question, full_text):
574
- """Generate an answer to the user’s question based on the extracted text from the PDF."""
575
- try:
576
- logging.info("Processing user question")
 
577
 
578
- # Set up the retriever with the PDF content (this could be your embedded database or a direct retrieval from full_text)
579
- embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
580
-
581
- # Use Chroma for document storage and retrieval if you’re storing documents in a vector store
582
- db = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
583
- retriever = db.as_retriever() # Set up the retriever to use Chroma database
584
-
585
- # Here we're just adding the full_text as a document for simplicity
586
- db.add_documents([full_text])
587
-
588
- # Set up the language model pipeline (assuming you already have a pipeline set up)
589
- llm = llm_pipeline()
590
-
591
- # Construct the retrieval chain using the retriever and LLM
592
- qa_chain = RetrievalQA.from_chain_type(
593
- llm=llm,
594
- chain_type="stuff",
595
- retriever=retriever,
596
- return_source_documents=True
597
- )
598
-
599
- # Create a tailored prompt for the question (providing context to the chatbot)
600
- tailored_prompt = f"""
601
- You are a helpful RAG-based chatbot designed to assist with answering questions from any uploaded document.
602
- You should answer the question using relevant information from the provided PDF text.
603
- Please provide a clear, informative answer based on the document content.
604
- User question: {user_question}
605
- """
606
-
607
- # Generate the answer using the retrieval-augmented generation model
608
- generated_text = qa_chain({"query": tailored_prompt})
609
-
610
- # Extract the generated answer
611
- answer = generated_text['result']
612
 
613
- # If the answer is empty or not very informative, provide a fallback message
614
- if "not provide" in answer or "no information" in answer:
615
- return "The document does not provide sufficient information to answer your question."
616
 
617
- logging.info("Answer generated successfully")
618
- return answer
 
619
 
 
 
 
 
 
620
  except Exception as e:
621
- logging.error(f"Error during answer generation: {str(e)}")
622
- return "Sorry, I encountered an issue while processing your question."
623
-
624
-
625
- # Streamlit UI Setup
626
- st.sidebar.header("File Upload")
627
- uploaded_files = st.sidebar.file_uploader("Upload your PDF files", type=["pdf"], accept_multiple_files=True)
628
-
629
- if uploaded_files:
630
- # Save uploaded files and extract their text
631
- if not os.path.exists(uploaded_files_dir):
632
- os.makedirs(uploaded_files_dir)
633
-
634
- for uploaded_file in uploaded_files:
635
- file_path = os.path.join(uploaded_files_dir, uploaded_file.name)
636
- with open(file_path, "wb") as f:
637
- f.write(uploaded_file.getbuffer())
638
-
639
- st.sidebar.success(f"Uploaded {len(uploaded_files)} file(s) successfully!")
640
-
641
- # Show the uploaded files' names
642
- st.subheader("Uploaded PDF(s):")
643
- for uploaded_file in uploaded_files:
644
- st.write(uploaded_file.name)
645
- # Display PDF preview link if possible
646
- with open(file_path, "rb") as f:
647
- file_bytes = f.read()
648
- st.download_button(
649
- label="Download PDF",
650
- data=file_bytes,
651
- file_name=uploaded_file.name,
652
- mime="application/pdf",
653
- )
654
-
655
- # Extract and display the full text from the PDF
656
- st.subheader("Full Text from the PDF:")
657
- full_text = extract_text_from_pdf(file_path)
658
- if full_text:
659
- st.text_area("PDF Text", full_text, height=300)
660
- else:
661
- st.warning("Failed to extract text from this PDF.")
662
-
663
- # # Generate summary option
664
- # if st.button("Generate Summary of Document"):
665
- # st.write("Summary: [Provide the generated summary here]")
666
-
667
- # Run data ingestion when files are uploaded
668
- data_ingestion()
669
-
670
- # Display UI for Q&A
671
- st.header("Ask a Question")
672
- user_question = st.text_input("Enter your question here:")
673
-
674
- if user_question:
675
- answer = process_answer(user_question)
676
- st.write(answer)
677
-
678
  else:
679
- st.sidebar.info("Upload PDF files to get started!")
 
419
 
420
 
421
  import os
 
 
422
  import streamlit as st
423
  import fitz # PyMuPDF
424
+ import logging
425
+ import math
426
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
427
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
428
  from langchain_community.vectorstores import Chroma
429
+ from langchain_community.embeddings import SentenceTransformerEmbeddings
430
  from langchain_community.llms import HuggingFacePipeline
431
  from langchain.chains import RetrievalQA
432
+ from langchain.schema import Document
433
 
434
+ # --- Configuration ---
435
+ st.set_page_config(page_title="πŸ“š RAG PDF Chatbot", layout="wide")
436
+ st.title("πŸ“š RAG-based PDF Chatbot")
 
 
437
  persist_directory = "db"
438
+ device = "cpu"
 
 
 
 
 
 
 
 
 
439
 
440
+ # --- Logging ---
441
+ logging.basicConfig(level=logging.INFO)
442
 
443
+ # --- Load LLM ---
444
+ @st.cache_resource
445
+ def load_model():
446
+ checkpoint = "MBZUAI/LaMini-T5-738M"
447
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
448
+ model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
449
+ pipe = pipeline('text2text-generation', model=model, tokenizer=tokenizer, max_length=512)
450
+ return HuggingFacePipeline(pipeline=pipe)
451
+
452
+ # --- Extract PDF Text ---
453
+ def read_pdf(file):
454
  try:
455
+ doc = fitz.open(stream=file.read(), filetype="pdf")
456
  text = ""
457
+ for page in doc:
458
+ text += page.get_text()
459
+ return text.strip()
 
460
  except Exception as e:
461
+ logging.error(f"Failed to extract text: {e}")
462
+ return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
 
464
+ # --- Split Text into Chunks ---
465
+ def split_text_into_chunks(text):
466
+ splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
467
+ return splitter.create_documents([text])
 
 
 
468
 
469
+ # --- Create Vector DB ---
470
+ def create_vectorstore(documents):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
472
+ db = Chroma.from_documents(documents, embeddings, persist_directory=persist_directory)
473
+ db.persist()
474
+ return db
 
 
 
 
 
 
 
475
 
476
+ # --- Setup QA Chain ---
477
+ def setup_qa(db):
478
+ retriever = db.as_retriever()
479
+ llm = load_model()
480
+ return RetrievalQA.from_chain_type(llm=llm, retriever=retriever, return_source_documents=True)
481
 
482
+ # --- Process Answer ---
483
+ def process_answer(user_question, full_text):
484
+ if not full_text:
485
+ return "No content was extracted from the PDF. Please try another file."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
 
487
+ docs = split_text_into_chunks(full_text)
488
+ db = create_vectorstore(docs)
489
+ qa = setup_qa(db)
490
 
491
+ prompt = f"""
492
+ You are a helpful AI assistant. Based on the provided context from a PDF document,
493
+ generate an accurate, informative answer to the following question:
494
 
495
+ {user_question}
496
+ """
497
+ try:
498
+ result = qa({"query": prompt})
499
+ return result['result']
500
  except Exception as e:
501
+ logging.error(f"Error generating answer: {e}")
502
+ return "Sorry, I couldn't generate an answer due to an internal error."
503
+
504
+ # --- UI Layout ---
505
+ with st.sidebar:
506
+ st.header("πŸ“„ Upload PDF")
507
+ uploaded_file = st.file_uploader("Choose a PDF", type=["pdf"])
508
+
509
+ # --- Main Interface ---
510
+ if uploaded_file:
511
+ st.success(f"You uploaded: {uploaded_file.name}")
512
+ full_text = read_pdf(uploaded_file)
513
+
514
+ if full_text:
515
+ st.subheader("πŸ“‘ PDF Preview")
516
+ with st.expander("View Extracted Text"):
517
+ st.write(full_text[:3000] + ("..." if len(full_text) > 3000 else ""))
518
+
519
+ st.subheader("πŸ’¬ Ask a Question")
520
+ user_question = st.text_input("Type your question about the PDF content")
521
+
522
+ if user_question:
523
+ with st.spinner("Thinking..."):
524
+ answer = process_answer(user_question, full_text)
525
+ st.markdown("### πŸ€– Answer")
526
+ st.write(answer)
527
+
528
+ with st.sidebar:
529
+ st.markdown("---")
530
+ st.markdown("**πŸ’‘ Suggestions:**")
531
+ st.caption("Try: \"Summarize this document\" or \"What is the key idea?\")
532
+
533
+ else:
534
+ st.error("⚠️ No text could be extracted from the PDF. Try another file.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
  else:
536
+ st.info("Upload a PDF to begin.")