Spaces:
Sleeping
Sleeping
# import os | |
# import logging | |
# import streamlit as st | |
# import torch | |
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
# from langchain_community.document_loaders import PDFMinerLoader | |
# from langchain.text_splitter import RecursiveCharacterTextSplitter | |
# from langchain_community.embeddings import SentenceTransformerEmbeddings | |
# from langchain_community.vectorstores import Chroma | |
# from langchain_community.llms import HuggingFacePipeline | |
# from langchain.chains import RetrievalQA | |
# # Set up logging | |
# logging.basicConfig(level=logging.INFO) | |
# # Paths and model | |
# PERSIST_DIRECTORY = "db" | |
# UPLOAD_FOLDER = "uploaded_files" | |
# os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
# CHECKPOINT = "MBZUAI/LaMini-T5-738M" | |
# tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT) | |
# base_model = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT) | |
# device = 0 if torch.cuda.is_available() else -1 | |
# def ingest_data(): | |
# try: | |
# st.info("π Ingesting documents...") | |
# docs = [] | |
# for file_name in os.listdir(UPLOAD_FOLDER): | |
# if file_name.endswith(".pdf"): | |
# path = os.path.join(UPLOAD_FOLDER, file_name) | |
# loader = PDFMinerLoader(path) | |
# loaded_docs = loader.load() | |
# docs.extend(loaded_docs) | |
# if not docs: | |
# st.error("No valid PDFs found.") | |
# return | |
# splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100) | |
# texts = splitter.split_documents(docs) | |
# embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") | |
# db = Chroma.from_documents(texts, embeddings, persist_directory=PERSIST_DIRECTORY) | |
# db.persist() | |
# st.success("β Ingestion successful!") | |
# except Exception as e: | |
# logging.error(f"Ingestion error: {str(e)}") | |
# st.error(f"Ingestion error: {str(e)}") | |
# def get_qa_chain(): | |
# embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") | |
# vectordb = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embeddings) | |
# retriever = vectordb.as_retriever() | |
# pipe = pipeline( | |
# "text2text-generation", | |
# model=base_model, | |
# tokenizer=tokenizer, | |
# max_length=256, | |
# do_sample=True, | |
# temperature=0.3, | |
# top_p=0.95, | |
# device=device, | |
# ) | |
# llm = HuggingFacePipeline(pipeline=pipe) | |
# qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True) | |
# return qa_chain | |
# def main(): | |
# st.set_page_config(page_title="CA Audit QA Chatbot", layout="wide") | |
# st.title("π CA Audit QA Assistant") | |
# with st.sidebar: | |
# st.header("π€ Upload Audit PDFs") | |
# uploaded_file = st.file_uploader("Choose a PDF file", type="pdf") | |
# if uploaded_file is not None: | |
# file_path = os.path.join(UPLOAD_FOLDER, uploaded_file.name) | |
# with open(file_path, "wb") as f: | |
# f.write(uploaded_file.getbuffer()) | |
# st.success(f"{uploaded_file.name} uploaded.") | |
# ingest_data() | |
# query = st.text_input("β Ask an audit-related question:") | |
# if st.button("π Get Answer") and query: | |
# st.info("Generating answer...") | |
# qa_chain = get_qa_chain() | |
# prompt = f""" | |
# You are an AI assistant helping Chartered Accountants (CAs) in auditing. | |
# Provide accurate, concise answers based on the uploaded documents. | |
# Question: {query} | |
# """ | |
# result = qa_chain({"query": prompt}) | |
# st.success("β Answer:") | |
# st.write(result["result"]) | |
# if __name__ == "__main__": | |
# main() | |
import os | |
import logging | |
import math | |
import streamlit as st | |
import fitz # PyMuPDF | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
from langchain_community.document_loaders import PDFMinerLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.embeddings import SentenceTransformerEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from langchain_community.llms import HuggingFacePipeline | |
from langchain.chains import RetrievalQA | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
# Define global variables | |
device = 'cpu' | |
persist_directory = "db" | |
uploaded_files_dir = "uploaded_files" | |
# Streamlit app configuration | |
st.set_page_config(page_title="Audit Assistant", layout="wide") | |
st.title("Audit Assistant") | |
# Load the model | |
checkpoint = "MBZUAI/LaMini-T5-738M" | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
base_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) | |
# Helper Functions | |
def extract_text_from_pdf(file_path): | |
"""Extract text from a PDF using PyMuPDF (fitz).""" | |
try: | |
doc = fitz.open(file_path) | |
text = "" | |
for page_num in range(doc.page_count): | |
page = doc.load_page(page_num) | |
text += page.get_text("text") | |
return text | |
except Exception as e: | |
logging.error(f"Error reading PDF {file_path}: {e}") | |
return None | |
def data_ingestion(): | |
"""Function to load PDFs and create embeddings with improved error handling and efficiency.""" | |
try: | |
logging.info("Starting data ingestion") | |
if not os.path.exists(uploaded_files_dir): | |
os.makedirs(uploaded_files_dir) | |
documents = [] | |
for filename in os.listdir(uploaded_files_dir): | |
if filename.endswith(".pdf"): | |
file_path = os.path.join(uploaded_files_dir, filename) | |
logging.info(f"Processing file: {file_path}") | |
loader = PDFMinerLoader(file_path) | |
loaded_docs = loader.load() | |
# Check the structure of the loaded docs to ensure it has the correct format | |
for doc in loaded_docs: | |
if isinstance(doc, dict): # If the document is a dictionary | |
# Extract text content if present in the dictionary | |
if 'content' in doc: | |
doc_content = doc['content'] | |
else: | |
logging.warning(f"Skipping invalid document structure in {file_path}") | |
continue | |
elif hasattr(doc, 'page_content'): # If the document is a proper object | |
doc_content = doc.page_content | |
else: | |
logging.warning(f"Skipping invalid document structure in {file_path}") | |
continue | |
# If document content exists, add it to the documents list | |
if doc_content and len(doc_content.strip()) > 0: | |
documents.append(doc) | |
else: | |
logging.warning(f"Skipping empty or invalid document: {file_path}") | |
if not documents: | |
logging.error("No valid documents found to process.") | |
return | |
logging.info(f"Total valid documents: {len(documents)}") | |
# Split documents into smaller chunks | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100) | |
texts = text_splitter.split_documents(documents) | |
logging.info(f"Total text chunks created: {len(texts)}") | |
if not texts: | |
logging.error("No valid text chunks to create embeddings.") | |
return | |
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") | |
# Proceed to split and embed the documents | |
MAX_BATCH_SIZE = 5461 | |
total_batches = math.ceil(len(texts) / MAX_BATCH_SIZE) | |
logging.info(f"Processing {len(texts)} text chunks in {total_batches} batches...") | |
db = None | |
for i in range(total_batches): | |
batch_start = i * MAX_BATCH_SIZE | |
batch_end = min((i + 1) * MAX_BATCH_SIZE, len(texts)) | |
text_batch = texts[batch_start:batch_end] | |
logging.info(f"Processing batch {i + 1}/{total_batches}, size: {len(text_batch)}") | |
if db is None: | |
db = Chroma.from_documents(text_batch, embeddings, persist_directory=persist_directory) | |
else: | |
db.add_documents(text_batch) | |
db.persist() | |
logging.info("Data ingestion completed successfully") | |
except Exception as e: | |
logging.error(f"Error during data ingestion: {str(e)}") | |
raise | |
def llm_pipeline(): | |
"""Set up the language model pipeline.""" | |
logging.info("Setting up LLM pipeline") | |
pipe = pipeline( | |
'text2text-generation', | |
model=base_model, | |
tokenizer=tokenizer, | |
max_length=256, | |
do_sample=True, | |
temperature=0.3, | |
top_p=0.95, | |
device=device | |
) | |
local_llm = HuggingFacePipeline(pipeline=pipe) | |
logging.info("LLM pipeline setup complete") | |
return local_llm | |
def qa_llm(): | |
"""Set up the question-answering chain.""" | |
logging.info("Setting up QA model") | |
llm = llm_pipeline() | |
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") | |
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings) | |
retriever = db.as_retriever() # Set up the retriever for the vector store | |
qa = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=True | |
) | |
logging.info("QA model setup complete") | |
return qa | |
def process_answer(user_question): | |
"""Generate an answer to the userβs question.""" | |
try: | |
logging.info("Processing user question") | |
qa = qa_llm() | |
tailored_prompt = f""" | |
You are an expert chatbot designed to assist Chartered Accountants (CAs) in the field of audits. | |
Your goal is to provide accurate and comprehensive answers to any questions related to audit policies, procedures, | |
and accounting standards based on the provided PDF documents. | |
Please respond effectively and refer to the relevant standards and policies whenever applicable. | |
User question: {user_question} | |
""" | |
generated_text = qa({"query": tailored_prompt}) | |
answer = generated_text['result'] | |
if "not provide" in answer or "no information" in answer: | |
return "The document does not provide sufficient information to answer your question." | |
logging.info("Answer generated successfully") | |
return answer | |
except Exception as e: | |
logging.error(f"Error during answer generation: {str(e)}") | |
return "Error processing the question." | |
# Streamlit UI Setup | |
st.sidebar.header("File Upload") | |
uploaded_files = st.sidebar.file_uploader("Upload your PDF files", type=["pdf"], accept_multiple_files=True) | |
if uploaded_files: | |
# Save uploaded files | |
if not os.path.exists(uploaded_files_dir): | |
os.makedirs(uploaded_files_dir) | |
for uploaded_file in uploaded_files: | |
file_path = os.path.join(uploaded_files_dir, uploaded_file.name) | |
with open(file_path, "wb") as f: | |
f.write(uploaded_file.getbuffer()) | |
st.sidebar.success(f"Uploaded {len(uploaded_files)} file(s) successfully!") | |
# Run data ingestion when files are uploaded | |
data_ingestion() | |
# Display UI for Q&A | |
st.header("Ask a Question") | |
user_question = st.text_input("Enter your question here:") | |
if user_question: | |
answer = process_answer(user_question) | |
st.write(answer) | |
else: | |
st.sidebar.info("Upload PDF files to get started!") | |