Hyma Roshini Gompa commited on
Commit
4bf0eb2
Β·
1 Parent(s): 0e553a7

Add Streamlit app filesq

Browse files
Files changed (1) hide show
  1. app.py +449 -0
app.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import shelve
3
+ import docx2txt
4
+ import PyPDF2
5
+ import time # Used to simulate typing effect
6
+ import nltk
7
+ import re
8
+ import os
9
+ import time # already imported in your code
10
+ import requests
11
+ from dotenv import load_dotenv
12
+ import torch
13
+ from sentence_transformers import SentenceTransformer, util
14
+ nltk.download('punkt')
15
+ import hashlib
16
+ from nltk import sent_tokenize
17
+ nltk.download('punkt_tab')
18
+ from transformers import LEDTokenizer, LEDForConditionalGeneration
19
+ from transformers import pipeline
20
+ import asyncio
21
+ import sys
22
+ # Fix for RuntimeError: no running event loop on Windows
23
+ if sys.platform.startswith("win"):
24
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
25
+
26
+
27
+ st.set_page_config(page_title="Legal Document Summarizer", layout="wide")
28
+
29
+ st.title("πŸ“„ Legal Document Summarizer (stage 4 )")
30
+
31
+ USER_AVATAR = "πŸ‘€"
32
+ BOT_AVATAR = "πŸ€–"
33
+
34
+ # Load chat history
35
+ def load_chat_history():
36
+ with shelve.open("chat_history") as db:
37
+ return db.get("messages", [])
38
+
39
+ # Save chat history
40
+ def save_chat_history(messages):
41
+ with shelve.open("chat_history") as db:
42
+ db["messages"] = messages
43
+
44
+ # Function to limit text preview to 500 words
45
+ def limit_text(text, word_limit=500):
46
+ words = text.split()
47
+ return " ".join(words[:word_limit]) + ("..." if len(words) > word_limit else "")
48
+
49
+
50
+ # CLEAN AND NORMALIZE TEXT
51
+
52
+
53
+ def clean_text(text):
54
+ # Remove newlines and extra spaces
55
+ text = text.replace('\r\n', ' ').replace('\n', ' ')
56
+ text = re.sub(r'\s+', ' ', text)
57
+
58
+ # Remove page number markers like "Page 1 of 10"
59
+ text = re.sub(r'Page\s+\d+\s+of\s+\d+', '', text, flags=re.IGNORECASE)
60
+
61
+ # Remove long dashed or underscored lines
62
+ text = re.sub(r'[_]{5,}', '', text) # Lines with underscores: _____
63
+ text = re.sub(r'[-]{5,}', '', text) # Lines with hyphens: -----
64
+
65
+ # Remove long dotted separators
66
+ text = re.sub(r'[.]{4,}', '', text) # Dots like "......" or ".............."
67
+
68
+ # Trim final leading/trailing whitespace
69
+ text = text.strip()
70
+
71
+ return text
72
+
73
+
74
+ #######################################################################################################################
75
+
76
+
77
+ # LOADING MODELS FOR DIVIDING TEXT INTO SECTIONS
78
+
79
+ # Load token from .env file
80
+ load_dotenv()
81
+ HF_API_TOKEN = os.getenv("HF_API_TOKEN")
82
+
83
+
84
+ # Load once at the top (cache for performance)
85
+ @st.cache_resource
86
+ def load_local_zero_shot_classifier():
87
+ return pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")
88
+
89
+ local_classifier = load_local_zero_shot_classifier()
90
+
91
+
92
+ SECTION_LABELS = ["Facts", "Arguments", "Judgment", "Other"]
93
+
94
+ def classify_chunk(text):
95
+ result = local_classifier(text, candidate_labels=SECTION_LABELS)
96
+ return result["labels"][0]
97
+
98
+
99
+ # NEW: NLP-based sectioning using zero-shot classification
100
+ def section_by_zero_shot(text):
101
+ sections = {"Facts": "", "Arguments": "", "Judgment": "", "Other": ""}
102
+ sentences = sent_tokenize(text)
103
+ chunk = ""
104
+
105
+ for i, sent in enumerate(sentences):
106
+ chunk += sent + " "
107
+ if (i + 1) % 3 == 0 or i == len(sentences) - 1:
108
+ label = classify_chunk(chunk.strip())
109
+ print(f"πŸ”Ž Chunk: {chunk[:60]}...\nπŸ”– Predicted Label: {label}")
110
+ # πŸ‘‡ Normalize label (title case and fallback)
111
+ label = label.capitalize()
112
+ if label not in sections:
113
+ label = "Other"
114
+ sections[label] += chunk + "\n"
115
+ chunk = ""
116
+
117
+ return sections
118
+
119
+ #######################################################################################################################
120
+
121
+
122
+
123
+ # EXTRACTING TEXT FROM UPLOADED FILES
124
+
125
+ # Function to extract text from uploaded file
126
+ def extract_text(file):
127
+ if file.name.endswith(".pdf"):
128
+ reader = PyPDF2.PdfReader(file)
129
+ full_text = "\n".join(page.extract_text() or "" for page in reader.pages)
130
+ elif file.name.endswith(".docx"):
131
+ full_text = docx2txt.process(file)
132
+ elif file.name.endswith(".txt"):
133
+ full_text = file.read().decode("utf-8")
134
+ else:
135
+ return "Unsupported file type."
136
+
137
+ return full_text # Full text is needed for summarization
138
+
139
+
140
+ #######################################################################################################################
141
+
142
+ # EXTRACTIVE AND ABSTRACTIVE SUMMARIZATION
143
+
144
+
145
+ @st.cache_resource
146
+ def load_legalbert():
147
+ return SentenceTransformer("nlpaueb/legal-bert-base-uncased")
148
+
149
+
150
+ legalbert_model = load_legalbert()
151
+
152
+ @st.cache_resource
153
+ def load_led():
154
+ tokenizer = LEDTokenizer.from_pretrained("allenai/led-base-16384")
155
+ model = LEDForConditionalGeneration.from_pretrained("allenai/led-base-16384")
156
+ return tokenizer, model
157
+
158
+ tokenizer_led, model_led = load_led()
159
+
160
+
161
+ def legalbert_extractive_summary(text, top_ratio=0.2):
162
+ sentences = sent_tokenize(text)
163
+ top_k = max(3, int(len(sentences) * top_ratio))
164
+
165
+ if len(sentences) <= top_k:
166
+ return text
167
+
168
+ # Embeddings & scoring
169
+ sentence_embeddings = legalbert_model.encode(sentences, convert_to_tensor=True)
170
+ doc_embedding = torch.mean(sentence_embeddings, dim=0)
171
+ cosine_scores = util.pytorch_cos_sim(doc_embedding, sentence_embeddings)[0]
172
+ top_results = torch.topk(cosine_scores, k=top_k)
173
+
174
+ # Preserve original order
175
+ selected_sentences = [sentences[i] for i in sorted(top_results.indices.tolist())]
176
+ return " ".join(selected_sentences)
177
+
178
+
179
+
180
+ # Add LED Abstractive Summarization
181
+
182
+
183
+ def led_abstractive_summary(text, max_length=512, min_length=100):
184
+ inputs = tokenizer_led(
185
+ text, return_tensors="pt", padding="max_length",
186
+ truncation=True, max_length=4096
187
+ )
188
+ global_attention_mask = torch.zeros_like(inputs["input_ids"])
189
+ global_attention_mask[:, 0] = 1
190
+
191
+ outputs = model_led.generate(
192
+ inputs["input_ids"],
193
+ attention_mask=inputs["attention_mask"],
194
+ global_attention_mask=global_attention_mask,
195
+ max_length=max_length,
196
+ min_length=min_length,
197
+ num_beams=4, # Use beam search
198
+ repetition_penalty=2.0, # Penalize repetition
199
+ length_penalty=1.0,
200
+ early_stopping=True,
201
+ no_repeat_ngram_size=4 # Prevent repeated phrases
202
+ )
203
+
204
+ return tokenizer_led.decode(outputs[0], skip_special_tokens=True)
205
+
206
+
207
+
208
+ def led_abstractive_summary_chunked(text, max_tokens=3000):
209
+ sentences = sent_tokenize(text)
210
+ current_chunk = ""
211
+ chunks = []
212
+ for sent in sentences:
213
+ if len(tokenizer_led(current_chunk + sent)["input_ids"]) > max_tokens:
214
+ chunks.append(current_chunk)
215
+ current_chunk = sent
216
+ else:
217
+ current_chunk += " " + sent
218
+ if current_chunk:
219
+ chunks.append(current_chunk)
220
+
221
+ summaries = []
222
+ for chunk in chunks:
223
+ summaries.append(led_abstractive_summary(chunk)) # Call your LED summary function here
224
+
225
+ return " ".join(summaries)
226
+
227
+
228
+
229
+ def hybrid_summary_hierarchical(text, top_ratio=0.8):
230
+ cleaned_text = clean_text(text)
231
+ sections = section_by_zero_shot(cleaned_text)
232
+
233
+ structured_summary = {} # <-- hierarchical summary here
234
+
235
+ for name, content in sections.items():
236
+ if content.strip():
237
+ # Extractive summary
238
+ extractive = legalbert_extractive_summary(content, top_ratio)
239
+
240
+ # Abstractive summary
241
+ abstractive = led_abstractive_summary_chunked(extractive)
242
+
243
+ # Store in dictionary (hierarchical structure)
244
+ structured_summary[name] = {
245
+ "extractive": extractive,
246
+ "abstractive": abstractive
247
+ }
248
+
249
+ return structured_summary
250
+
251
+
252
+ #######################################################################################################################
253
+
254
+
255
+ # STREAMLIT APP INTERFACE CODE
256
+
257
+ # Initialize or load chat history
258
+ if "messages" not in st.session_state:
259
+ st.session_state.messages = load_chat_history()
260
+
261
+ # Initialize last_uploaded if not set
262
+ if "last_uploaded" not in st.session_state:
263
+ st.session_state.last_uploaded = None
264
+
265
+ # Sidebar with a button to delete chat history
266
+ with st.sidebar:
267
+ st.subheader("βš™οΈ Options")
268
+ if st.button("Delete Chat History"):
269
+ st.session_state.messages = []
270
+ st.session_state.last_uploaded = None
271
+ save_chat_history([])
272
+
273
+ # Display chat messages with a typing effect
274
+ def display_with_typing_effect(text, speed=0.005):
275
+ placeholder = st.empty()
276
+ displayed_text = ""
277
+ for char in text:
278
+ displayed_text += char
279
+ placeholder.markdown(displayed_text)
280
+ time.sleep(speed)
281
+ return displayed_text
282
+
283
+ # Show existing chat messages
284
+ for message in st.session_state.messages:
285
+ avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
286
+ with st.chat_message(message["role"], avatar=avatar):
287
+ st.markdown(message["content"])
288
+
289
+
290
+ # Standard chat input field
291
+ prompt = st.chat_input("Type a message...")
292
+
293
+
294
+ # Place uploader before the chat so it's always visible
295
+ with st.container():
296
+ st.subheader("πŸ“Ž Upload a Legal Document")
297
+ uploaded_file = st.file_uploader("Upload a file (PDF, DOCX, TXT)", type=["pdf", "docx", "txt"])
298
+ reprocess_btn = st.button("πŸ”„ Reprocess Last Uploaded File")
299
+
300
+
301
+ # Hashing logic
302
+ def get_file_hash(file):
303
+ file.seek(0)
304
+ content = file.read()
305
+ file.seek(0)
306
+ return hashlib.md5(content).hexdigest()
307
+
308
+
309
+ ##############################################################################################################
310
+
311
+ user_role = st.sidebar.selectbox(
312
+ "🎭 Select Your Role for Custom Summary",
313
+ ["General", "Judge", "Lawyer", "Student"]
314
+ )
315
+
316
+
317
+ def role_based_filter(section, summary, role):
318
+ if role == "General":
319
+ return summary
320
+
321
+ filtered_summary = {
322
+ "extractive": "",
323
+ "abstractive": ""
324
+ }
325
+
326
+ if role == "Judge" and section in ["Judgment", "Facts"]:
327
+ filtered_summary = summary
328
+ elif role == "Lawyer" and section in ["Arguments", "Facts"]:
329
+ filtered_summary = summary
330
+ elif role == "Student" and section in ["Facts"]:
331
+ filtered_summary = summary
332
+
333
+ return filtered_summary
334
+
335
+
336
+
337
+ if uploaded_file:
338
+ file_hash = get_file_hash(uploaded_file)
339
+
340
+ # Check if file is new OR reprocess is triggered
341
+ if file_hash != st.session_state.get("last_uploaded_hash") or reprocess_btn:
342
+
343
+ start_time = time.time() # Start the timer
344
+
345
+ raw_text = extract_text(uploaded_file)
346
+
347
+ summary_dict = hybrid_summary_hierarchical(raw_text)
348
+
349
+ st.session_state.messages.append({
350
+ "role": "user",
351
+ "content": f"πŸ“€ Uploaded **{uploaded_file.name}**"
352
+ })
353
+
354
+
355
+ # Start building preview
356
+ preview_text = f"🧾 **Hybrid Summary of {uploaded_file.name}:**\n\n"
357
+
358
+
359
+ for section in ["Facts", "Arguments", "Judgment", "Other"]:
360
+ if section in summary_dict:
361
+
362
+ filtered = role_based_filter(section, summary_dict[section], user_role)
363
+
364
+ extractive = filtered.get("extractive", "").strip()
365
+ abstractive = filtered.get("abstractive", "").strip()
366
+
367
+ if not extractive and not abstractive:
368
+ continue # Skip if empty after filtering
369
+
370
+ preview_text += f"### πŸ“˜ {section} Section\n"
371
+ preview_text += f"πŸ“Œ **Extractive Summary:**\n{extractive if extractive else '_No content extracted._'}\n\n"
372
+ preview_text += f"πŸ” **Abstractive Summary:**\n{abstractive if abstractive else '_No summary generated._'}\n\n"
373
+
374
+
375
+ # Display in chat
376
+ with st.chat_message("assistant", avatar=BOT_AVATAR):
377
+ display_with_typing_effect(clean_text(preview_text), speed=0)
378
+
379
+ # Show processing time after the summary
380
+ processing_time = round(time.time() - start_time, 2)
381
+ st.session_state["last_response_time"] = processing_time
382
+
383
+ if "last_response_time" in st.session_state:
384
+ st.info(f"⏱️ Response generated in **{st.session_state['last_response_time']} seconds**.")
385
+
386
+ st.session_state.messages.append({
387
+ "role": "assistant",
388
+ "content": clean_text(preview_text)
389
+ })
390
+
391
+ # Save this file hash only if it’s a new upload (avoid overwriting during reprocess)
392
+ if not reprocess_btn:
393
+ st.session_state.last_uploaded_hash = file_hash
394
+
395
+ save_chat_history(st.session_state.messages)
396
+
397
+ st.rerun()
398
+
399
+
400
+ # Handle chat input and return hybrid summary
401
+ if prompt:
402
+ raw_text = prompt
403
+ start_time = time.time()
404
+
405
+ summary_dict = hybrid_summary_hierarchical(raw_text)
406
+
407
+ st.session_state.messages.append({
408
+ "role": "user",
409
+ "content": prompt
410
+ })
411
+
412
+ # Start building preview
413
+ preview_text = f"🧾 **Hybrid Summary of {uploaded_file.name}:**\n\n"
414
+
415
+ for section in ["Facts", "Arguments", "Judgment", "Other"]:
416
+ if section in summary_dict:
417
+
418
+ filtered = role_based_filter(section, summary_dict[section], user_role)
419
+
420
+ extractive = filtered.get("extractive", "").strip()
421
+ abstractive = filtered.get("abstractive", "").strip()
422
+
423
+ if not extractive and not abstractive:
424
+ continue # Skip if empty after filtering
425
+
426
+ preview_text += f"### πŸ“˜ {section} Section\n"
427
+ preview_text += f"πŸ“Œ **Extractive Summary:**\n{extractive if extractive else '_No content extracted._'}\n\n"
428
+ preview_text += f"πŸ” **Abstractive Summary:**\n{abstractive if abstractive else '_No summary generated._'}\n\n"
429
+
430
+
431
+ # Display in chat
432
+ with st.chat_message("assistant", avatar=BOT_AVATAR):
433
+ display_with_typing_effect(clean_text(preview_text), speed=0)
434
+
435
+ # Show processing time after the summary
436
+ processing_time = round(time.time() - start_time, 2)
437
+ st.session_state["last_response_time"] = processing_time
438
+
439
+ if "last_response_time" in st.session_state:
440
+ st.info(f"⏱️ Response generated in **{st.session_state['last_response_time']} seconds**.")
441
+
442
+ st.session_state.messages.append({
443
+ "role": "assistant",
444
+ "content": clean_text(preview_text)
445
+ })
446
+
447
+ save_chat_history(st.session_state.messages)
448
+
449
+ st.rerun()