Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import faiss | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
import PyPDF2 | |
# Model Setup | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_path = "ibm-granite/granite-3.1-1b-a400m-instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
# Load the model with a conditional to avoid meta tensor issues on CPU vs GPU | |
if device == "cpu": | |
model = AutoModelForCausalLM.from_pretrained(model_path) | |
else: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map="auto", | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.float16, | |
) | |
model.eval() | |
# Embedding Model for FAISS | |
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
# FAISS Index | |
dimension = 384 # Embedding size for MiniLM | |
index = faiss.IndexFlatL2(dimension) | |
docs = [] # Store document texts | |
summary = "" # Store book summary | |
# Function to extract text from PDF | |
def extract_text_from_pdf(uploaded_file): | |
reader = PyPDF2.PdfReader(uploaded_file) | |
text = "\n".join([page.extract_text() for page in reader.pages if page.extract_text()]) | |
return text | |
# Function to process uploaded documents and generate summary | |
def process_documents(files): | |
global docs, index, summary | |
docs = [] | |
for file in files: | |
if file.type == "application/pdf": | |
text = extract_text_from_pdf(file) | |
else: | |
text = file.getvalue().decode("utf-8") | |
docs.append(text) | |
embeddings = embedding_model.encode(docs) | |
index.add(np.array(embeddings)) | |
# Generate summary after processing documents | |
summary = generate_summary("\n".join(docs)) | |
# Function to generate a book summary | |
def generate_summary(text): | |
chat = [ | |
{"role": "system", "content": "You are a helpful AI that summarizes books."}, | |
{"role": "user", "content": f"Summarize this book in a short paragraph:\n{text[:4000]}"} # Limiting input size for summarization | |
] | |
chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | |
input_tokens = tokenizer(chat, return_tensors="pt").to(device) | |
output = model.generate(**input_tokens, max_new_tokens=300) | |
return tokenizer.batch_decode(output, skip_special_tokens=True)[0] | |
# Function to retrieve relevant context using FAISS | |
def retrieve_context(query): | |
if index.ntotal == 0: | |
return "No documents available. Please upload files first." | |
query_embedding = embedding_model.encode([query]) | |
distances, indices = index.search(np.array(query_embedding), k=1) | |
if len(indices) == 0 or indices[0][0] >= len(docs): | |
return "No relevant context found." | |
return docs[indices[0][0]] | |
# Function to generate response using IBM Granite model | |
def generate_response(query, context): | |
chat = [ | |
{"role": "system", "content": "You are a helpful assistant using retrieved knowledge."}, | |
{"role": "user", "content": f"Context: {context}\nQuestion: {query}\nAnswer based on context:"}, | |
] | |
chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | |
input_tokens = tokenizer(chat, return_tensors="pt").to(device) | |
output = model.generate(**input_tokens, max_new_tokens=200) | |
return tokenizer.batch_decode(output, skip_special_tokens=True)[0] | |
# Streamlit UI | |
st.set_page_config(page_title="π AI Book Assistant", page_icon="π") | |
st.title("π AI-Powered Book Assistant") | |
st.subheader("Upload a book and get its summary or ask questions!") | |
uploaded_file = st.file_uploader("Upload a book (PDF or TXT)", accept_multiple_files=False) | |
if uploaded_file: | |
with st.spinner("Processing book and generating summary..."): | |
process_documents([uploaded_file]) | |
st.success("Book uploaded and processed!") | |
st.markdown("### π Book Summary:") | |
st.write(summary) | |
query = st.text_input("Ask a question about the book:") | |
if st.button("Get Answer"): | |
if index.ntotal == 0: | |
st.warning("Please upload a book first!") | |
else: | |
with st.spinner("Retrieving and generating response..."): | |
context = retrieve_context(query) | |
response = generate_response(query, context) | |
st.markdown("### π€ Answer:") | |
st.write(response) | |