sunbal7's picture
Update app.py
7bea7bd verified
raw
history blame
4.5 kB
import streamlit as st
import os
import faiss
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer
from PyPDF2 import PdfReader
from docx import Document
import re
# Initialize models
@st.cache_resource
def load_models():
# Text embedding model
embed_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
TOKEN = "TOKEN"
# IBM Granite models with proper token and trust_remote_code settings
summary_tokenizer = AutoTokenizer.from_pretrained(
"ibm/granite-13b-instruct-v2",
token=TOKEN,
trust_remote_code=True
)
summary_model = AutoModelForCausalLM.from_pretrained(
"ibm/granite-13b-instruct-v2",
token=TOKEN,
trust_remote_code=True
)
qa_tokenizer = AutoTokenizer.from_pretrained(
"ibm/granite-13b-instruct-v2",
token=TOKEN,
trust_remote_code=True
)
qa_model = AutoModelForCausalLM.from_pretrained(
"ibm/granite-13b-instruct-v2",
token=TOKEN,
trust_remote_code=True
)
return embed_model, summary_model, summary_tokenizer, qa_model, qa_tokenizer
def process_file(uploaded_file):
text = ""
file_type = uploaded_file.name.split('.')[-1].lower()
if file_type == 'pdf':
pdf_reader = PdfReader(uploaded_file)
for page in pdf_reader.pages:
text += page.extract_text() or ""
elif file_type == 'txt':
text = uploaded_file.read().decode('utf-8')
elif file_type == 'docx':
doc = Document(uploaded_file)
for para in doc.paragraphs:
text += para.text + "\n"
return clean_text(text)
def clean_text(text):
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'[^\x00-\x7F]+', ' ', text)
return text
def split_text(text, chunk_size=500):
return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
def create_faiss_index(text_chunks, embed_model):
embeddings = embed_model.encode(text_chunks)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(np.array(embeddings).astype('float32'))
return index
def generate_summary(text, model, tokenizer):
inputs = tokenizer(f"Summarize this document: {text[:3000]}", return_tensors="pt", max_length=4096, truncation=True)
summary_ids = model.generate(inputs.input_ids, max_length=500)
return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
def answer_question(question, index, text_chunks, embed_model, model, tokenizer):
question_embed = embed_model.encode([question])
_, indices = index.search(question_embed.astype('float32'), 3)
context = " ".join([text_chunks[i] for i in indices[0]])
prompt = f"Context: {context}\n\nQuestion: {question}\nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt", max_length=4096, truncation=True)
outputs = model.generate(inputs.input_ids, max_length=500)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def main():
st.title("πŸ“– RAG Book Assistant with IBM Granite")
embed_model, summary_model, summary_tokenizer, qa_model, qa_tokenizer = load_models()
uploaded_file = st.file_uploader("Upload a document (PDF/TXT/DOCX)", type=['pdf', 'txt', 'docx'])
if uploaded_file and 'processed' not in st.session_state:
with st.spinner("Processing document..."):
text = process_file(uploaded_file)
text_chunks = split_text(text)
st.session_state.text_chunks = text_chunks
st.session_state.faiss_index = create_faiss_index(text_chunks, embed_model)
summary = generate_summary(text, summary_model, summary_tokenizer)
st.session_state.summary = summary
st.session_state.processed = True
if 'processed' in st.session_state:
st.subheader("Document Summary")
st.write(st.session_state.summary)
st.divider()
question = st.text_input("Ask a question about the document:")
if question:
answer = answer_question(
question,
st.session_state.faiss_index,
st.session_state.text_chunks,
embed_model,
qa_model,
qa_tokenizer
)
st.info(f"Answer: {answer}")
if __name__ == "__main__":
main()