File size: 4,014 Bytes
709f6b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import logging
import torch
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from langchain_community.document_loaders import PDFMinerLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA

# Setup
logging.basicConfig(level=logging.INFO)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

persist_directory = "db"
uploaded_files_dir = "uploaded_files"
os.makedirs(uploaded_files_dir, exist_ok=True)

checkpoint = "MBZUAI/LaMini-T5-738M"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
base_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

def data_ingestion():
    try:
        documents = []
        for filename in os.listdir(uploaded_files_dir):
            if filename.endswith(".pdf"):
                file_path = os.path.join(uploaded_files_dir, filename)
                loader = PDFMinerLoader(file_path)
                docs = loader.load()
                for doc in docs:
                    if hasattr(doc, 'page_content') and len(doc.page_content.strip()) > 0:
                        documents.append(doc)
        
        if not documents:
            st.error("No valid text extracted from uploaded PDFs.")
            return

        splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
        texts = splitter.split_documents(documents)

        embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")

        db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory)
        db.persist()
        st.success("Document ingested and stored successfully.")

    except Exception as e:
        st.error(f"Error during data ingestion: {str(e)}")

def qa_llm():
    pipe = pipeline(
        'text2text-generation',
        model=base_model,
        tokenizer=tokenizer,
        max_length=256,
        do_sample=True,
        temperature=0.3,
        top_p=0.95,
        device=0 if torch.cuda.is_available() else -1
    )
    llm = HuggingFacePipeline(pipeline=pipe)
    embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
    db = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
    retriever = db.as_retriever()
    qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
    return qa

def process_query(query):
    try:
        qa = qa_llm()
        tailored_prompt = f"""
        You are an expert chatbot designed to assist Chartered Accountants (CAs) in the field of audits.
        Your goal is to provide accurate and comprehensive answers to any questions related to audit policies,
        procedures, and accounting standards based on the uploaded PDF documents.

        User question: {query}
        """
        result = qa({"query": tailored_prompt})
        return result["result"]
    except Exception as e:
        return f"Error: {str(e)}"

# Streamlit UI
st.set_page_config(page_title="CA Audit Chatbot", layout="centered")
st.title("πŸ“š Chartered Accountant Audit Assistant")
st.markdown("Upload a PDF file and ask audit-related questions. This AI assistant will answer based on document content.")

# File uploader
uploaded_file = st.file_uploader("Upload PDF file", type=["pdf"])
if uploaded_file is not None:
    save_path = os.path.join(uploaded_files_dir, uploaded_file.name)
    with open(save_path, "wb") as f:
        f.write(uploaded_file.getbuffer())
    st.success("PDF uploaded successfully!")
    if st.button("Ingest Document"):
        data_ingestion()

# Query input
user_query = st.text_input("Ask a question about the audit document:")
if user_query:
    response = process_query(user_query)
    st.markdown("### πŸ“Œ Answer:")
    st.write(response)