sunbal7 commited on
Commit
ab46633
·
verified ·
1 Parent(s): c7d20f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -279
app.py CHANGED
@@ -1,332 +1,153 @@
1
  import streamlit as st
 
 
2
  import torch
3
  import numpy as np
4
  import faiss
5
- import time
6
- import re
7
- from typing import List, Tuple
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
9
  from sentence_transformers import SentenceTransformer
10
- import fitz # PyMuPDF
11
- import docx2txt
12
  from langchain_text_splitters import RecursiveCharacterTextSplitter
13
- from io import BytesIO
14
 
15
  # ------------------------
16
  # Configuration
17
  # ------------------------
18
  MODEL_NAME = "ibm-granite/granite-3.1-1b-a400m-instruct"
19
  EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"
20
- CHUNK_SIZE = 1024 # Increased for better context
21
- CHUNK_OVERLAP = 128
22
- MAX_FILE_SIZE_MB = 10
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
 
25
  # ------------------------
26
- # Model Loading with Quantization
27
  # ------------------------
28
  @st.cache_resource
29
  def load_models():
30
  try:
31
- # Configure quantization for CPU deployment
32
- quant_config = BitsAndBytesConfig(
33
- load_in_4bit=True,
34
- bnb_4bit_quant_type="nf4",
35
- bnb_4bit_use_double_quant=True,
36
- ) if DEVICE == "cpu" else None
37
-
38
  tokenizer = AutoTokenizer.from_pretrained(
39
  MODEL_NAME,
40
  trust_remote_code=True,
41
  revision="main"
42
  )
43
-
44
  model = AutoModelForCausalLM.from_pretrained(
45
  MODEL_NAME,
46
  trust_remote_code=True,
47
  revision="main",
48
- device_map="auto",
49
- quantization_config=quant_config,
50
  torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
51
  low_cpu_mem_usage=True
52
  ).eval()
53
-
54
- # Load embedding model with FP16 optimization
55
- embedder = SentenceTransformer(
56
- EMBED_MODEL,
57
- device=DEVICE,
58
- device_kwargs={"keep_all_models": True}
59
- )
60
- if DEVICE == "cuda":
61
- embedder = embedder.half()
62
-
63
  return tokenizer, model, embedder
64
  except Exception as e:
65
  st.error(f"Model loading failed: {str(e)}")
66
  st.stop()
67
 
 
 
68
  # ------------------------
69
- # Enhanced Text Processing
70
  # ------------------------
71
- def clean_text(text: str) -> str:
72
- """Advanced text cleaning with multiple normalization steps"""
73
- text = re.sub(r'\s+', ' ', text) # Remove extra whitespace
74
- text = re.sub(r'[^\x00-\x7F]+', ' ', text) # Remove non-ASCII
75
- text = re.sub(r'\bPage \d+\b', '', text) # Remove page numbers
76
- text = re.sub(r'http\S+', '', text) # Remove URLs
77
- text = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '', text) # Remove emails
78
- return text.strip()
79
-
80
- def extract_text(file: BytesIO) -> Tuple[str, List[str]]:
81
- """Improved text extraction with format-specific handling"""
82
- try:
83
- if file.size > MAX_FILE_SIZE_MB * 1024 * 1024:
84
- raise ValueError(f"File size exceeds {MAX_FILE_SIZE_MB}MB limit")
85
-
86
- file_type = file.type
87
- text = ""
88
-
89
- if file_type == "application/pdf":
90
- doc = fitz.open(stream=file.read(), filetype="pdf")
91
- text = "\n".join([page.get_text("text", flags=fitz.TEXT_PRESERVE_WHITESPACE) for page in doc])
92
- # Extract images metadata for future multimodal expansion
93
- images = [img for page in doc for img in page.get_images()]
94
- if images:
95
- st.session_state.images = images
96
- elif file_type == "text/plain":
97
- text = file.read().decode("utf-8", errors="replace")
98
- elif file_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
99
- text = docx2txt.process(file)
100
- else:
101
- raise ValueError("Unsupported file type")
102
-
103
- return clean_text(text)
104
- except Exception as e:
105
- st.error(f"Text extraction failed: {str(e)}")
106
- st.stop()
107
-
108
- def semantic_chunking(text: str) -> List[str]:
109
- """Context-aware text splitting with metadata tracking"""
110
  splitter = RecursiveCharacterTextSplitter(
111
  chunk_size=CHUNK_SIZE,
112
  chunk_overlap=CHUNK_OVERLAP,
113
- length_function=len,
114
- add_start_index=True
115
  )
116
- chunks = splitter.split_text(text)
117
- return chunks
118
 
119
- # ------------------------
120
- # Enhanced Vector Indexing
121
- # ------------------------
122
- def build_faiss_index(chunks: List[str], embedder) -> faiss.Index:
123
- """Build optimized FAISS index with error handling"""
124
- try:
125
- embeddings = embedder.encode(
126
- chunks,
127
- batch_size=32,
128
- show_progress_bar=True,
129
- convert_to_tensor=True
130
- )
131
- if DEVICE == "cuda":
132
- embeddings = embeddings.cpu().numpy()
133
- else:
134
- embeddings = embeddings.numpy()
135
-
136
- dimension = embeddings.shape[1]
137
- index = faiss.IndexFlatIP(dimension)
138
- faiss.normalize_L2(embeddings)
139
- index.add(embeddings)
140
- return index
141
- except Exception as e:
142
- st.error(f"Index creation failed: {str(e)}")
143
- st.stop()
 
 
 
144
 
145
  # ------------------------
146
- # Improved Generation Functions
147
  # ------------------------
148
- def format_prompt(system_prompt: str, user_input: str) -> str:
149
- """Structured prompt formatting for better model performance"""
150
- return f"""<|system|>
151
- {system_prompt}
152
- <|user|>
153
- {user_input}
154
- <|assistant|>
155
- """
 
 
156
 
157
- def generate_summary(text: str, tokenizer, model) -> str:
158
- """Hierarchical summarization with chunk processing"""
159
- try:
160
- # First-stage summary
161
- chunks = [text[i:i+3000] for i in range(0, len(text), 3000)]
162
- summaries = []
163
-
164
- for chunk in chunks:
165
- prompt = format_prompt(
166
- "Generate a detailed summary of this text excerpt:",
167
- chunk[:2500]
168
- )
169
- inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
170
- outputs = model.generate(
171
- **inputs,
172
- max_new_tokens=300,
173
- temperature=0.3,
174
- do_sample=True
175
- )
176
- summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
177
- summaries.append(summary.split("<|assistant|>")[-1].strip())
178
-
179
- # Final synthesis
180
- final_prompt = format_prompt(
181
- "Synthesize these summaries into a comprehensive overview:",
182
- "\n".join(summaries)
183
- )
184
- inputs = tokenizer(final_prompt, return_tensors="pt").to(DEVICE)
185
- outputs = model.generate(
186
- **inputs,
187
- max_new_tokens=500,
188
- temperature=0.4,
189
- do_sample=True
190
- )
191
- return tokenizer.decode(outputs[0], skip_special_tokens=True).split("<|assistant|>")[-1].strip()
192
- except Exception as e:
193
- st.error(f"Summarization failed: {str(e)}")
194
- return "Summary generation failed"
195
-
196
- def retrieve_context(query: str, index, chunks: List[str], embedder, top_k: int = 3) -> str:
197
- """Enhanced retrieval with score thresholding"""
198
- query_embed = embedder.encode([query], convert_to_tensor=True)
199
- if DEVICE == "cuda":
200
- query_embed = query_embed.cpu().numpy()
201
- else:
202
- query_embed = query_embed.numpy()
203
-
204
- faiss.normalize_L2(query_embed)
205
- scores, indices = index.search(query_embed, top_k*2) # Retrieve extra for filtering
206
-
207
- # Apply similarity threshold
208
- valid_indices = [i for i, score in zip(indices[0], scores[0]) if score > 0.35]
209
- return " ".join([chunks[i] for i in valid_indices[:top_k]])
210
 
211
  # ------------------------
212
- # Streamlit UI Improvements
213
  # ------------------------
214
- def main():
215
- st.set_page_config(
216
- page_title="RAG Book Analyzer Pro",
217
- layout="wide",
218
- initial_sidebar_state="expanded"
219
- )
220
-
221
- # Initialize session state
222
- if "processed" not in st.session_state:
223
- st.session_state.processed = False
224
- if "index" not in st.session_state:
225
- st.session_state.index = None
226
-
227
- # Load models once
228
- tokenizer, model, embedder = load_models()
229
-
230
- # Sidebar controls
231
- with st.sidebar:
232
- st.header("Settings")
233
- top_k = st.slider("Number of context passages", 1, 5, 3)
234
- temp = st.slider("Generation Temperature", 0.1, 1.0, 0.4)
235
-
236
- # Main interface
237
- st.title("📚 Advanced Book Analyzer")
238
- st.write("Upload technical manuals, research papers, or books for deep analysis")
239
-
240
- uploaded_file = st.file_uploader(
241
- "Choose a document",
242
- type=["pdf", "txt", "docx"],
243
- accept_multiple_files=False
244
- )
245
-
246
- if uploaded_file and not st.session_state.processed:
247
- with st.spinner("Analyzing document..."):
248
- start_time = time.time()
249
-
250
- # Process document
251
- text = extract_text(uploaded_file)
252
- chunks = semantic_chunking(text)
253
- index = build_faiss_index(chunks, embedder)
254
-
255
- # Store in session state
256
- st.session_state.update({
257
- "chunks": chunks,
258
- "index": index,
259
- "processed": True,
260
- "text": text
261
- })
262
-
263
- st.success(f"Processed {len(chunks)} chunks in {time.time()-start_time:.1f}s")
264
-
265
- if st.session_state.processed:
266
- # Summary section
267
- with st.expander("Document Summary", expanded=True):
268
- summary = generate_summary(st.session_state.text, tokenizer, model)
269
- st.markdown(summary)
270
 
271
- # Q&A Section
272
- st.divider()
273
- col1, col2 = st.columns([3, 1])
274
- with col1:
275
- query = st.text_input("Ask about the document:", placeholder="What are the key findings...")
276
- with col2:
277
- show_context = st.checkbox("Show context sources")
278
 
 
 
279
  if query:
280
- with st.spinner("Searching document..."):
281
- context = retrieve_context(
282
- query,
283
- st.session_state.index,
284
- st.session_state.chunks,
285
- embedder,
286
- top_k=top_k
287
- )
288
-
289
- if not context:
290
- st.warning("No relevant context found in document")
291
- return
292
-
293
- with st.expander("Generated Answer", expanded=True):
294
- answer = generate_answer(query, context, tokenizer, model, temp)
295
- st.markdown(answer)
296
-
297
- if show_context:
298
- st.divider()
299
- st.subheader("Source Context")
300
- st.write(context)
301
-
302
- def generate_answer(query: str, context: str, tokenizer, model, temp: float) -> str:
303
- """Improved answer generation with context validation"""
304
- try:
305
- prompt = format_prompt(
306
- f"""Answer the question using only the provided context.
307
- Follow these rules:
308
- 1. Be precise and factual
309
- 2. If unsure, say 'The document does not specify'
310
- 3. Use bullet points when listing items
311
- 4. Keep answers under 3 sentences
312
-
313
- Context: {context[:2000]}""",
314
- query
315
- )
316
- inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
317
- outputs = model.generate(
318
- **inputs,
319
- max_new_tokens=400,
320
- temperature=temp,
321
- top_p=0.9,
322
- repetition_penalty=1.2,
323
- do_sample=True
324
- )
325
- answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
326
- return answer.split("<|assistant|>")[-1].strip()
327
- except Exception as e:
328
- st.error(f"Generation failed: {str(e)}")
329
- return "Unable to generate answer"
330
-
331
- if __name__ == "__main__":
332
- main()
 
1
  import streamlit as st
2
+ st.set_page_config(page_title="RAG Book Analyzer", layout="wide") # Must be the first Streamlit command
3
+
4
  import torch
5
  import numpy as np
6
  import faiss
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
8
  from sentence_transformers import SentenceTransformer
9
+ import fitz # PyMuPDF for PDF extraction
10
+ import docx2txt # For DOCX extraction
11
  from langchain_text_splitters import RecursiveCharacterTextSplitter
 
12
 
13
  # ------------------------
14
  # Configuration
15
  # ------------------------
16
  MODEL_NAME = "ibm-granite/granite-3.1-1b-a400m-instruct"
17
  EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"
18
+ CHUNK_SIZE = 512
19
+ CHUNK_OVERLAP = 64
 
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
  # ------------------------
23
+ # Model Loading with Caching
24
  # ------------------------
25
  @st.cache_resource
26
  def load_models():
27
  try:
 
 
 
 
 
 
 
28
  tokenizer = AutoTokenizer.from_pretrained(
29
  MODEL_NAME,
30
  trust_remote_code=True,
31
  revision="main"
32
  )
 
33
  model = AutoModelForCausalLM.from_pretrained(
34
  MODEL_NAME,
35
  trust_remote_code=True,
36
  revision="main",
37
+ device_map="auto" if DEVICE == "cuda" else None,
 
38
  torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
39
  low_cpu_mem_usage=True
40
  ).eval()
41
+ embedder = SentenceTransformer(EMBED_MODEL, device=DEVICE)
 
 
 
 
 
 
 
 
 
42
  return tokenizer, model, embedder
43
  except Exception as e:
44
  st.error(f"Model loading failed: {str(e)}")
45
  st.stop()
46
 
47
+ tokenizer, model, embedder = load_models()
48
+
49
  # ------------------------
50
+ # Text Processing Functions
51
  # ------------------------
52
+ def split_text(text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  splitter = RecursiveCharacterTextSplitter(
54
  chunk_size=CHUNK_SIZE,
55
  chunk_overlap=CHUNK_OVERLAP,
56
+ length_function=len
 
57
  )
58
+ return splitter.split_text(text)
 
59
 
60
+ def extract_text(file):
61
+ file_type = file.type
62
+ if file_type == "application/pdf":
63
+ try:
64
+ doc = fitz.open(stream=file.read(), filetype="pdf")
65
+ return "\n".join([page.get_text() for page in doc])
66
+ except Exception as e:
67
+ st.error("Error processing PDF: " + str(e))
68
+ return ""
69
+ elif file_type == "text/plain":
70
+ return file.read().decode("utf-8")
71
+ elif file_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
72
+ try:
73
+ return docx2txt.process(file)
74
+ except Exception as e:
75
+ st.error("Error processing DOCX: " + str(e))
76
+ return ""
77
+ else:
78
+ st.error("Unsupported file type: " + file_type)
79
+ return ""
80
+
81
+ def build_index(chunks):
82
+ embeddings = embedder.encode(chunks, show_progress_bar=True)
83
+ dimension = embeddings.shape[1]
84
+ index = faiss.IndexFlatIP(dimension)
85
+ faiss.normalize_L2(embeddings)
86
+ index.add(embeddings)
87
+ return index
88
 
89
  # ------------------------
90
+ # Summarization and Q&A Functions
91
  # ------------------------
92
+ def generate_summary(text):
93
+ # Limit input text to avoid long sequences
94
+ prompt = f"<|user|>\nSummarize the following book in a concise and informative paragraph:\n\n{text[:4000]}\n<|assistant|>\n"
95
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
96
+ outputs = model.generate(**inputs, max_new_tokens=300, temperature=0.5)
97
+ summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
98
+ # Remove any markers and extra lines; return the first non-empty paragraph.
99
+ summary = summary.replace("<|assistant|>", "").strip()
100
+ paragraphs = [p.strip() for p in summary.split("\n") if p.strip()]
101
+ return paragraphs[0] if paragraphs else summary
102
 
103
+ def generate_answer(query, context):
104
+ prompt = f"<|user|>\nUsing the context below, answer the following question precisely. If unsure, say 'I don't know'.\n\nContext: {context}\n\nQuestion: {query}\n<|assistant|>\n"
105
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True).to(DEVICE)
106
+ outputs = model.generate(
107
+ **inputs,
108
+ max_new_tokens=300,
109
+ temperature=0.4,
110
+ top_p=0.9,
111
+ repetition_penalty=1.2,
112
+ do_sample=True
113
+ )
114
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
115
+ answer = answer.replace("<|assistant|>", "").strip()
116
+ paragraphs = [p.strip() for p in answer.split("\n") if p.strip()]
117
+ return paragraphs[0] if paragraphs else answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  # ------------------------
120
+ # Streamlit UI
121
  # ------------------------
122
+ st.title("RAG-Based Book Analyzer")
123
+ st.write("Upload a book (PDF, TXT, DOCX) to get a summary and ask questions about its content.")
124
+
125
+ uploaded_file = st.file_uploader("Upload File", type=["pdf", "txt", "docx"])
126
+
127
+ if uploaded_file:
128
+ text = extract_text(uploaded_file)
129
+ if text:
130
+ st.success("File successfully processed!")
131
+ st.write("Generating summary...")
132
+ summary = generate_summary(text)
133
+ st.markdown("### Book Summary")
134
+ st.write(summary)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ # Process text into chunks and build FAISS index
137
+ chunks = split_text(text)
138
+ index = build_index(chunks)
139
+ st.session_state.chunks = chunks
140
+ st.session_state.index = index
 
 
141
 
142
+ st.markdown("### Ask a Question about the Book:")
143
+ query = st.text_input("Your Question:")
144
  if query:
145
+ # Retrieve top 3 relevant chunks as context
146
+ query_embedding = embedder.encode([query])
147
+ faiss.normalize_L2(query_embedding)
148
+ distances, indices = st.session_state.index.search(query_embedding, k=3)
149
+ retrieved_chunks = [chunks[i] for i in indices[0] if i < len(chunks)]
150
+ context = "\n".join(retrieved_chunks)
151
+ answer = generate_answer(query, context)
152
+ st.markdown("### Answer")
153
+ st.write(answer)