pradeepsengarr commited on
Commit
3875c87
·
verified ·
1 Parent(s): d52e65a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +272 -68
app.py CHANGED
@@ -1,7 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
2
  import logging
 
3
  import streamlit as st
4
- import torch
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
  from langchain_community.document_loaders import PDFMinerLoader
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -13,90 +118,189 @@ from langchain.chains import RetrievalQA
13
  # Set up logging
14
  logging.basicConfig(level=logging.INFO)
15
 
16
- # Paths and model
17
- PERSIST_DIRECTORY = "db"
18
- UPLOAD_FOLDER = "uploaded_files"
19
- os.makedirs(UPLOAD_FOLDER, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- CHECKPOINT = "MBZUAI/LaMini-T5-738M"
22
- tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
23
- base_model = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT)
24
- device = 0 if torch.cuda.is_available() else -1
 
 
 
 
 
 
 
 
25
 
26
- def ingest_data():
 
27
  try:
28
- st.info("📚 Ingesting documents...")
29
-
30
- docs = []
31
- for file_name in os.listdir(UPLOAD_FOLDER):
32
- if file_name.endswith(".pdf"):
33
- path = os.path.join(UPLOAD_FOLDER, file_name)
34
- loader = PDFMinerLoader(path)
35
- loaded_docs = loader.load()
36
- docs.extend(loaded_docs)
37
-
38
- if not docs:
39
- st.error("No valid PDFs found.")
 
 
 
 
 
 
 
 
 
40
  return
41
 
42
- splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
43
- texts = splitter.split_documents(docs)
 
 
 
 
 
 
 
 
 
44
 
45
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
46
- db = Chroma.from_documents(texts, embeddings, persist_directory=PERSIST_DIRECTORY)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  db.persist()
48
- st.success(" Ingestion successful!")
 
49
  except Exception as e:
50
- logging.error(f"Ingestion error: {str(e)}")
51
- st.error(f"Ingestion error: {str(e)}")
52
-
53
- def get_qa_chain():
54
- embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
55
- vectordb = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embeddings)
56
- retriever = vectordb.as_retriever()
57
 
 
 
 
58
  pipe = pipeline(
59
- "text2text-generation",
60
  model=base_model,
61
  tokenizer=tokenizer,
62
  max_length=256,
63
  do_sample=True,
64
  temperature=0.3,
65
  top_p=0.95,
66
- device=device,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  )
68
- llm = HuggingFacePipeline(pipeline=pipe)
69
-
70
- qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
71
- return qa_chain
72
-
73
- def main():
74
- st.set_page_config(page_title="CA Audit QA Chatbot", layout="wide")
75
- st.title("📄 CA Audit QA Assistant")
76
-
77
- with st.sidebar:
78
- st.header("📤 Upload Audit PDFs")
79
- uploaded_file = st.file_uploader("Choose a PDF file", type="pdf")
80
-
81
- if uploaded_file is not None:
82
- file_path = os.path.join(UPLOAD_FOLDER, uploaded_file.name)
83
- with open(file_path, "wb") as f:
84
- f.write(uploaded_file.getbuffer())
85
- st.success(f"{uploaded_file.name} uploaded.")
86
- ingest_data()
87
-
88
- query = st.text_input("❓ Ask an audit-related question:")
89
- if st.button("🔍 Get Answer") and query:
90
- st.info("Generating answer...")
91
- qa_chain = get_qa_chain()
92
- prompt = f"""
93
- You are an AI assistant helping Chartered Accountants (CAs) in auditing.
94
- Provide accurate, concise answers based on the uploaded documents.
95
- Question: {query}
96
  """
97
- result = qa_chain({"query": prompt})
98
- st.success("✅ Answer:")
99
- st.write(result["result"])
100
 
101
- if __name__ == "__main__":
102
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os
2
+ # import logging
3
+ # import streamlit as st
4
+ # import torch
5
+ # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
+ # from langchain_community.document_loaders import PDFMinerLoader
7
+ # from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ # from langchain_community.embeddings import SentenceTransformerEmbeddings
9
+ # from langchain_community.vectorstores import Chroma
10
+ # from langchain_community.llms import HuggingFacePipeline
11
+ # from langchain.chains import RetrievalQA
12
+
13
+ # # Set up logging
14
+ # logging.basicConfig(level=logging.INFO)
15
+
16
+ # # Paths and model
17
+ # PERSIST_DIRECTORY = "db"
18
+ # UPLOAD_FOLDER = "uploaded_files"
19
+ # os.makedirs(UPLOAD_FOLDER, exist_ok=True)
20
+
21
+ # CHECKPOINT = "MBZUAI/LaMini-T5-738M"
22
+ # tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
23
+ # base_model = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT)
24
+ # device = 0 if torch.cuda.is_available() else -1
25
+
26
+ # def ingest_data():
27
+ # try:
28
+ # st.info("📚 Ingesting documents...")
29
+
30
+ # docs = []
31
+ # for file_name in os.listdir(UPLOAD_FOLDER):
32
+ # if file_name.endswith(".pdf"):
33
+ # path = os.path.join(UPLOAD_FOLDER, file_name)
34
+ # loader = PDFMinerLoader(path)
35
+ # loaded_docs = loader.load()
36
+ # docs.extend(loaded_docs)
37
+
38
+ # if not docs:
39
+ # st.error("No valid PDFs found.")
40
+ # return
41
+
42
+ # splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
43
+ # texts = splitter.split_documents(docs)
44
+
45
+ # embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
46
+ # db = Chroma.from_documents(texts, embeddings, persist_directory=PERSIST_DIRECTORY)
47
+ # db.persist()
48
+ # st.success("✅ Ingestion successful!")
49
+ # except Exception as e:
50
+ # logging.error(f"Ingestion error: {str(e)}")
51
+ # st.error(f"Ingestion error: {str(e)}")
52
+
53
+ # def get_qa_chain():
54
+ # embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
55
+ # vectordb = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embeddings)
56
+ # retriever = vectordb.as_retriever()
57
+
58
+ # pipe = pipeline(
59
+ # "text2text-generation",
60
+ # model=base_model,
61
+ # tokenizer=tokenizer,
62
+ # max_length=256,
63
+ # do_sample=True,
64
+ # temperature=0.3,
65
+ # top_p=0.95,
66
+ # device=device,
67
+ # )
68
+ # llm = HuggingFacePipeline(pipeline=pipe)
69
+
70
+ # qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
71
+ # return qa_chain
72
+
73
+ # def main():
74
+ # st.set_page_config(page_title="CA Audit QA Chatbot", layout="wide")
75
+ # st.title("📄 CA Audit QA Assistant")
76
+
77
+ # with st.sidebar:
78
+ # st.header("📤 Upload Audit PDFs")
79
+ # uploaded_file = st.file_uploader("Choose a PDF file", type="pdf")
80
+
81
+ # if uploaded_file is not None:
82
+ # file_path = os.path.join(UPLOAD_FOLDER, uploaded_file.name)
83
+ # with open(file_path, "wb") as f:
84
+ # f.write(uploaded_file.getbuffer())
85
+ # st.success(f"{uploaded_file.name} uploaded.")
86
+ # ingest_data()
87
+
88
+ # query = st.text_input("❓ Ask an audit-related question:")
89
+ # if st.button("🔍 Get Answer") and query:
90
+ # st.info("Generating answer...")
91
+ # qa_chain = get_qa_chain()
92
+ # prompt = f"""
93
+ # You are an AI assistant helping Chartered Accountants (CAs) in auditing.
94
+ # Provide accurate, concise answers based on the uploaded documents.
95
+ # Question: {query}
96
+ # """
97
+ # result = qa_chain({"query": prompt})
98
+ # st.success("✅ Answer:")
99
+ # st.write(result["result"])
100
+
101
+ # if __name__ == "__main__":
102
+ # main()
103
+
104
+
105
  import os
106
+ import PyPDF2
107
  import logging
108
+ import math
109
  import streamlit as st
 
110
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
111
  from langchain_community.document_loaders import PDFMinerLoader
112
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
118
  # Set up logging
119
  logging.basicConfig(level=logging.INFO)
120
 
121
+ # Define global variables
122
+ device = 'cpu'
123
+ persist_directory = "db"
124
+ uploaded_files_dir = "uploaded_files"
125
+
126
+ # Streamlit app configuration
127
+ st.set_page_config(page_title="Audit Assistant", layout="wide")
128
+ st.title("Audit Assistant")
129
+
130
+ # Load the model
131
+ checkpoint = "MBZUAI/LaMini-T5-738M"
132
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
133
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
134
+
135
+ # Helper Functions
136
 
137
+ def extract_text_from_pdf(file_path):
138
+ """Extract text from a PDF using PyPDF2."""
139
+ try:
140
+ with open(file_path, 'rb') as file:
141
+ reader = PyPDF2.PdfReader(file)
142
+ text = ""
143
+ for page in range(len(reader.pages)):
144
+ text += reader.pages[page].extract_text()
145
+ return text
146
+ except Exception as e:
147
+ logging.error(f"Error reading PDF {file_path}: {e}")
148
+ return None
149
 
150
+ def data_ingestion():
151
+ """Load PDFs and create embeddings."""
152
  try:
153
+ logging.info("Starting data ingestion")
154
+
155
+ if not os.path.exists(uploaded_files_dir):
156
+ os.makedirs(uploaded_files_dir)
157
+
158
+ documents = []
159
+ for filename in os.listdir(uploaded_files_dir):
160
+ if filename.endswith(".pdf"):
161
+ file_path = os.path.join(uploaded_files_dir, filename)
162
+ logging.info(f"Processing file: {file_path}")
163
+
164
+ # Extract text using PyPDF2
165
+ text = extract_text_from_pdf(file_path)
166
+
167
+ if text:
168
+ documents.append({"page_content": text, "source": file_path})
169
+ else:
170
+ logging.warning(f"Skipping file due to extraction error: {file_path}")
171
+
172
+ if not documents:
173
+ logging.error("No valid documents found to process.")
174
  return
175
 
176
+ logging.info(f"Total valid documents: {len(documents)}")
177
+
178
+ # Split the documents into chunks
179
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
180
+ texts = text_splitter.split_documents(documents)
181
+
182
+ logging.info(f"Total text chunks created: {len(texts)}")
183
+
184
+ if not texts:
185
+ logging.error("No valid text chunks to create embeddings.")
186
+ return
187
 
188
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
189
+
190
+ # Process text chunks (embedding and persistence)
191
+ MAX_BATCH_SIZE = 5461
192
+ total_batches = math.ceil(len(texts) / MAX_BATCH_SIZE)
193
+
194
+ logging.info(f"Processing {len(texts)} text chunks in {total_batches} batches...")
195
+
196
+ db = None
197
+ for i in range(total_batches):
198
+ batch_start = i * MAX_BATCH_SIZE
199
+ batch_end = min((i + 1) * MAX_BATCH_SIZE, len(texts))
200
+ text_batch = texts[batch_start:batch_end]
201
+
202
+ logging.info(f"Processing batch {i + 1}/{total_batches}, size: {len(text_batch)}")
203
+
204
+ if db is None:
205
+ db = Chroma.from_documents(text_batch, embeddings, persist_directory=persist_directory)
206
+ else:
207
+ db.add_documents(text_batch)
208
+
209
  db.persist()
210
+ logging.info("Data ingestion completed successfully")
211
+
212
  except Exception as e:
213
+ logging.error(f"Error during data ingestion: {str(e)}")
214
+ raise
 
 
 
 
 
215
 
216
+ def llm_pipeline():
217
+ """Set up the language model pipeline."""
218
+ logging.info("Setting up LLM pipeline")
219
  pipe = pipeline(
220
+ 'text2text-generation',
221
  model=base_model,
222
  tokenizer=tokenizer,
223
  max_length=256,
224
  do_sample=True,
225
  temperature=0.3,
226
  top_p=0.95,
227
+ device=device
228
+ )
229
+ local_llm = HuggingFacePipeline(pipeline=pipe)
230
+ logging.info("LLM pipeline setup complete")
231
+ return local_llm
232
+
233
+ def qa_llm():
234
+ """Set up the question-answering chain."""
235
+ logging.info("Setting up QA model")
236
+ llm = llm_pipeline()
237
+ embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
238
+ db = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
239
+ retriever = db.as_retriever() # Set up the retriever for the vector store
240
+ qa = RetrievalQA.from_chain_type(
241
+ llm=llm,
242
+ chain_type="stuff",
243
+ retriever=retriever,
244
+ return_source_documents=True
245
  )
246
+ logging.info("QA model setup complete")
247
+ return qa
248
+
249
+ def process_answer(user_question):
250
+ """Generate an answer to the user’s question."""
251
+ try:
252
+ logging.info("Processing user question")
253
+ qa = qa_llm()
254
+
255
+ tailored_prompt = f"""
256
+ You are an expert chatbot designed to assist Chartered Accountants (CAs) in the field of audits.
257
+ Your goal is to provide accurate and comprehensive answers to any questions related to audit policies, procedures,
258
+ and accounting standards based on the provided PDF documents.
259
+ Please respond effectively and refer to the relevant standards and policies whenever applicable.
260
+
261
+ User question: {user_question}
 
 
 
 
 
 
 
 
 
 
 
 
262
  """
 
 
 
263
 
264
+ generated_text = qa({"query": tailored_prompt})
265
+ answer = generated_text['result']
266
+
267
+ if "not provide" in answer or "no information" in answer:
268
+ return "The document does not provide sufficient information to answer your question."
269
+
270
+ logging.info("Answer generated successfully")
271
+ return answer
272
+
273
+ except Exception as e:
274
+ logging.error(f"Error during answer generation: {str(e)}")
275
+ return "Error processing the question."
276
+
277
+ # Streamlit UI Setup
278
+ st.sidebar.header("File Upload")
279
+ uploaded_files = st.sidebar.file_uploader("Upload your PDF files", type=["pdf"], accept_multiple_files=True)
280
+
281
+ if uploaded_files:
282
+ # Save uploaded files
283
+ if not os.path.exists(uploaded_files_dir):
284
+ os.makedirs(uploaded_files_dir)
285
+
286
+ for uploaded_file in uploaded_files:
287
+ file_path = os.path.join(uploaded_files_dir, uploaded_file.name)
288
+ with open(file_path, "wb") as f:
289
+ f.write(uploaded_file.getbuffer())
290
+
291
+ st.sidebar.success(f"Uploaded {len(uploaded_files)} file(s) successfully!")
292
+
293
+ # Run data ingestion when files are uploaded
294
+ data_ingestion()
295
+
296
+ # Display UI for Q&A
297
+ st.header("Ask a Question")
298
+ user_question = st.text_input("Enter your question here:")
299
+
300
+ if user_question:
301
+ answer = process_answer(user_question)
302
+ st.write(answer)
303
+
304
+ else:
305
+ st.sidebar.info("Upload PDF files to get started!")
306
+