import os import logging import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline from langchain_community.document_loaders import PDFMinerLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings import SentenceTransformerEmbeddings from langchain_community.vectorstores import Chroma from langchain_community.llms import HuggingFacePipeline from langchain.chains import RetrievalQA # Set up logging logging.basicConfig(level=logging.INFO) # Paths and model PERSIST_DIRECTORY = "db" UPLOAD_FOLDER = "uploaded_files" os.makedirs(UPLOAD_FOLDER, exist_ok=True) CHECKPOINT = "MBZUAI/LaMini-T5-738M" tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT) base_model = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT) device = 0 if torch.cuda.is_available() else -1 def ingest_data(): try: st.info("📚 Ingesting documents...") docs = [] for file_name in os.listdir(UPLOAD_FOLDER): if file_name.endswith(".pdf"): path = os.path.join(UPLOAD_FOLDER, file_name) loader = PDFMinerLoader(path) loaded_docs = loader.load() docs.extend(loaded_docs) if not docs: st.error("No valid PDFs found.") return splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100) texts = splitter.split_documents(docs) embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") db = Chroma.from_documents(texts, embeddings, persist_directory=PERSIST_DIRECTORY) db.persist() st.success("✅ Ingestion successful!") except Exception as e: logging.error(f"Ingestion error: {str(e)}") st.error(f"Ingestion error: {str(e)}") def get_qa_chain(): embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") vectordb = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embeddings) retriever = vectordb.as_retriever() pipe = pipeline( "text2text-generation", model=base_model, tokenizer=tokenizer, max_length=256, do_sample=True, temperature=0.3, top_p=0.95, device=device, ) llm = HuggingFacePipeline(pipeline=pipe) qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True) return qa_chain def main(): st.set_page_config(page_title="CA Audit QA Chatbot", layout="wide") st.title("📄 CA Audit QA Assistant") with st.sidebar: st.header("📤 Upload Audit PDFs") uploaded_file = st.file_uploader("Choose a PDF file", type="pdf") if uploaded_file is not None: file_path = os.path.join(UPLOAD_FOLDER, uploaded_file.name) with open(file_path, "wb") as f: f.write(uploaded_file.getbuffer()) st.success(f"{uploaded_file.name} uploaded.") ingest_data() query = st.text_input("❓ Ask an audit-related question:") if st.button("🔍 Get Answer") and query: st.info("Generating answer...") qa_chain = get_qa_chain() prompt = f""" You are an AI assistant helping Chartered Accountants (CAs) in auditing. Provide accurate, concise answers based on the uploaded documents. Question: {query} """ result = qa_chain({"query": prompt}) st.success("✅ Answer:") st.write(result["result"]) if __name__ == "__main__": main()