Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import faiss | |
import numpy as np | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from sentence_transformers import SentenceTransformer | |
from PyPDF2 import PdfReader | |
from docx import Document | |
import re | |
# Initialize models | |
def load_models(): | |
# Text embedding model | |
embed_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2') | |
TOKEN = "TOKEN" | |
# IBM Granite models with proper token and trust_remote_code settings | |
summary_tokenizer = AutoTokenizer.from_pretrained( | |
"ibm/granite-13b-instruct-v2", | |
token=TOKEN, | |
trust_remote_code=True | |
) | |
summary_model = AutoModelForCausalLM.from_pretrained( | |
"ibm/granite-13b-instruct-v2", | |
token=TOKEN, | |
trust_remote_code=True | |
) | |
qa_tokenizer = AutoTokenizer.from_pretrained( | |
"ibm/granite-13b-instruct-v2", | |
token=TOKEN, | |
trust_remote_code=True | |
) | |
qa_model = AutoModelForCausalLM.from_pretrained( | |
"ibm/granite-13b-instruct-v2", | |
token=TOKEN, | |
trust_remote_code=True | |
) | |
return embed_model, summary_model, summary_tokenizer, qa_model, qa_tokenizer | |
def process_file(uploaded_file): | |
text = "" | |
file_type = uploaded_file.name.split('.')[-1].lower() | |
if file_type == 'pdf': | |
pdf_reader = PdfReader(uploaded_file) | |
for page in pdf_reader.pages: | |
text += page.extract_text() or "" | |
elif file_type == 'txt': | |
text = uploaded_file.read().decode('utf-8') | |
elif file_type == 'docx': | |
doc = Document(uploaded_file) | |
for para in doc.paragraphs: | |
text += para.text + "\n" | |
return clean_text(text) | |
def clean_text(text): | |
text = re.sub(r'\s+', ' ', text) | |
text = re.sub(r'[^\x00-\x7F]+', ' ', text) | |
return text | |
def split_text(text, chunk_size=500): | |
return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)] | |
def create_faiss_index(text_chunks, embed_model): | |
embeddings = embed_model.encode(text_chunks) | |
dimension = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dimension) | |
index.add(np.array(embeddings).astype('float32')) | |
return index | |
def generate_summary(text, model, tokenizer): | |
inputs = tokenizer(f"Summarize this document: {text[:3000]}", return_tensors="pt", max_length=4096, truncation=True) | |
summary_ids = model.generate(inputs.input_ids, max_length=500) | |
return tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
def answer_question(question, index, text_chunks, embed_model, model, tokenizer): | |
question_embed = embed_model.encode([question]) | |
_, indices = index.search(question_embed.astype('float32'), 3) | |
context = " ".join([text_chunks[i] for i in indices[0]]) | |
prompt = f"Context: {context}\n\nQuestion: {question}\nAnswer:" | |
inputs = tokenizer(prompt, return_tensors="pt", max_length=4096, truncation=True) | |
outputs = model.generate(inputs.input_ids, max_length=500) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
def main(): | |
st.title("π RAG Book Assistant with IBM Granite") | |
embed_model, summary_model, summary_tokenizer, qa_model, qa_tokenizer = load_models() | |
uploaded_file = st.file_uploader("Upload a document (PDF/TXT/DOCX)", type=['pdf', 'txt', 'docx']) | |
if uploaded_file and 'processed' not in st.session_state: | |
with st.spinner("Processing document..."): | |
text = process_file(uploaded_file) | |
text_chunks = split_text(text) | |
st.session_state.text_chunks = text_chunks | |
st.session_state.faiss_index = create_faiss_index(text_chunks, embed_model) | |
summary = generate_summary(text, summary_model, summary_tokenizer) | |
st.session_state.summary = summary | |
st.session_state.processed = True | |
if 'processed' in st.session_state: | |
st.subheader("Document Summary") | |
st.write(st.session_state.summary) | |
st.divider() | |
question = st.text_input("Ask a question about the document:") | |
if question: | |
answer = answer_question( | |
question, | |
st.session_state.faiss_index, | |
st.session_state.text_chunks, | |
embed_model, | |
qa_model, | |
qa_tokenizer | |
) | |
st.info(f"Answer: {answer}") | |
if __name__ == "__main__": | |
main() | |