File size: 4,375 Bytes
449bb7f
04b42d6
 
 
 
 
 
 
 
 
 
 
2ced9a6
 
 
 
 
 
 
 
 
 
 
04b42d6
 
 
 
 
 
 
 
 
aae1639
 
04b42d6
 
 
 
 
 
aae1639
04b42d6
aae1639
04b42d6
 
 
 
 
ff54315
04b42d6
 
 
 
 
aae1639
 
 
 
 
 
 
 
2ced9a6
aae1639
 
 
 
 
 
2ced9a6
04b42d6
aae1639
 
 
04b42d6
 
 
aae1639
 
 
 
2ced9a6
04b42d6
 
 
 
 
 
 
 
 
 
 
aae1639
 
 
04b42d6
2ced9a6
04b42d6
2ced9a6
aae1639
2ced9a6
aae1639
 
 
 
 
 
 
 
a7e018f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
import PyPDF2

# Model Setup
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "ibm-granite/granite-3.1-1b-a400m-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Load the model with a conditional to avoid meta tensor issues on CPU vs GPU
if device == "cpu":
    model = AutoModelForCausalLM.from_pretrained(model_path)
else:
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map="auto",
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
    )
model.eval()

# Embedding Model for FAISS
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

# FAISS Index
dimension = 384  # Embedding size for MiniLM
index = faiss.IndexFlatL2(dimension)
docs = []  # Store document texts
summary = ""  # Store book summary

# Function to extract text from PDF
def extract_text_from_pdf(uploaded_file):
    reader = PyPDF2.PdfReader(uploaded_file)
    text = "\n".join([page.extract_text() for page in reader.pages if page.extract_text()])
    return text

# Function to process uploaded documents and generate summary
def process_documents(files):
    global docs, index, summary
    docs = []
    
    for file in files:
        if file.type == "application/pdf":
            text = extract_text_from_pdf(file)
        else:
            text = file.getvalue().decode("utf-8")
        docs.append(text)
    
    embeddings = embedding_model.encode(docs)
    index.add(np.array(embeddings))
    
    # Generate summary after processing documents
    summary = generate_summary("\n".join(docs))

# Function to generate a book summary
def generate_summary(text):
    chat = [
        {"role": "system", "content": "You are a helpful AI that summarizes books."},
        {"role": "user", "content": f"Summarize this book in a short paragraph:\n{text[:4000]}"}  # Limiting input size for summarization
    ]
    chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    input_tokens = tokenizer(chat, return_tensors="pt").to(device)
    output = model.generate(**input_tokens, max_new_tokens=300)
    return tokenizer.batch_decode(output, skip_special_tokens=True)[0]

# Function to retrieve relevant context using FAISS
def retrieve_context(query):
    if index.ntotal == 0:
        return "No documents available. Please upload files first."
    
    query_embedding = embedding_model.encode([query])
    distances, indices = index.search(np.array(query_embedding), k=1)
    
    if len(indices) == 0 or indices[0][0] >= len(docs):
        return "No relevant context found."
    return docs[indices[0][0]]

# Function to generate response using IBM Granite model
def generate_response(query, context):
    chat = [
        {"role": "system", "content": "You are a helpful assistant using retrieved knowledge."},
        {"role": "user", "content": f"Context: {context}\nQuestion: {query}\nAnswer based on context:"},
    ]
    chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    input_tokens = tokenizer(chat, return_tensors="pt").to(device)
    output = model.generate(**input_tokens, max_new_tokens=200)
    return tokenizer.batch_decode(output, skip_special_tokens=True)[0]

# Streamlit UI
st.set_page_config(page_title="πŸ“– AI Book Assistant", page_icon="πŸ“š")
st.title("πŸ“– AI-Powered Book Assistant")
st.subheader("Upload a book and get its summary or ask questions!")

uploaded_file = st.file_uploader("Upload a book (PDF or TXT)", accept_multiple_files=False)

if uploaded_file:
    with st.spinner("Processing book and generating summary..."):
        process_documents([uploaded_file])
    st.success("Book uploaded and processed!")
    st.markdown("### πŸ“š Book Summary:")
    st.write(summary)

query = st.text_input("Ask a question about the book:")
if st.button("Get Answer"):
    if index.ntotal == 0:
        st.warning("Please upload a book first!")
    else:
        with st.spinner("Retrieving and generating response..."):
            context = retrieve_context(query)
            response = generate_response(query, context)
            st.markdown("### πŸ€– Answer:")
            st.write(response)