File size: 5,915 Bytes
449bb7f
04b42d6
 
59f49c7
 
04b42d6
685e53c
59f49c7
04b42d6
59f49c7
 
685e53c
 
 
59f49c7
2ced9a6
59f49c7
685e53c
bf18560
685e53c
bf18560
 
 
 
 
 
 
 
 
 
 
 
59f49c7
 
685e53c
59f49c7
 
 
 
 
 
04b42d6
59f49c7
aae1639
685e53c
59f49c7
 
 
 
 
 
 
04b42d6
685e53c
59f49c7
 
 
 
 
 
 
 
 
 
 
 
04b42d6
59f49c7
 
685e53c
 
 
 
59f49c7
 
685e53c
 
04b42d6
685e53c
59f49c7
 
685e53c
 
 
 
aae1639
59f49c7
 
685e53c
 
aae1639
685e53c
59f49c7
 
 
 
 
 
 
aae1639
685e53c
59f49c7
 
685e53c
 
 
 
 
59f49c7
 
 
 
 
 
 
 
 
685e53c
 
04b42d6
685e53c
59f49c7
 
04b42d6
685e53c
59f49c7
04b42d6
2ced9a6
59f49c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aae1639
685e53c
 
59f49c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685e53c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import streamlit as st
import torch
import numpy as np
import faiss
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
import fitz  # PyMuPDF for PDF extraction
from langchain_text_splitters import RecursiveCharacterTextSplitter

# Configuration
MODEL_NAME = "ibm-granite/granite-3.1-1b-a400m-instruct"
EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"
CHUNK_SIZE = 512
CHUNK_OVERLAP = 64
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

@st.cache_resource
def load_models():
    try:
        # Load tokenizer and generative model with trust_remote_code enabled
        tokenizer = AutoTokenizer.from_pretrained(
            MODEL_NAME,
            trust_remote_code=True,
            revision="main"
        )
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            device_map="auto" if DEVICE == "cuda" else None,
            torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
            trust_remote_code=True,
            revision="main",
            low_cpu_mem_usage=True
        ).eval()
        
        # Load embedding model for FAISS
        embedder = SentenceTransformer(EMBED_MODEL, device=DEVICE)
        return tokenizer, model, embedder
    
    except Exception as e:
        st.error(f"Model loading failed: {str(e)}")
        st.stop()

tokenizer, model, embedder = load_models()

# Improved text processing: splits text into chunks
def process_text(text):
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=CHUNK_SIZE,
        chunk_overlap=CHUNK_OVERLAP,
        length_function=len
    )
    return splitter.split_text(text)

# Enhanced PDF extraction using PyMuPDF
def extract_pdf_text(uploaded_file):
    try:
        doc = fitz.open(stream=uploaded_file.read(), filetype="pdf")
        return "\n".join([page.get_text() for page in doc])
    except Exception as e:
        st.error(f"PDF extraction error: {str(e)}")
        return ""

# Multi-step summarization
def generate_summary(text):
    chunks = process_text(text)[:10]  # Use first 10 chunks for summary
    summaries = []
    
    for chunk in chunks:
        prompt = f"""<|user|>
Summarize this text section focusing on key themes, characters, and plot points:
{chunk[:2000]}
<|assistant|>
"""
        inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
        outputs = model.generate(**inputs, max_new_tokens=300, temperature=0.3)
        summary_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        summaries.append(summary_text)
    
    # Combine individual summaries into one comprehensive summary
    combined = "\n".join(summaries)
    final_prompt = f"""<|user|>
Combine these section summaries into a coherent book summary:
{combined}
<|assistant|>
The comprehensive summary is:"""
    
    inputs = tokenizer(final_prompt, return_tensors="pt").to(DEVICE)
    outputs = model.generate(**inputs, max_new_tokens=500, temperature=0.5)
    full_summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return full_summary.split(":")[-1].strip()

# Enhanced retrieval system using FAISS
def build_faiss_index(texts):
    embeddings = embedder.encode(texts, show_progress_bar=True)
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatIP(dimension)
    faiss.normalize_L2(embeddings)
    index.add(embeddings)
    return index

# Context-aware answer generation
def generate_answer(query, context):
    prompt = f"""<|user|>
Using this context: {context}
Answer the question precisely and truthfully. If unsure, say "I don't know".
Question: {query}
<|assistant|>
"""
    inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True).to(DEVICE)
    outputs = model.generate(
        **inputs,
        max_new_tokens=300,
        temperature=0.4,
        top_p=0.9,
        repetition_penalty=1.2,
        do_sample=True
    )
    answer_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return answer_text.split("<|assistant|>")[-1].strip()

# Streamlit UI setup
st.set_page_config(page_title="πŸ“š Smart Book Analyst", layout="wide")
st.title("πŸ“š AI-Powered Book Analysis System")

# File upload
uploaded_file = st.file_uploader("Upload book (PDF or TXT)", type=["pdf", "txt"])

if uploaded_file:
    with st.spinner("πŸ“– Analyzing book content..."):
        try:
            if uploaded_file.type == "application/pdf":
                text = extract_pdf_text(uploaded_file)
            else:
                text = uploaded_file.read().decode()
            
            chunks = process_text(text)
            st.session_state.docs = chunks
            st.session_state.index = build_faiss_index(chunks)
            
            with st.expander("πŸ“ Book Summary", expanded=True):
                summary = generate_summary(text)
                st.write(summary)
                
        except Exception as e:
            st.error(f"Processing failed: {str(e)}")

# Query interface
if "index" in st.session_state and st.session_state.index is not None:
    query = st.text_input("Ask about the book:")
    if query:
        with st.spinner("πŸ” Searching for answers..."):
            try:
                # Retrieve top 3 relevant chunks
                query_embed = embedder.encode([query])
                faiss.normalize_L2(query_embed)
                distances, indices = st.session_state.index.search(query_embed, k=3)
                
                context = "\n".join([st.session_state.docs[i] for i in indices[0]])
                answer = generate_answer(query, context)
                
                st.subheader("Answer")
                st.markdown(f"```\n{answer}\n```")
                st.caption("Retrieved context confidence: {:.2f}".format(distances[0][0]))
                
            except Exception as e:
                st.error(f"Query failed: {str(e)}")