Spaces:
Sleeping
Sleeping
File size: 4,375 Bytes
449bb7f 04b42d6 2ced9a6 04b42d6 aae1639 04b42d6 aae1639 04b42d6 aae1639 04b42d6 ff54315 04b42d6 aae1639 2ced9a6 aae1639 2ced9a6 04b42d6 aae1639 04b42d6 aae1639 2ced9a6 04b42d6 aae1639 04b42d6 2ced9a6 04b42d6 2ced9a6 aae1639 2ced9a6 aae1639 a7e018f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
import PyPDF2
# Model Setup
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "ibm-granite/granite-3.1-1b-a400m-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Load the model with a conditional to avoid meta tensor issues on CPU vs GPU
if device == "cpu":
model = AutoModelForCausalLM.from_pretrained(model_path)
else:
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
)
model.eval()
# Embedding Model for FAISS
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# FAISS Index
dimension = 384 # Embedding size for MiniLM
index = faiss.IndexFlatL2(dimension)
docs = [] # Store document texts
summary = "" # Store book summary
# Function to extract text from PDF
def extract_text_from_pdf(uploaded_file):
reader = PyPDF2.PdfReader(uploaded_file)
text = "\n".join([page.extract_text() for page in reader.pages if page.extract_text()])
return text
# Function to process uploaded documents and generate summary
def process_documents(files):
global docs, index, summary
docs = []
for file in files:
if file.type == "application/pdf":
text = extract_text_from_pdf(file)
else:
text = file.getvalue().decode("utf-8")
docs.append(text)
embeddings = embedding_model.encode(docs)
index.add(np.array(embeddings))
# Generate summary after processing documents
summary = generate_summary("\n".join(docs))
# Function to generate a book summary
def generate_summary(text):
chat = [
{"role": "system", "content": "You are a helpful AI that summarizes books."},
{"role": "user", "content": f"Summarize this book in a short paragraph:\n{text[:4000]}"} # Limiting input size for summarization
]
chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
input_tokens = tokenizer(chat, return_tensors="pt").to(device)
output = model.generate(**input_tokens, max_new_tokens=300)
return tokenizer.batch_decode(output, skip_special_tokens=True)[0]
# Function to retrieve relevant context using FAISS
def retrieve_context(query):
if index.ntotal == 0:
return "No documents available. Please upload files first."
query_embedding = embedding_model.encode([query])
distances, indices = index.search(np.array(query_embedding), k=1)
if len(indices) == 0 or indices[0][0] >= len(docs):
return "No relevant context found."
return docs[indices[0][0]]
# Function to generate response using IBM Granite model
def generate_response(query, context):
chat = [
{"role": "system", "content": "You are a helpful assistant using retrieved knowledge."},
{"role": "user", "content": f"Context: {context}\nQuestion: {query}\nAnswer based on context:"},
]
chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
input_tokens = tokenizer(chat, return_tensors="pt").to(device)
output = model.generate(**input_tokens, max_new_tokens=200)
return tokenizer.batch_decode(output, skip_special_tokens=True)[0]
# Streamlit UI
st.set_page_config(page_title="π AI Book Assistant", page_icon="π")
st.title("π AI-Powered Book Assistant")
st.subheader("Upload a book and get its summary or ask questions!")
uploaded_file = st.file_uploader("Upload a book (PDF or TXT)", accept_multiple_files=False)
if uploaded_file:
with st.spinner("Processing book and generating summary..."):
process_documents([uploaded_file])
st.success("Book uploaded and processed!")
st.markdown("### π Book Summary:")
st.write(summary)
query = st.text_input("Ask a question about the book:")
if st.button("Get Answer"):
if index.ntotal == 0:
st.warning("Please upload a book first!")
else:
with st.spinner("Retrieving and generating response..."):
context = retrieve_context(query)
response = generate_response(query, context)
st.markdown("### π€ Answer:")
st.write(response)
|