File size: 3,601 Bytes
709f6b7
 
 
28c38fd
709f6b7
 
 
 
 
 
 
 
28c38fd
709f6b7
 
28c38fd
 
 
 
709f6b7
28c38fd
 
 
 
709f6b7
28c38fd
709f6b7
28c38fd
 
 
 
 
 
 
 
 
 
 
 
709f6b7
 
 
28c38fd
709f6b7
 
28c38fd
709f6b7
28c38fd
709f6b7
28c38fd
 
 
 
 
 
 
709f6b7
 
28c38fd
709f6b7
 
 
 
 
 
28c38fd
709f6b7
 
 
28c38fd
 
 
 
 
 
 
 
 
 
709f6b7
28c38fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
709f6b7
28c38fd
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import logging
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from langchain_community.document_loaders import PDFMinerLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA

# Set up logging
logging.basicConfig(level=logging.INFO)

# Paths and model
PERSIST_DIRECTORY = "db"
UPLOAD_FOLDER = "uploaded_files"
os.makedirs(UPLOAD_FOLDER, exist_ok=True)

CHECKPOINT = "MBZUAI/LaMini-T5-738M"
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
base_model = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT)
device = 0 if torch.cuda.is_available() else -1

def ingest_data():
    try:
        st.info("πŸ“š Ingesting documents...")

        docs = []
        for file_name in os.listdir(UPLOAD_FOLDER):
            if file_name.endswith(".pdf"):
                path = os.path.join(UPLOAD_FOLDER, file_name)
                loader = PDFMinerLoader(path)
                loaded_docs = loader.load()
                docs.extend(loaded_docs)

        if not docs:
            st.error("No valid PDFs found.")
            return

        splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
        texts = splitter.split_documents(docs)

        embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
        db = Chroma.from_documents(texts, embeddings, persist_directory=PERSIST_DIRECTORY)
        db.persist()
        st.success("βœ… Ingestion successful!")
    except Exception as e:
        logging.error(f"Ingestion error: {str(e)}")
        st.error(f"Ingestion error: {str(e)}")

def get_qa_chain():
    embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
    vectordb = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embeddings)
    retriever = vectordb.as_retriever()

    pipe = pipeline(
        "text2text-generation",
        model=base_model,
        tokenizer=tokenizer,
        max_length=256,
        do_sample=True,
        temperature=0.3,
        top_p=0.95,
        device=device,
    )
    llm = HuggingFacePipeline(pipeline=pipe)

    qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
    return qa_chain

def main():
    st.set_page_config(page_title="CA Audit QA Chatbot", layout="wide")
    st.title("πŸ“„ CA Audit QA Assistant")

    with st.sidebar:
        st.header("πŸ“€ Upload Audit PDFs")
        uploaded_file = st.file_uploader("Choose a PDF file", type="pdf")

        if uploaded_file is not None:
            file_path = os.path.join(UPLOAD_FOLDER, uploaded_file.name)
            with open(file_path, "wb") as f:
                f.write(uploaded_file.getbuffer())
            st.success(f"{uploaded_file.name} uploaded.")
            ingest_data()

    query = st.text_input("❓ Ask an audit-related question:")
    if st.button("πŸ” Get Answer") and query:
        st.info("Generating answer...")
        qa_chain = get_qa_chain()
        prompt = f"""
        You are an AI assistant helping Chartered Accountants (CAs) in auditing.
        Provide accurate, concise answers based on the uploaded documents.
        Question: {query}
        """
        result = qa_chain({"query": prompt})
        st.success("βœ… Answer:")
        st.write(result["result"])

if __name__ == "__main__":
    main()