Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
from langchain_community.document_loaders import PDFMinerLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.embeddings import SentenceTransformerEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from langchain_community.llms import HuggingFacePipeline | |
from langchain.chains import RetrievalQA | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
# Paths and model | |
PERSIST_DIRECTORY = "db" | |
UPLOAD_FOLDER = "uploaded_files" | |
os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
CHECKPOINT = "MBZUAI/LaMini-T5-738M" | |
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT) | |
base_model = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT) | |
device = 0 if torch.cuda.is_available() else -1 | |
def ingest_data(): | |
try: | |
st.info("π Ingesting documents...") | |
docs = [] | |
for file_name in os.listdir(UPLOAD_FOLDER): | |
if file_name.endswith(".pdf"): | |
path = os.path.join(UPLOAD_FOLDER, file_name) | |
loader = PDFMinerLoader(path) | |
loaded_docs = loader.load() | |
docs.extend(loaded_docs) | |
if not docs: | |
st.error("No valid PDFs found.") | |
return | |
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100) | |
texts = splitter.split_documents(docs) | |
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") | |
db = Chroma.from_documents(texts, embeddings, persist_directory=PERSIST_DIRECTORY) | |
db.persist() | |
st.success("β Ingestion successful!") | |
except Exception as e: | |
logging.error(f"Ingestion error: {str(e)}") | |
st.error(f"Ingestion error: {str(e)}") | |
def get_qa_chain(): | |
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") | |
vectordb = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embeddings) | |
retriever = vectordb.as_retriever() | |
pipe = pipeline( | |
"text2text-generation", | |
model=base_model, | |
tokenizer=tokenizer, | |
max_length=256, | |
do_sample=True, | |
temperature=0.3, | |
top_p=0.95, | |
device=device, | |
) | |
llm = HuggingFacePipeline(pipeline=pipe) | |
qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True) | |
return qa_chain | |
def main(): | |
st.set_page_config(page_title="CA Audit QA Chatbot", layout="wide") | |
st.title("π CA Audit QA Assistant") | |
with st.sidebar: | |
st.header("π€ Upload Audit PDFs") | |
uploaded_file = st.file_uploader("Choose a PDF file", type="pdf") | |
if uploaded_file is not None: | |
file_path = os.path.join(UPLOAD_FOLDER, uploaded_file.name) | |
with open(file_path, "wb") as f: | |
f.write(uploaded_file.getbuffer()) | |
st.success(f"{uploaded_file.name} uploaded.") | |
ingest_data() | |
query = st.text_input("β Ask an audit-related question:") | |
if st.button("π Get Answer") and query: | |
st.info("Generating answer...") | |
qa_chain = get_qa_chain() | |
prompt = f""" | |
You are an AI assistant helping Chartered Accountants (CAs) in auditing. | |
Provide accurate, concise answers based on the uploaded documents. | |
Question: {query} | |
""" | |
result = qa_chain({"query": prompt}) | |
st.success("β Answer:") | |
st.write(result["result"]) | |
if __name__ == "__main__": | |
main() | |