Spaces:
Sleeping
Sleeping
File size: 4,498 Bytes
76552c4 7bea7bd 76552c4 7bea7bd 76552c4 7bea7bd 76552c4 7bea7bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import streamlit as st
import os
import faiss
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer
from PyPDF2 import PdfReader
from docx import Document
import re
# Initialize models
@st.cache_resource
def load_models():
# Text embedding model
embed_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
TOKEN = "TOKEN"
# IBM Granite models with proper token and trust_remote_code settings
summary_tokenizer = AutoTokenizer.from_pretrained(
"ibm/granite-13b-instruct-v2",
token=TOKEN,
trust_remote_code=True
)
summary_model = AutoModelForCausalLM.from_pretrained(
"ibm/granite-13b-instruct-v2",
token=TOKEN,
trust_remote_code=True
)
qa_tokenizer = AutoTokenizer.from_pretrained(
"ibm/granite-13b-instruct-v2",
token=TOKEN,
trust_remote_code=True
)
qa_model = AutoModelForCausalLM.from_pretrained(
"ibm/granite-13b-instruct-v2",
token=TOKEN,
trust_remote_code=True
)
return embed_model, summary_model, summary_tokenizer, qa_model, qa_tokenizer
def process_file(uploaded_file):
text = ""
file_type = uploaded_file.name.split('.')[-1].lower()
if file_type == 'pdf':
pdf_reader = PdfReader(uploaded_file)
for page in pdf_reader.pages:
text += page.extract_text() or ""
elif file_type == 'txt':
text = uploaded_file.read().decode('utf-8')
elif file_type == 'docx':
doc = Document(uploaded_file)
for para in doc.paragraphs:
text += para.text + "\n"
return clean_text(text)
def clean_text(text):
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'[^\x00-\x7F]+', ' ', text)
return text
def split_text(text, chunk_size=500):
return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
def create_faiss_index(text_chunks, embed_model):
embeddings = embed_model.encode(text_chunks)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(np.array(embeddings).astype('float32'))
return index
def generate_summary(text, model, tokenizer):
inputs = tokenizer(f"Summarize this document: {text[:3000]}", return_tensors="pt", max_length=4096, truncation=True)
summary_ids = model.generate(inputs.input_ids, max_length=500)
return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
def answer_question(question, index, text_chunks, embed_model, model, tokenizer):
question_embed = embed_model.encode([question])
_, indices = index.search(question_embed.astype('float32'), 3)
context = " ".join([text_chunks[i] for i in indices[0]])
prompt = f"Context: {context}\n\nQuestion: {question}\nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt", max_length=4096, truncation=True)
outputs = model.generate(inputs.input_ids, max_length=500)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def main():
st.title("π RAG Book Assistant with IBM Granite")
embed_model, summary_model, summary_tokenizer, qa_model, qa_tokenizer = load_models()
uploaded_file = st.file_uploader("Upload a document (PDF/TXT/DOCX)", type=['pdf', 'txt', 'docx'])
if uploaded_file and 'processed' not in st.session_state:
with st.spinner("Processing document..."):
text = process_file(uploaded_file)
text_chunks = split_text(text)
st.session_state.text_chunks = text_chunks
st.session_state.faiss_index = create_faiss_index(text_chunks, embed_model)
summary = generate_summary(text, summary_model, summary_tokenizer)
st.session_state.summary = summary
st.session_state.processed = True
if 'processed' in st.session_state:
st.subheader("Document Summary")
st.write(st.session_state.summary)
st.divider()
question = st.text_input("Ask a question about the document:")
if question:
answer = answer_question(
question,
st.session_state.faiss_index,
st.session_state.text_chunks,
embed_model,
qa_model,
qa_tokenizer
)
st.info(f"Answer: {answer}")
if __name__ == "__main__":
main()
|