Bot_RAG / app.py
pradeepsengarr's picture
Update app.py
28c38fd verified
raw
history blame
3.6 kB
import os
import logging
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from langchain_community.document_loaders import PDFMinerLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
# Set up logging
logging.basicConfig(level=logging.INFO)
# Paths and model
PERSIST_DIRECTORY = "db"
UPLOAD_FOLDER = "uploaded_files"
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
CHECKPOINT = "MBZUAI/LaMini-T5-738M"
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
base_model = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT)
device = 0 if torch.cuda.is_available() else -1
def ingest_data():
try:
st.info("πŸ“š Ingesting documents...")
docs = []
for file_name in os.listdir(UPLOAD_FOLDER):
if file_name.endswith(".pdf"):
path = os.path.join(UPLOAD_FOLDER, file_name)
loader = PDFMinerLoader(path)
loaded_docs = loader.load()
docs.extend(loaded_docs)
if not docs:
st.error("No valid PDFs found.")
return
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
texts = splitter.split_documents(docs)
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
db = Chroma.from_documents(texts, embeddings, persist_directory=PERSIST_DIRECTORY)
db.persist()
st.success("βœ… Ingestion successful!")
except Exception as e:
logging.error(f"Ingestion error: {str(e)}")
st.error(f"Ingestion error: {str(e)}")
def get_qa_chain():
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
vectordb = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embeddings)
retriever = vectordb.as_retriever()
pipe = pipeline(
"text2text-generation",
model=base_model,
tokenizer=tokenizer,
max_length=256,
do_sample=True,
temperature=0.3,
top_p=0.95,
device=device,
)
llm = HuggingFacePipeline(pipeline=pipe)
qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
return qa_chain
def main():
st.set_page_config(page_title="CA Audit QA Chatbot", layout="wide")
st.title("πŸ“„ CA Audit QA Assistant")
with st.sidebar:
st.header("πŸ“€ Upload Audit PDFs")
uploaded_file = st.file_uploader("Choose a PDF file", type="pdf")
if uploaded_file is not None:
file_path = os.path.join(UPLOAD_FOLDER, uploaded_file.name)
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.success(f"{uploaded_file.name} uploaded.")
ingest_data()
query = st.text_input("❓ Ask an audit-related question:")
if st.button("πŸ” Get Answer") and query:
st.info("Generating answer...")
qa_chain = get_qa_chain()
prompt = f"""
You are an AI assistant helping Chartered Accountants (CAs) in auditing.
Provide accurate, concise answers based on the uploaded documents.
Question: {query}
"""
result = qa_chain({"query": prompt})
st.success("βœ… Answer:")
st.write(result["result"])
if __name__ == "__main__":
main()