pradeepsengarr commited on
Commit
28c38fd
Β·
verified Β·
1 Parent(s): 94f70e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -70
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import logging
3
- import torch
4
  import streamlit as st
 
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
  from langchain_community.document_loaders import PDFMinerLoader
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -10,97 +10,93 @@ from langchain_community.vectorstores import Chroma
10
  from langchain_community.llms import HuggingFacePipeline
11
  from langchain.chains import RetrievalQA
12
 
13
- # Setup
14
  logging.basicConfig(level=logging.INFO)
15
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
 
17
- persist_directory = "db"
18
- uploaded_files_dir = "uploaded_files"
19
- os.makedirs(uploaded_files_dir, exist_ok=True)
 
20
 
21
- checkpoint = "MBZUAI/LaMini-T5-738M"
22
- tokenizer = AutoTokenizer.from_pretrained(checkpoint)
23
- base_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
 
24
 
25
- def data_ingestion():
26
  try:
27
- documents = []
28
- for filename in os.listdir(uploaded_files_dir):
29
- if filename.endswith(".pdf"):
30
- file_path = os.path.join(uploaded_files_dir, filename)
31
- loader = PDFMinerLoader(file_path)
32
- docs = loader.load()
33
- for doc in docs:
34
- if hasattr(doc, 'page_content') and len(doc.page_content.strip()) > 0:
35
- documents.append(doc)
36
-
37
- if not documents:
38
- st.error("No valid text extracted from uploaded PDFs.")
39
  return
40
 
41
  splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
42
- texts = splitter.split_documents(documents)
43
 
44
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
45
-
46
- db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory)
47
  db.persist()
48
- st.success("Document ingested and stored successfully.")
49
-
50
  except Exception as e:
51
- st.error(f"Error during data ingestion: {str(e)}")
 
 
 
 
 
 
52
 
53
- def qa_llm():
54
  pipe = pipeline(
55
- 'text2text-generation',
56
  model=base_model,
57
  tokenizer=tokenizer,
58
  max_length=256,
59
  do_sample=True,
60
  temperature=0.3,
61
  top_p=0.95,
62
- device=0 if torch.cuda.is_available() else -1
63
  )
64
  llm = HuggingFacePipeline(pipeline=pipe)
65
- embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
66
- db = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
67
- retriever = db.as_retriever()
68
- qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
69
- return qa
70
 
71
- def process_query(query):
72
- try:
73
- qa = qa_llm()
74
- tailored_prompt = f"""
75
- You are an expert chatbot designed to assist Chartered Accountants (CAs) in the field of audits.
76
- Your goal is to provide accurate and comprehensive answers to any questions related to audit policies,
77
- procedures, and accounting standards based on the uploaded PDF documents.
 
 
 
78
 
79
- User question: {query}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  """
81
- result = qa({"query": tailored_prompt})
82
- return result["result"]
83
- except Exception as e:
84
- return f"Error: {str(e)}"
85
-
86
- # Streamlit UI
87
- st.set_page_config(page_title="CA Audit Chatbot", layout="centered")
88
- st.title("πŸ“š Chartered Accountant Audit Assistant")
89
- st.markdown("Upload a PDF file and ask audit-related questions. This AI assistant will answer based on document content.")
90
-
91
- # File uploader
92
- uploaded_file = st.file_uploader("Upload PDF file", type=["pdf"])
93
- if uploaded_file is not None:
94
- save_path = os.path.join(uploaded_files_dir, uploaded_file.name)
95
- with open(save_path, "wb") as f:
96
- f.write(uploaded_file.getbuffer())
97
- st.success("PDF uploaded successfully!")
98
- if st.button("Ingest Document"):
99
- data_ingestion()
100
-
101
- # Query input
102
- user_query = st.text_input("Ask a question about the audit document:")
103
- if user_query:
104
- response = process_query(user_query)
105
- st.markdown("### πŸ“Œ Answer:")
106
- st.write(response)
 
1
  import os
2
  import logging
 
3
  import streamlit as st
4
+ import torch
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
  from langchain_community.document_loaders import PDFMinerLoader
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
10
  from langchain_community.llms import HuggingFacePipeline
11
  from langchain.chains import RetrievalQA
12
 
13
+ # Set up logging
14
  logging.basicConfig(level=logging.INFO)
 
15
 
16
+ # Paths and model
17
+ PERSIST_DIRECTORY = "db"
18
+ UPLOAD_FOLDER = "uploaded_files"
19
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
20
 
21
+ CHECKPOINT = "MBZUAI/LaMini-T5-738M"
22
+ tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
23
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT)
24
+ device = 0 if torch.cuda.is_available() else -1
25
 
26
+ def ingest_data():
27
  try:
28
+ st.info("πŸ“š Ingesting documents...")
29
+
30
+ docs = []
31
+ for file_name in os.listdir(UPLOAD_FOLDER):
32
+ if file_name.endswith(".pdf"):
33
+ path = os.path.join(UPLOAD_FOLDER, file_name)
34
+ loader = PDFMinerLoader(path)
35
+ loaded_docs = loader.load()
36
+ docs.extend(loaded_docs)
37
+
38
+ if not docs:
39
+ st.error("No valid PDFs found.")
40
  return
41
 
42
  splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
43
+ texts = splitter.split_documents(docs)
44
 
45
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
46
+ db = Chroma.from_documents(texts, embeddings, persist_directory=PERSIST_DIRECTORY)
 
47
  db.persist()
48
+ st.success("βœ… Ingestion successful!")
 
49
  except Exception as e:
50
+ logging.error(f"Ingestion error: {str(e)}")
51
+ st.error(f"Ingestion error: {str(e)}")
52
+
53
+ def get_qa_chain():
54
+ embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
55
+ vectordb = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embeddings)
56
+ retriever = vectordb.as_retriever()
57
 
 
58
  pipe = pipeline(
59
+ "text2text-generation",
60
  model=base_model,
61
  tokenizer=tokenizer,
62
  max_length=256,
63
  do_sample=True,
64
  temperature=0.3,
65
  top_p=0.95,
66
+ device=device,
67
  )
68
  llm = HuggingFacePipeline(pipeline=pipe)
 
 
 
 
 
69
 
70
+ qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
71
+ return qa_chain
72
+
73
+ def main():
74
+ st.set_page_config(page_title="CA Audit QA Chatbot", layout="wide")
75
+ st.title("πŸ“„ CA Audit QA Assistant")
76
+
77
+ with st.sidebar:
78
+ st.header("πŸ“€ Upload Audit PDFs")
79
+ uploaded_file = st.file_uploader("Choose a PDF file", type="pdf")
80
 
81
+ if uploaded_file is not None:
82
+ file_path = os.path.join(UPLOAD_FOLDER, uploaded_file.name)
83
+ with open(file_path, "wb") as f:
84
+ f.write(uploaded_file.getbuffer())
85
+ st.success(f"{uploaded_file.name} uploaded.")
86
+ ingest_data()
87
+
88
+ query = st.text_input("❓ Ask an audit-related question:")
89
+ if st.button("πŸ” Get Answer") and query:
90
+ st.info("Generating answer...")
91
+ qa_chain = get_qa_chain()
92
+ prompt = f"""
93
+ You are an AI assistant helping Chartered Accountants (CAs) in auditing.
94
+ Provide accurate, concise answers based on the uploaded documents.
95
+ Question: {query}
96
  """
97
+ result = qa_chain({"query": prompt})
98
+ st.success("βœ… Answer:")
99
+ st.write(result["result"])
100
+
101
+ if __name__ == "__main__":
102
+ main()