AIPaperPilot / app.py
sunbal7's picture
Update app.py
2ced9a6 verified
raw
history blame
4.38 kB
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
import PyPDF2
# Model Setup
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "ibm-granite/granite-3.1-1b-a400m-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Load the model with a conditional to avoid meta tensor issues on CPU vs GPU
if device == "cpu":
model = AutoModelForCausalLM.from_pretrained(model_path)
else:
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
)
model.eval()
# Embedding Model for FAISS
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# FAISS Index
dimension = 384 # Embedding size for MiniLM
index = faiss.IndexFlatL2(dimension)
docs = [] # Store document texts
summary = "" # Store book summary
# Function to extract text from PDF
def extract_text_from_pdf(uploaded_file):
reader = PyPDF2.PdfReader(uploaded_file)
text = "\n".join([page.extract_text() for page in reader.pages if page.extract_text()])
return text
# Function to process uploaded documents and generate summary
def process_documents(files):
global docs, index, summary
docs = []
for file in files:
if file.type == "application/pdf":
text = extract_text_from_pdf(file)
else:
text = file.getvalue().decode("utf-8")
docs.append(text)
embeddings = embedding_model.encode(docs)
index.add(np.array(embeddings))
# Generate summary after processing documents
summary = generate_summary("\n".join(docs))
# Function to generate a book summary
def generate_summary(text):
chat = [
{"role": "system", "content": "You are a helpful AI that summarizes books."},
{"role": "user", "content": f"Summarize this book in a short paragraph:\n{text[:4000]}"} # Limiting input size for summarization
]
chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
input_tokens = tokenizer(chat, return_tensors="pt").to(device)
output = model.generate(**input_tokens, max_new_tokens=300)
return tokenizer.batch_decode(output, skip_special_tokens=True)[0]
# Function to retrieve relevant context using FAISS
def retrieve_context(query):
if index.ntotal == 0:
return "No documents available. Please upload files first."
query_embedding = embedding_model.encode([query])
distances, indices = index.search(np.array(query_embedding), k=1)
if len(indices) == 0 or indices[0][0] >= len(docs):
return "No relevant context found."
return docs[indices[0][0]]
# Function to generate response using IBM Granite model
def generate_response(query, context):
chat = [
{"role": "system", "content": "You are a helpful assistant using retrieved knowledge."},
{"role": "user", "content": f"Context: {context}\nQuestion: {query}\nAnswer based on context:"},
]
chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
input_tokens = tokenizer(chat, return_tensors="pt").to(device)
output = model.generate(**input_tokens, max_new_tokens=200)
return tokenizer.batch_decode(output, skip_special_tokens=True)[0]
# Streamlit UI
st.set_page_config(page_title="πŸ“– AI Book Assistant", page_icon="πŸ“š")
st.title("πŸ“– AI-Powered Book Assistant")
st.subheader("Upload a book and get its summary or ask questions!")
uploaded_file = st.file_uploader("Upload a book (PDF or TXT)", accept_multiple_files=False)
if uploaded_file:
with st.spinner("Processing book and generating summary..."):
process_documents([uploaded_file])
st.success("Book uploaded and processed!")
st.markdown("### πŸ“š Book Summary:")
st.write(summary)
query = st.text_input("Ask a question about the book:")
if st.button("Get Answer"):
if index.ntotal == 0:
st.warning("Please upload a book first!")
else:
with st.spinner("Retrieving and generating response..."):
context = retrieve_context(query)
response = generate_response(query, context)
st.markdown("### πŸ€– Answer:")
st.write(response)