Hyma Roshini Gompa commited on
Commit
11cb5b8
·
1 Parent(s): 4bf0eb2

Uploading correct updated app.py

Browse files
Files changed (7) hide show
  1. .idea/.gitignore +8 -0
  2. .idea/LegalDoc.iml +9 -0
  3. .idea/misc.xml +6 -0
  4. .idea/modules.xml +8 -0
  5. .idea/vcs.xml +6 -0
  6. app.py +823 -449
  7. stage_4.py +0 -449
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/LegalDoc.iml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="JAVA_MODULE" version="4">
3
+ <component name="NewModuleRootManager" inherit-compiler-output="true">
4
+ <exclude-output />
5
+ <content url="file://$MODULE_DIR$" />
6
+ <orderEntry type="inheritedJdk" />
7
+ <orderEntry type="sourceFolder" forTests="false" />
8
+ </component>
9
+ </module>
.idea/misc.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" languageLevel="JDK_17" default="true" project-jdk-name="23" project-jdk-type="JavaSDK">
4
+ <output url="file://$PROJECT_DIR$/out" />
5
+ </component>
6
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/LegalDoc.iml" filepath="$PROJECT_DIR$/.idea/LegalDoc.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
app.py CHANGED
@@ -1,449 +1,823 @@
1
- import streamlit as st
2
- import shelve
3
- import docx2txt
4
- import PyPDF2
5
- import time # Used to simulate typing effect
6
- import nltk
7
- import re
8
- import os
9
- import time # already imported in your code
10
- import requests
11
- from dotenv import load_dotenv
12
- import torch
13
- from sentence_transformers import SentenceTransformer, util
14
- nltk.download('punkt')
15
- import hashlib
16
- from nltk import sent_tokenize
17
- nltk.download('punkt_tab')
18
- from transformers import LEDTokenizer, LEDForConditionalGeneration
19
- from transformers import pipeline
20
- import asyncio
21
- import sys
22
- # Fix for RuntimeError: no running event loop on Windows
23
- if sys.platform.startswith("win"):
24
- asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
25
-
26
-
27
- st.set_page_config(page_title="Legal Document Summarizer", layout="wide")
28
-
29
- st.title("📄 Legal Document Summarizer (stage 4 )")
30
-
31
- USER_AVATAR = "👤"
32
- BOT_AVATAR = "🤖"
33
-
34
- # Load chat history
35
- def load_chat_history():
36
- with shelve.open("chat_history") as db:
37
- return db.get("messages", [])
38
-
39
- # Save chat history
40
- def save_chat_history(messages):
41
- with shelve.open("chat_history") as db:
42
- db["messages"] = messages
43
-
44
- # Function to limit text preview to 500 words
45
- def limit_text(text, word_limit=500):
46
- words = text.split()
47
- return " ".join(words[:word_limit]) + ("..." if len(words) > word_limit else "")
48
-
49
-
50
- # CLEAN AND NORMALIZE TEXT
51
-
52
-
53
- def clean_text(text):
54
- # Remove newlines and extra spaces
55
- text = text.replace('\r\n', ' ').replace('\n', ' ')
56
- text = re.sub(r'\s+', ' ', text)
57
-
58
- # Remove page number markers like "Page 1 of 10"
59
- text = re.sub(r'Page\s+\d+\s+of\s+\d+', '', text, flags=re.IGNORECASE)
60
-
61
- # Remove long dashed or underscored lines
62
- text = re.sub(r'[_]{5,}', '', text) # Lines with underscores: _____
63
- text = re.sub(r'[-]{5,}', '', text) # Lines with hyphens: -----
64
-
65
- # Remove long dotted separators
66
- text = re.sub(r'[.]{4,}', '', text) # Dots like "......" or ".............."
67
-
68
- # Trim final leading/trailing whitespace
69
- text = text.strip()
70
-
71
- return text
72
-
73
-
74
- #######################################################################################################################
75
-
76
-
77
- # LOADING MODELS FOR DIVIDING TEXT INTO SECTIONS
78
-
79
- # Load token from .env file
80
- load_dotenv()
81
- HF_API_TOKEN = os.getenv("HF_API_TOKEN")
82
-
83
-
84
- # Load once at the top (cache for performance)
85
- @st.cache_resource
86
- def load_local_zero_shot_classifier():
87
- return pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")
88
-
89
- local_classifier = load_local_zero_shot_classifier()
90
-
91
-
92
- SECTION_LABELS = ["Facts", "Arguments", "Judgment", "Other"]
93
-
94
- def classify_chunk(text):
95
- result = local_classifier(text, candidate_labels=SECTION_LABELS)
96
- return result["labels"][0]
97
-
98
-
99
- # NEW: NLP-based sectioning using zero-shot classification
100
- def section_by_zero_shot(text):
101
- sections = {"Facts": "", "Arguments": "", "Judgment": "", "Other": ""}
102
- sentences = sent_tokenize(text)
103
- chunk = ""
104
-
105
- for i, sent in enumerate(sentences):
106
- chunk += sent + " "
107
- if (i + 1) % 3 == 0 or i == len(sentences) - 1:
108
- label = classify_chunk(chunk.strip())
109
- print(f"🔎 Chunk: {chunk[:60]}...\n🔖 Predicted Label: {label}")
110
- # 👇 Normalize label (title case and fallback)
111
- label = label.capitalize()
112
- if label not in sections:
113
- label = "Other"
114
- sections[label] += chunk + "\n"
115
- chunk = ""
116
-
117
- return sections
118
-
119
- #######################################################################################################################
120
-
121
-
122
-
123
- # EXTRACTING TEXT FROM UPLOADED FILES
124
-
125
- # Function to extract text from uploaded file
126
- def extract_text(file):
127
- if file.name.endswith(".pdf"):
128
- reader = PyPDF2.PdfReader(file)
129
- full_text = "\n".join(page.extract_text() or "" for page in reader.pages)
130
- elif file.name.endswith(".docx"):
131
- full_text = docx2txt.process(file)
132
- elif file.name.endswith(".txt"):
133
- full_text = file.read().decode("utf-8")
134
- else:
135
- return "Unsupported file type."
136
-
137
- return full_text # Full text is needed for summarization
138
-
139
-
140
- #######################################################################################################################
141
-
142
- # EXTRACTIVE AND ABSTRACTIVE SUMMARIZATION
143
-
144
-
145
- @st.cache_resource
146
- def load_legalbert():
147
- return SentenceTransformer("nlpaueb/legal-bert-base-uncased")
148
-
149
-
150
- legalbert_model = load_legalbert()
151
-
152
- @st.cache_resource
153
- def load_led():
154
- tokenizer = LEDTokenizer.from_pretrained("allenai/led-base-16384")
155
- model = LEDForConditionalGeneration.from_pretrained("allenai/led-base-16384")
156
- return tokenizer, model
157
-
158
- tokenizer_led, model_led = load_led()
159
-
160
-
161
- def legalbert_extractive_summary(text, top_ratio=0.2):
162
- sentences = sent_tokenize(text)
163
- top_k = max(3, int(len(sentences) * top_ratio))
164
-
165
- if len(sentences) <= top_k:
166
- return text
167
-
168
- # Embeddings & scoring
169
- sentence_embeddings = legalbert_model.encode(sentences, convert_to_tensor=True)
170
- doc_embedding = torch.mean(sentence_embeddings, dim=0)
171
- cosine_scores = util.pytorch_cos_sim(doc_embedding, sentence_embeddings)[0]
172
- top_results = torch.topk(cosine_scores, k=top_k)
173
-
174
- # Preserve original order
175
- selected_sentences = [sentences[i] for i in sorted(top_results.indices.tolist())]
176
- return " ".join(selected_sentences)
177
-
178
-
179
-
180
- # Add LED Abstractive Summarization
181
-
182
-
183
- def led_abstractive_summary(text, max_length=512, min_length=100):
184
- inputs = tokenizer_led(
185
- text, return_tensors="pt", padding="max_length",
186
- truncation=True, max_length=4096
187
- )
188
- global_attention_mask = torch.zeros_like(inputs["input_ids"])
189
- global_attention_mask[:, 0] = 1
190
-
191
- outputs = model_led.generate(
192
- inputs["input_ids"],
193
- attention_mask=inputs["attention_mask"],
194
- global_attention_mask=global_attention_mask,
195
- max_length=max_length,
196
- min_length=min_length,
197
- num_beams=4, # Use beam search
198
- repetition_penalty=2.0, # Penalize repetition
199
- length_penalty=1.0,
200
- early_stopping=True,
201
- no_repeat_ngram_size=4 # Prevent repeated phrases
202
- )
203
-
204
- return tokenizer_led.decode(outputs[0], skip_special_tokens=True)
205
-
206
-
207
-
208
- def led_abstractive_summary_chunked(text, max_tokens=3000):
209
- sentences = sent_tokenize(text)
210
- current_chunk = ""
211
- chunks = []
212
- for sent in sentences:
213
- if len(tokenizer_led(current_chunk + sent)["input_ids"]) > max_tokens:
214
- chunks.append(current_chunk)
215
- current_chunk = sent
216
- else:
217
- current_chunk += " " + sent
218
- if current_chunk:
219
- chunks.append(current_chunk)
220
-
221
- summaries = []
222
- for chunk in chunks:
223
- summaries.append(led_abstractive_summary(chunk)) # Call your LED summary function here
224
-
225
- return " ".join(summaries)
226
-
227
-
228
-
229
- def hybrid_summary_hierarchical(text, top_ratio=0.8):
230
- cleaned_text = clean_text(text)
231
- sections = section_by_zero_shot(cleaned_text)
232
-
233
- structured_summary = {} # <-- hierarchical summary here
234
-
235
- for name, content in sections.items():
236
- if content.strip():
237
- # Extractive summary
238
- extractive = legalbert_extractive_summary(content, top_ratio)
239
-
240
- # Abstractive summary
241
- abstractive = led_abstractive_summary_chunked(extractive)
242
-
243
- # Store in dictionary (hierarchical structure)
244
- structured_summary[name] = {
245
- "extractive": extractive,
246
- "abstractive": abstractive
247
- }
248
-
249
- return structured_summary
250
-
251
-
252
- #######################################################################################################################
253
-
254
-
255
- # STREAMLIT APP INTERFACE CODE
256
-
257
- # Initialize or load chat history
258
- if "messages" not in st.session_state:
259
- st.session_state.messages = load_chat_history()
260
-
261
- # Initialize last_uploaded if not set
262
- if "last_uploaded" not in st.session_state:
263
- st.session_state.last_uploaded = None
264
-
265
- # Sidebar with a button to delete chat history
266
- with st.sidebar:
267
- st.subheader("⚙️ Options")
268
- if st.button("Delete Chat History"):
269
- st.session_state.messages = []
270
- st.session_state.last_uploaded = None
271
- save_chat_history([])
272
-
273
- # Display chat messages with a typing effect
274
- def display_with_typing_effect(text, speed=0.005):
275
- placeholder = st.empty()
276
- displayed_text = ""
277
- for char in text:
278
- displayed_text += char
279
- placeholder.markdown(displayed_text)
280
- time.sleep(speed)
281
- return displayed_text
282
-
283
- # Show existing chat messages
284
- for message in st.session_state.messages:
285
- avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
286
- with st.chat_message(message["role"], avatar=avatar):
287
- st.markdown(message["content"])
288
-
289
-
290
- # Standard chat input field
291
- prompt = st.chat_input("Type a message...")
292
-
293
-
294
- # Place uploader before the chat so it's always visible
295
- with st.container():
296
- st.subheader("📎 Upload a Legal Document")
297
- uploaded_file = st.file_uploader("Upload a file (PDF, DOCX, TXT)", type=["pdf", "docx", "txt"])
298
- reprocess_btn = st.button("🔄 Reprocess Last Uploaded File")
299
-
300
-
301
- # Hashing logic
302
- def get_file_hash(file):
303
- file.seek(0)
304
- content = file.read()
305
- file.seek(0)
306
- return hashlib.md5(content).hexdigest()
307
-
308
-
309
- ##############################################################################################################
310
-
311
- user_role = st.sidebar.selectbox(
312
- "🎭 Select Your Role for Custom Summary",
313
- ["General", "Judge", "Lawyer", "Student"]
314
- )
315
-
316
-
317
- def role_based_filter(section, summary, role):
318
- if role == "General":
319
- return summary
320
-
321
- filtered_summary = {
322
- "extractive": "",
323
- "abstractive": ""
324
- }
325
-
326
- if role == "Judge" and section in ["Judgment", "Facts"]:
327
- filtered_summary = summary
328
- elif role == "Lawyer" and section in ["Arguments", "Facts"]:
329
- filtered_summary = summary
330
- elif role == "Student" and section in ["Facts"]:
331
- filtered_summary = summary
332
-
333
- return filtered_summary
334
-
335
-
336
-
337
- if uploaded_file:
338
- file_hash = get_file_hash(uploaded_file)
339
-
340
- # Check if file is new OR reprocess is triggered
341
- if file_hash != st.session_state.get("last_uploaded_hash") or reprocess_btn:
342
-
343
- start_time = time.time() # Start the timer
344
-
345
- raw_text = extract_text(uploaded_file)
346
-
347
- summary_dict = hybrid_summary_hierarchical(raw_text)
348
-
349
- st.session_state.messages.append({
350
- "role": "user",
351
- "content": f"📤 Uploaded **{uploaded_file.name}**"
352
- })
353
-
354
-
355
- # Start building preview
356
- preview_text = f"🧾 **Hybrid Summary of {uploaded_file.name}:**\n\n"
357
-
358
-
359
- for section in ["Facts", "Arguments", "Judgment", "Other"]:
360
- if section in summary_dict:
361
-
362
- filtered = role_based_filter(section, summary_dict[section], user_role)
363
-
364
- extractive = filtered.get("extractive", "").strip()
365
- abstractive = filtered.get("abstractive", "").strip()
366
-
367
- if not extractive and not abstractive:
368
- continue # Skip if empty after filtering
369
-
370
- preview_text += f"### 📘 {section} Section\n"
371
- preview_text += f"📌 **Extractive Summary:**\n{extractive if extractive else '_No content extracted._'}\n\n"
372
- preview_text += f"🔍 **Abstractive Summary:**\n{abstractive if abstractive else '_No summary generated._'}\n\n"
373
-
374
-
375
- # Display in chat
376
- with st.chat_message("assistant", avatar=BOT_AVATAR):
377
- display_with_typing_effect(clean_text(preview_text), speed=0)
378
-
379
- # Show processing time after the summary
380
- processing_time = round(time.time() - start_time, 2)
381
- st.session_state["last_response_time"] = processing_time
382
-
383
- if "last_response_time" in st.session_state:
384
- st.info(f"⏱️ Response generated in **{st.session_state['last_response_time']} seconds**.")
385
-
386
- st.session_state.messages.append({
387
- "role": "assistant",
388
- "content": clean_text(preview_text)
389
- })
390
-
391
- # Save this file hash only if it’s a new upload (avoid overwriting during reprocess)
392
- if not reprocess_btn:
393
- st.session_state.last_uploaded_hash = file_hash
394
-
395
- save_chat_history(st.session_state.messages)
396
-
397
- st.rerun()
398
-
399
-
400
- # Handle chat input and return hybrid summary
401
- if prompt:
402
- raw_text = prompt
403
- start_time = time.time()
404
-
405
- summary_dict = hybrid_summary_hierarchical(raw_text)
406
-
407
- st.session_state.messages.append({
408
- "role": "user",
409
- "content": prompt
410
- })
411
-
412
- # Start building preview
413
- preview_text = f"🧾 **Hybrid Summary of {uploaded_file.name}:**\n\n"
414
-
415
- for section in ["Facts", "Arguments", "Judgment", "Other"]:
416
- if section in summary_dict:
417
-
418
- filtered = role_based_filter(section, summary_dict[section], user_role)
419
-
420
- extractive = filtered.get("extractive", "").strip()
421
- abstractive = filtered.get("abstractive", "").strip()
422
-
423
- if not extractive and not abstractive:
424
- continue # Skip if empty after filtering
425
-
426
- preview_text += f"### 📘 {section} Section\n"
427
- preview_text += f"📌 **Extractive Summary:**\n{extractive if extractive else '_No content extracted._'}\n\n"
428
- preview_text += f"🔍 **Abstractive Summary:**\n{abstractive if abstractive else '_No summary generated._'}\n\n"
429
-
430
-
431
- # Display in chat
432
- with st.chat_message("assistant", avatar=BOT_AVATAR):
433
- display_with_typing_effect(clean_text(preview_text), speed=0)
434
-
435
- # Show processing time after the summary
436
- processing_time = round(time.time() - start_time, 2)
437
- st.session_state["last_response_time"] = processing_time
438
-
439
- if "last_response_time" in st.session_state:
440
- st.info(f"⏱️ Response generated in **{st.session_state['last_response_time']} seconds**.")
441
-
442
- st.session_state.messages.append({
443
- "role": "assistant",
444
- "content": clean_text(preview_text)
445
- })
446
-
447
- save_chat_history(st.session_state.messages)
448
-
449
- st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import shelve
3
+ import docx2txt
4
+ import PyPDF2
5
+ import time # Used to simulate typing effect
6
+ import nltk
7
+ import re
8
+ import os
9
+ import time # already imported in your code
10
+ from dotenv import load_dotenv
11
+ import torch
12
+ from sentence_transformers import SentenceTransformer, util
13
+ nltk.download('punkt')
14
+ import hashlib
15
+ from nltk import sent_tokenize
16
+ nltk.download('punkt_tab')
17
+ from transformers import LEDTokenizer, LEDForConditionalGeneration
18
+ from transformers import pipeline
19
+ import asyncio
20
+ import dateutil.parser
21
+ from datetime import datetime
22
+ import sys
23
+
24
+ from openai import OpenAI
25
+ import numpy as np
26
+
27
+
28
+ # Fix for RuntimeError: no running event loop on Windows
29
+ if sys.platform.startswith("win"):
30
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
31
+
32
+ st.set_page_config(page_title="Legal Document Summarizer", layout="wide")
33
+
34
+ if "processed" not in st.session_state:
35
+ st.session_state.processed = False
36
+ if "last_uploaded_hash" not in st.session_state:
37
+ st.session_state.last_uploaded_hash = None
38
+ if "chat_prompt_processed" not in st.session_state:
39
+ st.session_state.chat_prompt_processed = False
40
+
41
+ if "embedding_text" not in st.session_state:
42
+ st.session_state.embedding_text = None
43
+
44
+ if "document_context" not in st.session_state:
45
+ st.session_state.document_context = None
46
+
47
+ if "last_prompt_hash" not in st.session_state:
48
+ st.session_state.last_prompt_hash = None
49
+
50
+
51
+ st.title("📄 Legal Document Summarizer (Simple RAG with evaluation results)")
52
+
53
+ USER_AVATAR = "👤"
54
+ BOT_AVATAR = "🤖"
55
+
56
+ # Load chat history
57
+ def load_chat_history():
58
+ with shelve.open("chat_history") as db:
59
+ return db.get("messages", [])
60
+
61
+ # Save chat history
62
+ def save_chat_history(messages):
63
+ with shelve.open("chat_history") as db:
64
+ db["messages"] = messages
65
+
66
+ # Function to limit text preview to 500 words
67
+ def limit_text(text, word_limit=500):
68
+ words = text.split()
69
+ return " ".join(words[:word_limit]) + ("..." if len(words) > word_limit else "")
70
+
71
+
72
+ # CLEAN AND NORMALIZE TEXT
73
+
74
+
75
+ def clean_text(text):
76
+ # Remove newlines and extra spaces
77
+ text = text.replace('\r\n', ' ').replace('\n', ' ')
78
+ text = re.sub(r'\s+', ' ', text)
79
+
80
+ # Remove page number markers like "Page 1 of 10"
81
+ text = re.sub(r'Page\s+\d+\s+of\s+\d+', '', text, flags=re.IGNORECASE)
82
+
83
+ # Remove long dashed or underscored lines
84
+ text = re.sub(r'[_]{5,}', '', text) # Lines with underscores: _____
85
+ text = re.sub(r'[-]{5,}', '', text) # Lines with hyphens: -----
86
+
87
+ # Remove long dotted separators
88
+ text = re.sub(r'[.]{4,}', '', text) # Dots like "......" or ".............."
89
+
90
+ # Trim final leading/trailing whitespace
91
+ text = text.strip()
92
+
93
+ return text
94
+
95
+
96
+ #######################################################################################################################
97
+
98
+
99
+ # LOADING MODELS FOR DIVIDING TEXT INTO SECTIONS
100
+
101
+ # Load token from .env file
102
+ load_dotenv()
103
+ HF_API_TOKEN = os.getenv("HF_API_TOKEN")
104
+
105
+ client = OpenAI(
106
+ base_url="https://api.studio.nebius.com/v1/",
107
+ api_key=os.getenv("OPENAI_API_KEY")
108
+ )
109
+
110
+ # print("API Key:", os.getenv("OPENAI_API_KEY")) # Temporary for debugging
111
+
112
+
113
+ # Load once at the top (cache for performance)
114
+ @st.cache_resource
115
+ def load_local_zero_shot_classifier():
116
+ return pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")
117
+
118
+ local_classifier = load_local_zero_shot_classifier()
119
+
120
+
121
+ SECTION_LABELS = ["Facts", "Arguments", "Judgement", "Others"]
122
+
123
+ def classify_chunk(text):
124
+ result = local_classifier(text, candidate_labels=SECTION_LABELS)
125
+ return result["labels"][0]
126
+
127
+
128
+ # NEW: NLP-based sectioning using zero-shot classification
129
+ def section_by_zero_shot(text):
130
+ sections = {"Facts": "", "Arguments": "", "Judgment": "", "Others": ""}
131
+ sentences = sent_tokenize(text)
132
+ chunk = ""
133
+
134
+ for i, sent in enumerate(sentences):
135
+ chunk += sent + " "
136
+ if (i + 1) % 3 == 0 or i == len(sentences) - 1:
137
+ label = classify_chunk(chunk.strip())
138
+ print(f"🔎 Chunk: {chunk[:60]}...\n🔖 Predicted Label: {label}")
139
+ # 👇 Normalize label (title case and fallback)
140
+ label = label.capitalize()
141
+ if label not in sections:
142
+ label = "Others"
143
+ sections[label] += chunk + "\n"
144
+ chunk = ""
145
+
146
+ return sections
147
+
148
+ #######################################################################################################################
149
+
150
+
151
+
152
+ # EXTRACTING TEXT FROM UPLOADED FILES
153
+
154
+ # Function to extract text from uploaded file
155
+ def extract_text(file):
156
+ if file.name.endswith(".pdf"):
157
+ reader = PyPDF2.PdfReader(file)
158
+ full_text = "\n".join(page.extract_text() or "" for page in reader.pages)
159
+ elif file.name.endswith(".docx"):
160
+ full_text = docx2txt.process(file)
161
+ elif file.name.endswith(".txt"):
162
+ full_text = file.read().decode("utf-8")
163
+ else:
164
+ return "Unsupported file type."
165
+
166
+ return full_text # Full text is needed for summarization
167
+
168
+
169
+ #######################################################################################################################
170
+
171
+ # EXTRACTIVE AND ABSTRACTIVE SUMMARIZATION
172
+
173
+
174
+ @st.cache_resource
175
+ def load_legalbert():
176
+ return SentenceTransformer("nlpaueb/legal-bert-base-uncased")
177
+
178
+
179
+ legalbert_model = load_legalbert()
180
+
181
+ @st.cache_resource
182
+ def load_led():
183
+ tokenizer = LEDTokenizer.from_pretrained("allenai/led-base-16384")
184
+ model = LEDForConditionalGeneration.from_pretrained("allenai/led-base-16384")
185
+ return tokenizer, model
186
+
187
+ tokenizer_led, model_led = load_led()
188
+
189
+
190
+ def legalbert_extractive_summary(text, top_ratio=0.5):
191
+ sentences = sent_tokenize(text)
192
+ top_k = max(3, int(len(sentences) * top_ratio))
193
+ if len(sentences) <= top_k:
194
+ return text
195
+ sentence_embeddings = legalbert_model.encode(sentences, convert_to_tensor=True)
196
+ doc_embedding = torch.mean(sentence_embeddings, dim=0)
197
+ cosine_scores = util.pytorch_cos_sim(doc_embedding, sentence_embeddings)[0]
198
+ top_results = torch.topk(cosine_scores, k=top_k)
199
+ selected_sentences = [sentences[i] for i in sorted(top_results.indices.tolist())]
200
+ return " ".join(selected_sentences)
201
+
202
+ # Add LED Abstractive Summarization
203
+
204
+
205
+ def led_abstractive_summary(text, max_length=512, min_length=100):
206
+ inputs = tokenizer_led(
207
+ text, return_tensors="pt", padding="max_length",
208
+ truncation=True, max_length=4096
209
+ )
210
+ global_attention_mask = torch.zeros_like(inputs["input_ids"])
211
+ global_attention_mask[:, 0] = 1
212
+
213
+ outputs = model_led.generate(
214
+ inputs["input_ids"],
215
+ attention_mask=inputs["attention_mask"],
216
+ global_attention_mask=global_attention_mask,
217
+ max_length=max_length,
218
+ min_length=min_length,
219
+ num_beams=4, # Use beam search
220
+ repetition_penalty=2.0, # Penalize repetition
221
+ length_penalty=1.0,
222
+ early_stopping=True,
223
+ no_repeat_ngram_size=4 # Prevent repeated phrases
224
+ )
225
+
226
+ return tokenizer_led.decode(outputs[0], skip_special_tokens=True)
227
+
228
+
229
+
230
+ def led_abstractive_summary_chunked(text, max_tokens=3000):
231
+ sentences = sent_tokenize(text)
232
+ current_chunk, chunks, summaries = "", [], []
233
+ for sent in sentences:
234
+ if len(tokenizer_led(current_chunk + sent)["input_ids"]) > max_tokens:
235
+ chunks.append(current_chunk)
236
+ current_chunk = sent
237
+ else:
238
+ current_chunk += " " + sent
239
+ if current_chunk:
240
+ chunks.append(current_chunk)
241
+ for chunk in chunks:
242
+ inputs = tokenizer_led(chunk, return_tensors="pt", padding="max_length", truncation=True, max_length=4096)
243
+ global_attention_mask = torch.zeros_like(inputs["input_ids"])
244
+ global_attention_mask[:, 0] = 1
245
+ output = model_led.generate(
246
+ inputs["input_ids"],
247
+ attention_mask=inputs["attention_mask"],
248
+ global_attention_mask=global_attention_mask,
249
+ max_length=512,
250
+ min_length=100,
251
+ num_beams=4,
252
+ repetition_penalty=2.0,
253
+ length_penalty=1.0,
254
+ early_stopping=True,
255
+ no_repeat_ngram_size=4,
256
+ )
257
+ summaries.append(tokenizer_led.decode(output[0], skip_special_tokens=True))
258
+ return " ".join(summaries)
259
+
260
+
261
+
262
+ def extract_timeline(text):
263
+ sentences = sent_tokenize(text)
264
+ timeline = []
265
+
266
+ for sentence in sentences:
267
+ try:
268
+ # Try fuzzy parsing on the sentence
269
+ parsed = dateutil.parser.parse(sentence, fuzzy=True)
270
+
271
+ # Validate year: exclude years before 1950 unless explicitly whitelisted
272
+ current_year = datetime.now().year
273
+ if 1900 <= parsed.year <= current_year + 5:
274
+ # Additional filtering: discard misleading past years unless contextually valid
275
+ if parsed.year < 1950 and parsed.year not in [2020, 2022, 2023]:
276
+ continue
277
+
278
+ # Further validation: ignore obviously wrong patterns like years starting with 0
279
+ if re.match(r"^0\d{3}$", str(parsed.year)):
280
+ continue
281
+
282
+ # Passed all checks
283
+ timeline.append((parsed.date(), sentence.strip()))
284
+ except Exception:
285
+ continue
286
+
287
+ # Remove duplicates and sort
288
+ unique_timeline = list(set(timeline))
289
+ return sorted(unique_timeline, key=lambda x: x[0])
290
+
291
+
292
+
293
+ def format_timeline_for_chat(timeline_data):
294
+ if not timeline_data:
295
+ return "_No significant timeline events detected._"
296
+
297
+ formatted = "🗓️ **Timeline of Events**\n\n"
298
+ for date, event in timeline_data:
299
+ formatted += f"**{date.strftime('%Y-%m-%d')}**: {event}\n\n"
300
+ return formatted.strip()
301
+
302
+
303
+
304
+ def hybrid_summary_hierarchical(text, top_ratio=0.8):
305
+ cleaned_text = clean_text(text)
306
+ sections = section_by_zero_shot(cleaned_text)
307
+
308
+ structured_summary = {} # <-- hierarchical summary here
309
+
310
+ for name, content in sections.items():
311
+ if content.strip():
312
+ # Extractive summary
313
+ extractive = legalbert_extractive_summary(content, top_ratio)
314
+
315
+ # Abstractive summary
316
+ abstractive = led_abstractive_summary_chunked(extractive)
317
+
318
+ # Store in dictionary (hierarchical structure)
319
+ structured_summary[name] = {
320
+ "extractive": extractive,
321
+ "abstractive": abstractive
322
+ }
323
+
324
+ return structured_summary
325
+
326
+
327
+ from sentence_transformers import SentenceTransformer
328
+
329
+ @st.cache_resource
330
+ def load_embedder():
331
+ return SentenceTransformer("all-MiniLM-L6-v2")
332
+
333
+ embedder = load_embedder()
334
+
335
+ # import faiss
336
+ import numpy as np
337
+
338
+
339
+ # def build_faiss_index(chunks):
340
+ # embedder = load_embedder()
341
+ # embeddings = embedder.encode(chunks, convert_to_tensor=False)
342
+ # dimension = embeddings[0].shape[0]
343
+ # index = faiss.IndexFlatL2(dimension)
344
+ # index.add(np.array(embeddings).astype("float32"))
345
+ # st.session_state["embedder"] = embedder
346
+ # return index, chunks # ✅ Return both
347
+
348
+
349
+ def retrieve_top_k(query, chunks, index, k=3):
350
+ query_vec = embedder.encode([query])
351
+ D, I = index.search(np.array(query_vec).astype("float32"), k)
352
+ return [chunks[i] for i in I[0]]
353
+
354
+
355
+ def chunk_text_custom(text, n=1000, overlap=200):
356
+ chunks = []
357
+ for i in range(0, len(text), n - overlap):
358
+ chunks.append(text[i:i + n])
359
+ return chunks
360
+
361
+ def create_embeddings(text_chunks, model="BAAI/bge-en-icl"):
362
+ response = client.embeddings.create(
363
+ model=model,
364
+ input=text_chunks
365
+ )
366
+ return response.data
367
+
368
+ def cosine_similarity(vec1, vec2):
369
+ return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
370
+
371
+
372
+ def semantic_search(query, text_chunks, chunk_embeddings, k=7):
373
+ query_embedding = create_embeddings([query])[0].embedding
374
+ scores = [(i, cosine_similarity(np.array(query_embedding), np.array(emb.embedding))) for i, emb in enumerate(chunk_embeddings)]
375
+ top_indices = [idx for idx, _ in sorted(scores, key=lambda x: x[1], reverse=True)[:k]]
376
+ return [text_chunks[i] for i in top_indices]
377
+
378
+
379
+
380
+ def generate_response(system_prompt, user_message, model="meta-llama/Llama-3.2-3B-Instruct"):
381
+ return client.chat.completions.create(
382
+ model=model,
383
+ temperature=0,
384
+ messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_message}]
385
+ ).choices[0].message.content
386
+
387
+
388
+ def rag_query_response(prompt, embedding_text):
389
+ chunks = chunk_text_custom(embedding_text)
390
+ chunk_embeddings = create_embeddings(chunks)
391
+ top_chunks = semantic_search(prompt, chunks, chunk_embeddings, k=5)
392
+ context_block = "\n\n".join([f"Context {i+1}:\n{chunk}" for i, chunk in enumerate(top_chunks)])
393
+ user_prompt = f"{context_block}\n\nQuestion: {prompt}"
394
+ system_instruction = (
395
+ "You are an AI assistant that strictly answers based on the given context. "
396
+ "If the answer cannot be derived directly from the context, respond: 'I do not have enough information to answer that.'"
397
+ )
398
+ return generate_response(system_instruction, user_prompt)
399
+
400
+
401
+
402
+
403
+ #######################################################################################################################
404
+
405
+
406
+ # STREAMLIT APP INTERFACE CODE
407
+
408
+ # Initialize or load chat history
409
+ if "messages" not in st.session_state:
410
+ st.session_state.messages = load_chat_history()
411
+
412
+ # Initialize last_uploaded if not set
413
+ if "last_uploaded" not in st.session_state:
414
+ st.session_state.last_uploaded = None
415
+
416
+
417
+
418
+ # Sidebar with a button to delete chat history
419
+ with st.sidebar:
420
+ st.subheader("⚙️ Options")
421
+ if st.button("Delete Chat History"):
422
+ st.session_state.messages = []
423
+ st.session_state.last_uploaded = None
424
+ st.session_state.processed = False
425
+ st.session_state.chat_prompt_processed = False
426
+ save_chat_history([])
427
+
428
+
429
+ # Display chat messages with a typing effect
430
+ def display_with_typing_effect(text, speed=0.005):
431
+ placeholder = st.empty()
432
+ displayed_text = ""
433
+ for char in text:
434
+ displayed_text += char
435
+ placeholder.markdown(displayed_text)
436
+ time.sleep(speed)
437
+ return displayed_text
438
+
439
+ # Show existing chat messages
440
+ for message in st.session_state.messages:
441
+ avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
442
+ with st.chat_message(message["role"], avatar=avatar):
443
+ st.markdown(message["content"])
444
+
445
+
446
+ # Standard chat input field
447
+ prompt = st.chat_input("Type a message...")
448
+
449
+
450
+ # Place uploader before the chat so it's always visible
451
+ with st.container():
452
+ st.subheader("📎 Upload a Legal Document")
453
+ uploaded_file = st.file_uploader("Upload a file (PDF, DOCX, TXT)", type=["pdf", "docx", "txt"])
454
+ reprocess_btn = st.button("🔄 Reprocess Last Uploaded File")
455
+
456
+
457
+
458
+ # Hashing logic
459
+ def get_file_hash(file):
460
+ file.seek(0)
461
+ content = file.read()
462
+ file.seek(0)
463
+ return hashlib.md5(content).hexdigest()
464
+
465
+ # Function to prepare text for embedding
466
+ # This function combines the extractive and abstractive summaries into a single string for embedding
467
+ def prepare_text_for_embedding(summary_dict, timeline_data):
468
+ combined_chunks = []
469
+
470
+ for section, content in summary_dict.items():
471
+ ext = content.get("extractive", "").strip()
472
+ abs = content.get("abstractive", "").strip()
473
+ if ext:
474
+ combined_chunks.append(f"{section} - Extractive Summary:\n{ext}")
475
+ if abs:
476
+ combined_chunks.append(f"{section} - Abstractive Summary:\n{abs}")
477
+
478
+ if timeline_data:
479
+
480
+ combined_chunks.append("Timeline of Events:\n")
481
+ for date, event in timeline_data:
482
+ combined_chunks.append(f"{date.strftime('%Y-%m-%d')}: {event.strip()}")
483
+
484
+ return "\n\n".join(combined_chunks)
485
+
486
+
487
+ ###################################################################################################################
488
+
489
+ # Store cleaned text and FAISS index only when document is processed
490
+
491
+ # Embedding for chunking
492
+
493
+
494
+ def chunk_text(text, max_tokens=100):
495
+ sentences = sent_tokenize(text)
496
+ chunks, current_chunk = [], ""
497
+
498
+ for sentence in sentences:
499
+ if len(current_chunk.split()) + len(sentence.split()) > max_tokens:
500
+ chunks.append(current_chunk.strip())
501
+ current_chunk = sentence
502
+ else:
503
+ current_chunk += " " + sentence
504
+ if current_chunk:
505
+ chunks.append(current_chunk.strip())
506
+
507
+ return chunks
508
+
509
+
510
+
511
+ ##############################################################################################################
512
+
513
+ user_role = st.sidebar.selectbox(
514
+ "🎭 Select Your Role for Custom Summary",
515
+ ["General", "Judge", "Lawyer", "Student"]
516
+ )
517
+
518
+
519
+ def role_based_filter(section, summary, role):
520
+ if role == "General":
521
+ return summary
522
+
523
+ filtered_summary = {
524
+ "extractive": "",
525
+ "abstractive": ""
526
+ }
527
+
528
+ if role == "Judge" and section in ["Judgement", "Facts"]:
529
+ filtered_summary = summary
530
+ elif role == "Lawyer" and section in ["Arguments", "Facts"]:
531
+ filtered_summary = summary
532
+ elif role == "Student" and section in ["Facts"]:
533
+ filtered_summary = summary
534
+
535
+ return filtered_summary
536
+
537
+
538
+
539
+
540
+
541
+ #########################################################################################################################
542
+
543
+
544
+ if uploaded_file:
545
+ file_hash = get_file_hash(uploaded_file)
546
+ if file_hash != st.session_state.last_uploaded_hash or reprocess_btn:
547
+ st.session_state.processed = False
548
+
549
+ # if is_new_file or reprocess_btn:
550
+ # st.session_state.processed = False
551
+
552
+ if not st.session_state.processed:
553
+ start_time = time.time()
554
+ raw_text = extract_text(uploaded_file)
555
+ summary_dict = hybrid_summary_hierarchical(raw_text)
556
+ timeline_data = extract_timeline(clean_text(raw_text))
557
+ embedding_text = prepare_text_for_embedding(summary_dict, timeline_data)
558
+
559
+ # Generate and display RAG-based summary
560
+
561
+ st.session_state.document_context = embedding_text
562
+
563
+ role_specific_prompt = f"As a {user_role}, summarize the legal document focusing on the most relevant aspects such as facts, arguments, and judgments tailored for your role. Include key legal reasoning and timeline of events where necessary."
564
+ rag_summary = rag_query_response(role_specific_prompt, embedding_text)
565
+
566
+ st.session_state.generated_summary = rag_summary
567
+
568
+
569
+ st.session_state.messages.append({"role": "user", "content": f"📤 Uploaded **{uploaded_file.name}**"})
570
+ st.session_state.messages.append({"role": "assistant", "content": rag_summary})
571
+
572
+ with st.chat_message("assistant", avatar=BOT_AVATAR):
573
+ display_with_typing_effect(rag_summary)
574
+
575
+ processing_time = round((time.time() - start_time) / 60, 2)
576
+ st.info(f"⏱️ Response generated in **{processing_time} minutes**.")
577
+
578
+ st.session_state.last_uploaded_hash = file_hash
579
+ st.session_state.processed = True
580
+ st.session_state.last_prompt_hash = None
581
+ save_chat_history(st.session_state.messages)
582
+
583
+
584
+ # if prompt:
585
+ # word_count = len(prompt.split())
586
+ # # Document ingestion if long and not yet processed
587
+ # if word_count > 30 and not st.session_state.processed:
588
+ # raw_text = prompt
589
+ # start_time = time.time()
590
+ # summary_dict = hybrid_summary_hierarchical(raw_text)
591
+ # timeline_data = extract_timeline(clean_text(raw_text))
592
+ # embedding_text = prepare_text_for_embedding(summary_dict, timeline_data)
593
+
594
+ # # Save document context for future queries
595
+ # st.session_state.document_context = embedding_text
596
+ # st.session_state.processed = True
597
+
598
+ # # Initial role-based summary
599
+ # role_prompt = f"As a {user_role}, summarize the document focusing on facts, arguments, judgments, plus timeline of events."
600
+ # initial_summary = rag_query_response(role_prompt, embedding_text)
601
+ # st.session_state.messages.append({"role": "user", "content": "📥 Document ingested"})
602
+ # st.session_state.messages.append({"role": "assistant", "content": initial_summary})
603
+ # with st.chat_message("assistant", avatar=BOT_AVATAR):
604
+ # display_with_typing_effect(initial_summary)
605
+ # # Step 10: Show time
606
+ # processing_time = round((time.time() - start_time) / 60, 2)
607
+ # st.info(f"⏱️ Response generated in **{processing_time} minutes**.")
608
+ # save_chat_history(st.session_state.messages)
609
+
610
+ # # Querying phase: use existing document context
611
+ # elif st.session_state.processed:
612
+ # if not st.session_state.document_context:
613
+ # st.warning("⚠️ No document context found. Please upload or paste your document first (30+ words).")
614
+ # else:
615
+ # answer = rag_query_response(prompt, st.session_state.document_context)
616
+
617
+ # st.session_state.messages.append({"role": "user", "content": prompt})
618
+ # st.session_state.messages.append({"role": "assistant", "content": answer})
619
+ # with st.chat_message("assistant", avatar=BOT_AVATAR):
620
+ # display_with_typing_effect(answer)
621
+ # save_chat_history(st.session_state.messages)
622
+
623
+ # # Prompt too short and no document yet
624
+ # else:
625
+ # with st.chat_message("assistant", avatar=BOT_AVATAR):
626
+ # st.markdown("❗ Please first paste your document (more than 30 words), then ask questions.")
627
+
628
+
629
+ if prompt:
630
+ words = prompt.split()
631
+ word_count = len(words)
632
+
633
+ # compute a quick hash to detect “new” direct-paste
634
+ prompt_hash = hashlib.md5(prompt.encode("utf-8")).hexdigest()
635
+
636
+ # --- 1) LONG prompts always re-ingest as a NEW doc ---
637
+ if word_count > 30 and prompt_hash != st.session_state.last_prompt_hash:
638
+ # mark this as our new “last prompt”
639
+ st.session_state.last_prompt_hash = prompt_hash
640
+
641
+ # ingest exactly like you do for an uploaded file
642
+ raw_text = prompt
643
+ start_time = time.time()
644
+
645
+ summary_dict = hybrid_summary_hierarchical(raw_text)
646
+ timeline_data = extract_timeline(clean_text(raw_text))
647
+ emb_text = prepare_text_for_embedding(summary_dict, timeline_data)
648
+
649
+ # overwrite context
650
+ st.session_state.document_context = emb_text
651
+ st.session_state.processed = True
652
+
653
+ # produce your initial summary
654
+ role_prompt = (
655
+ f"As a {user_role}, summarize the document focusing on facts, "
656
+ "arguments, judgments, plus timeline of events."
657
+ )
658
+ initial_summary = rag_query_response(role_prompt, emb_text)
659
+
660
+ st.session_state.messages.append({"role":"user", "content":"📥 Document ingested"})
661
+ st.session_state.messages.append({"role":"assistant","content":initial_summary})
662
+ with st.chat_message("assistant", avatar=BOT_AVATAR):
663
+ display_with_typing_effect(initial_summary)
664
+
665
+ st.info(f"⏱️ Summary generated in {round((time.time()-start_time)/60,2)} minutes")
666
+ save_chat_history(st.session_state.messages)
667
+
668
+
669
+ # --- 2) SHORT prompts are queries against the last context ---
670
+ elif word_count <= 30 and st.session_state.processed:
671
+ answer = rag_query_response(prompt, st.session_state.document_context)
672
+ st.session_state.messages.append({"role":"user", "content":prompt})
673
+ st.session_state.messages.append({"role":"assistant", "content":answer})
674
+ with st.chat_message("assistant", avatar=BOT_AVATAR):
675
+ display_with_typing_effect(answer)
676
+ save_chat_history(st.session_state.messages)
677
+
678
+ # --- 3) anything else: ask them to paste something first ---
679
+ else:
680
+ with st.chat_message("assistant", avatar=BOT_AVATAR):
681
+ st.markdown("❗ Paste at least 30 words of your document to ingest it first.")
682
+
683
+
684
+
685
+ ######################################################################################################################### --- Evaluation Code Starts Here ---
686
+
687
+ import evaluate
688
+
689
+ # Load evaluators
690
+ rouge = evaluate.load("rouge")
691
+ bertscore = evaluate.load("bertscore")
692
+
693
+
694
+ def evaluate_summary(generated_summary, ground_truth_summary):
695
+ """Evaluate model-generated summary against ground truth."""
696
+ # Compute ROUGE
697
+ rouge_result = rouge.compute(predictions=[generated_summary], references=[ground_truth_summary])
698
+
699
+ # Compute BERTScore
700
+ bert_result = bertscore.compute(predictions=[generated_summary], references=[ground_truth_summary], lang="en")
701
+
702
+ return rouge_result, bert_result
703
+
704
+
705
+ # 🛑 Upload ground truth (fix file uploader text)
706
+ ground_truth_summary_file = st.file_uploader("📄 Upload Ground Truth Summary (.txt)", type=["txt"])
707
+
708
+ if ground_truth_summary_file:
709
+ ground_truth_summary = ground_truth_summary_file.read().decode("utf-8").strip()
710
+
711
+ # ⚡ Make sure you have generated_summary available
712
+ if "generated_summary" in st.session_state and st.session_state.generated_summary:
713
+
714
+ # Perform evaluation
715
+ rouge_result, bert_result = evaluate_summary(st.session_state.generated_summary, ground_truth_summary)
716
+
717
+ # Display Results
718
+ st.subheader("📊 Evaluation Results")
719
+
720
+ st.write("🔹 ROUGE Scores:")
721
+ st.json(rouge_result)
722
+
723
+ st.write("🔹 BERTScore:")
724
+ st.json(bert_result)
725
+
726
+ else:
727
+ st.warning("")
728
+
729
+
730
+
731
+
732
+
733
+ ######################################################################################################################
734
+
735
+
736
+ # Run this along with streamlit run app.py to evaluate the model's performance on a test set
737
+ # Otherwise, comment the below code
738
+
739
+ # ⇒ EVALUATION HOOK: after the very first summary, fire off evaluate.main() once
740
+
741
+ # import json
742
+ # import pandas as pd
743
+ # import threading
744
+ #
745
+ #
746
+ # def run_eval(doc_context):
747
+ #
748
+ # with open("test_case2.json", "r", encoding="utf-8") as f:
749
+ # gt_data = json.load(f)
750
+ #
751
+ # # 2) map document_id → local file
752
+ # doc_paths = {
753
+ # "case2": "case2.pdf",
754
+ # # add more if you have more documents
755
+ # }
756
+ #
757
+ # records = []
758
+ # for entry in gt_data:
759
+ # doc_id = entry["document_id"]
760
+ # query = entry["query"]
761
+ # gt_ans = entry["ground_truth_answer"]
762
+ #
763
+ #
764
+ # # model_ans = rag_query_response(query, emb_text)
765
+ # model_ans = rag_query_response(query, doc_context)
766
+ #
767
+ # records.append({
768
+ # "document_id": doc_id,
769
+ # "query": query,
770
+ # "ground_truth_answer": gt_ans,
771
+ # "model_answer": model_ans
772
+ # })
773
+ # print(f"✅ Done {doc_id} / “{query}”")
774
+ #
775
+ # # 3) push to DataFrame + CSV
776
+ # df = pd.DataFrame(records)
777
+ # out = "evaluation_results.csv"
778
+ # df.to_csv(out, index=False, encoding="utf-8")
779
+ # print(f"\n📝 Saved {len(df)} rows to {out}")
780
+ #
781
+ #
782
+ # # you could log this somewhere
783
+ # def _run_evaluation():
784
+ # try:
785
+ # run_eval()
786
+ # except Exception as e:
787
+ # print("‼️ Evaluation script error:", e)
788
+ #
789
+ # if st.session_state.processed and not st.session_state.get("evaluation_launched", False):
790
+ # st.session_state.evaluation_launched = True
791
+ #
792
+ # # inform user
793
+ # st.sidebar.info("🔬 Starting background evaluation run…")
794
+ #
795
+ # # *capture* the context
796
+ # doc_ctx = st.session_state.document_context
797
+ #
798
+ # # spawn the thread, passing doc_ctx in
799
+ # threading.Thread(
800
+ # target=lambda: run_eval(doc_ctx),
801
+ # daemon=True
802
+ # ).start()
803
+ #
804
+ # st.sidebar.success("✅ Evaluation launched — check evaluation_results.csv when done.")
805
+ #
806
+ # # check for file existence & show download button
807
+ # eval_path = os.path.abspath("evaluation_results.csv")
808
+ # if os.path.exists(eval_path):
809
+ # st.sidebar.success(f"✅ Results saved to:\n`{eval_path}`")
810
+ # # load it into a small dataframe (optional)
811
+ # df_eval = pd.read_csv(eval_path)
812
+ # # add a download button
813
+ # st.sidebar.download_button(
814
+ # label="⬇️ Download evaluation_results.csv",
815
+ # data=df_eval.to_csv(index=False).encode("utf-8"),
816
+ # file_name="evaluation_results.csv",
817
+ # mime="text/csv"
818
+ # )
819
+ # else:
820
+ # # if you want, display the cwd so you can inspect it
821
+ # st.sidebar.info(f"Current working dir:\n`{os.getcwd()}`")
822
+ #
823
+ #
stage_4.py DELETED
@@ -1,449 +0,0 @@
1
- import streamlit as st
2
- import shelve
3
- import docx2txt
4
- import PyPDF2
5
- import time # Used to simulate typing effect
6
- import nltk
7
- import re
8
- import os
9
- import time # already imported in your code
10
- import requests
11
- from dotenv import load_dotenv
12
- import torch
13
- from sentence_transformers import SentenceTransformer, util
14
- nltk.download('punkt')
15
- import hashlib
16
- from nltk import sent_tokenize
17
- nltk.download('punkt_tab')
18
- from transformers import LEDTokenizer, LEDForConditionalGeneration
19
- from transformers import pipeline
20
- import asyncio
21
- import sys
22
- # Fix for RuntimeError: no running event loop on Windows
23
- if sys.platform.startswith("win"):
24
- asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
25
-
26
-
27
- st.set_page_config(page_title="Legal Document Summarizer", layout="wide")
28
-
29
- st.title("📄 Legal Document Summarizer (stage 4 )")
30
-
31
- USER_AVATAR = "👤"
32
- BOT_AVATAR = "🤖"
33
-
34
- # Load chat history
35
- def load_chat_history():
36
- with shelve.open("chat_history") as db:
37
- return db.get("messages", [])
38
-
39
- # Save chat history
40
- def save_chat_history(messages):
41
- with shelve.open("chat_history") as db:
42
- db["messages"] = messages
43
-
44
- # Function to limit text preview to 500 words
45
- def limit_text(text, word_limit=500):
46
- words = text.split()
47
- return " ".join(words[:word_limit]) + ("..." if len(words) > word_limit else "")
48
-
49
-
50
- # CLEAN AND NORMALIZE TEXT
51
-
52
-
53
- def clean_text(text):
54
- # Remove newlines and extra spaces
55
- text = text.replace('\r\n', ' ').replace('\n', ' ')
56
- text = re.sub(r'\s+', ' ', text)
57
-
58
- # Remove page number markers like "Page 1 of 10"
59
- text = re.sub(r'Page\s+\d+\s+of\s+\d+', '', text, flags=re.IGNORECASE)
60
-
61
- # Remove long dashed or underscored lines
62
- text = re.sub(r'[_]{5,}', '', text) # Lines with underscores: _____
63
- text = re.sub(r'[-]{5,}', '', text) # Lines with hyphens: -----
64
-
65
- # Remove long dotted separators
66
- text = re.sub(r'[.]{4,}', '', text) # Dots like "......" or ".............."
67
-
68
- # Trim final leading/trailing whitespace
69
- text = text.strip()
70
-
71
- return text
72
-
73
-
74
- #######################################################################################################################
75
-
76
-
77
- # LOADING MODELS FOR DIVIDING TEXT INTO SECTIONS
78
-
79
- # Load token from .env file
80
- load_dotenv()
81
- HF_API_TOKEN = os.getenv("HF_API_TOKEN")
82
-
83
-
84
- # Load once at the top (cache for performance)
85
- @st.cache_resource
86
- def load_local_zero_shot_classifier():
87
- return pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")
88
-
89
- local_classifier = load_local_zero_shot_classifier()
90
-
91
-
92
- SECTION_LABELS = ["Facts", "Arguments", "Judgment", "Other"]
93
-
94
- def classify_chunk(text):
95
- result = local_classifier(text, candidate_labels=SECTION_LABELS)
96
- return result["labels"][0]
97
-
98
-
99
- # NEW: NLP-based sectioning using zero-shot classification
100
- def section_by_zero_shot(text):
101
- sections = {"Facts": "", "Arguments": "", "Judgment": "", "Other": ""}
102
- sentences = sent_tokenize(text)
103
- chunk = ""
104
-
105
- for i, sent in enumerate(sentences):
106
- chunk += sent + " "
107
- if (i + 1) % 3 == 0 or i == len(sentences) - 1:
108
- label = classify_chunk(chunk.strip())
109
- print(f"🔎 Chunk: {chunk[:60]}...\n🔖 Predicted Label: {label}")
110
- # 👇 Normalize label (title case and fallback)
111
- label = label.capitalize()
112
- if label not in sections:
113
- label = "Other"
114
- sections[label] += chunk + "\n"
115
- chunk = ""
116
-
117
- return sections
118
-
119
- #######################################################################################################################
120
-
121
-
122
-
123
- # EXTRACTING TEXT FROM UPLOADED FILES
124
-
125
- # Function to extract text from uploaded file
126
- def extract_text(file):
127
- if file.name.endswith(".pdf"):
128
- reader = PyPDF2.PdfReader(file)
129
- full_text = "\n".join(page.extract_text() or "" for page in reader.pages)
130
- elif file.name.endswith(".docx"):
131
- full_text = docx2txt.process(file)
132
- elif file.name.endswith(".txt"):
133
- full_text = file.read().decode("utf-8")
134
- else:
135
- return "Unsupported file type."
136
-
137
- return full_text # Full text is needed for summarization
138
-
139
-
140
- #######################################################################################################################
141
-
142
- # EXTRACTIVE AND ABSTRACTIVE SUMMARIZATION
143
-
144
-
145
- @st.cache_resource
146
- def load_legalbert():
147
- return SentenceTransformer("nlpaueb/legal-bert-base-uncased")
148
-
149
-
150
- legalbert_model = load_legalbert()
151
-
152
- @st.cache_resource
153
- def load_led():
154
- tokenizer = LEDTokenizer.from_pretrained("allenai/led-base-16384")
155
- model = LEDForConditionalGeneration.from_pretrained("allenai/led-base-16384")
156
- return tokenizer, model
157
-
158
- tokenizer_led, model_led = load_led()
159
-
160
-
161
- def legalbert_extractive_summary(text, top_ratio=0.2):
162
- sentences = sent_tokenize(text)
163
- top_k = max(3, int(len(sentences) * top_ratio))
164
-
165
- if len(sentences) <= top_k:
166
- return text
167
-
168
- # Embeddings & scoring
169
- sentence_embeddings = legalbert_model.encode(sentences, convert_to_tensor=True)
170
- doc_embedding = torch.mean(sentence_embeddings, dim=0)
171
- cosine_scores = util.pytorch_cos_sim(doc_embedding, sentence_embeddings)[0]
172
- top_results = torch.topk(cosine_scores, k=top_k)
173
-
174
- # Preserve original order
175
- selected_sentences = [sentences[i] for i in sorted(top_results.indices.tolist())]
176
- return " ".join(selected_sentences)
177
-
178
-
179
-
180
- # Add LED Abstractive Summarization
181
-
182
-
183
- def led_abstractive_summary(text, max_length=512, min_length=100):
184
- inputs = tokenizer_led(
185
- text, return_tensors="pt", padding="max_length",
186
- truncation=True, max_length=4096
187
- )
188
- global_attention_mask = torch.zeros_like(inputs["input_ids"])
189
- global_attention_mask[:, 0] = 1
190
-
191
- outputs = model_led.generate(
192
- inputs["input_ids"],
193
- attention_mask=inputs["attention_mask"],
194
- global_attention_mask=global_attention_mask,
195
- max_length=max_length,
196
- min_length=min_length,
197
- num_beams=4, # Use beam search
198
- repetition_penalty=2.0, # Penalize repetition
199
- length_penalty=1.0,
200
- early_stopping=True,
201
- no_repeat_ngram_size=4 # Prevent repeated phrases
202
- )
203
-
204
- return tokenizer_led.decode(outputs[0], skip_special_tokens=True)
205
-
206
-
207
-
208
- def led_abstractive_summary_chunked(text, max_tokens=3000):
209
- sentences = sent_tokenize(text)
210
- current_chunk = ""
211
- chunks = []
212
- for sent in sentences:
213
- if len(tokenizer_led(current_chunk + sent)["input_ids"]) > max_tokens:
214
- chunks.append(current_chunk)
215
- current_chunk = sent
216
- else:
217
- current_chunk += " " + sent
218
- if current_chunk:
219
- chunks.append(current_chunk)
220
-
221
- summaries = []
222
- for chunk in chunks:
223
- summaries.append(led_abstractive_summary(chunk)) # Call your LED summary function here
224
-
225
- return " ".join(summaries)
226
-
227
-
228
-
229
- def hybrid_summary_hierarchical(text, top_ratio=0.8):
230
- cleaned_text = clean_text(text)
231
- sections = section_by_zero_shot(cleaned_text)
232
-
233
- structured_summary = {} # <-- hierarchical summary here
234
-
235
- for name, content in sections.items():
236
- if content.strip():
237
- # Extractive summary
238
- extractive = legalbert_extractive_summary(content, top_ratio)
239
-
240
- # Abstractive summary
241
- abstractive = led_abstractive_summary_chunked(extractive)
242
-
243
- # Store in dictionary (hierarchical structure)
244
- structured_summary[name] = {
245
- "extractive": extractive,
246
- "abstractive": abstractive
247
- }
248
-
249
- return structured_summary
250
-
251
-
252
- #######################################################################################################################
253
-
254
-
255
- # STREAMLIT APP INTERFACE CODE
256
-
257
- # Initialize or load chat history
258
- if "messages" not in st.session_state:
259
- st.session_state.messages = load_chat_history()
260
-
261
- # Initialize last_uploaded if not set
262
- if "last_uploaded" not in st.session_state:
263
- st.session_state.last_uploaded = None
264
-
265
- # Sidebar with a button to delete chat history
266
- with st.sidebar:
267
- st.subheader("⚙️ Options")
268
- if st.button("Delete Chat History"):
269
- st.session_state.messages = []
270
- st.session_state.last_uploaded = None
271
- save_chat_history([])
272
-
273
- # Display chat messages with a typing effect
274
- def display_with_typing_effect(text, speed=0.005):
275
- placeholder = st.empty()
276
- displayed_text = ""
277
- for char in text:
278
- displayed_text += char
279
- placeholder.markdown(displayed_text)
280
- time.sleep(speed)
281
- return displayed_text
282
-
283
- # Show existing chat messages
284
- for message in st.session_state.messages:
285
- avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
286
- with st.chat_message(message["role"], avatar=avatar):
287
- st.markdown(message["content"])
288
-
289
-
290
- # Standard chat input field
291
- prompt = st.chat_input("Type a message...")
292
-
293
-
294
- # Place uploader before the chat so it's always visible
295
- with st.container():
296
- st.subheader("📎 Upload a Legal Document")
297
- uploaded_file = st.file_uploader("Upload a file (PDF, DOCX, TXT)", type=["pdf", "docx", "txt"])
298
- reprocess_btn = st.button("🔄 Reprocess Last Uploaded File")
299
-
300
-
301
- # Hashing logic
302
- def get_file_hash(file):
303
- file.seek(0)
304
- content = file.read()
305
- file.seek(0)
306
- return hashlib.md5(content).hexdigest()
307
-
308
-
309
- ##############################################################################################################
310
-
311
- user_role = st.sidebar.selectbox(
312
- "🎭 Select Your Role for Custom Summary",
313
- ["General", "Judge", "Lawyer", "Student"]
314
- )
315
-
316
-
317
- def role_based_filter(section, summary, role):
318
- if role == "General":
319
- return summary
320
-
321
- filtered_summary = {
322
- "extractive": "",
323
- "abstractive": ""
324
- }
325
-
326
- if role == "Judge" and section in ["Judgment", "Facts"]:
327
- filtered_summary = summary
328
- elif role == "Lawyer" and section in ["Arguments", "Facts"]:
329
- filtered_summary = summary
330
- elif role == "Student" and section in ["Facts"]:
331
- filtered_summary = summary
332
-
333
- return filtered_summary
334
-
335
-
336
-
337
- if uploaded_file:
338
- file_hash = get_file_hash(uploaded_file)
339
-
340
- # Check if file is new OR reprocess is triggered
341
- if file_hash != st.session_state.get("last_uploaded_hash") or reprocess_btn:
342
-
343
- start_time = time.time() # Start the timer
344
-
345
- raw_text = extract_text(uploaded_file)
346
-
347
- summary_dict = hybrid_summary_hierarchical(raw_text)
348
-
349
- st.session_state.messages.append({
350
- "role": "user",
351
- "content": f"📤 Uploaded **{uploaded_file.name}**"
352
- })
353
-
354
-
355
- # Start building preview
356
- preview_text = f"🧾 **Hybrid Summary of {uploaded_file.name}:**\n\n"
357
-
358
-
359
- for section in ["Facts", "Arguments", "Judgment", "Other"]:
360
- if section in summary_dict:
361
-
362
- filtered = role_based_filter(section, summary_dict[section], user_role)
363
-
364
- extractive = filtered.get("extractive", "").strip()
365
- abstractive = filtered.get("abstractive", "").strip()
366
-
367
- if not extractive and not abstractive:
368
- continue # Skip if empty after filtering
369
-
370
- preview_text += f"### 📘 {section} Section\n"
371
- preview_text += f"📌 **Extractive Summary:**\n{extractive if extractive else '_No content extracted._'}\n\n"
372
- preview_text += f"🔍 **Abstractive Summary:**\n{abstractive if abstractive else '_No summary generated._'}\n\n"
373
-
374
-
375
- # Display in chat
376
- with st.chat_message("assistant", avatar=BOT_AVATAR):
377
- display_with_typing_effect(clean_text(preview_text), speed=0)
378
-
379
- # Show processing time after the summary
380
- processing_time = round(time.time() - start_time, 2)
381
- st.session_state["last_response_time"] = processing_time
382
-
383
- if "last_response_time" in st.session_state:
384
- st.info(f"⏱️ Response generated in **{st.session_state['last_response_time']} seconds**.")
385
-
386
- st.session_state.messages.append({
387
- "role": "assistant",
388
- "content": clean_text(preview_text)
389
- })
390
-
391
- # Save this file hash only if it’s a new upload (avoid overwriting during reprocess)
392
- if not reprocess_btn:
393
- st.session_state.last_uploaded_hash = file_hash
394
-
395
- save_chat_history(st.session_state.messages)
396
-
397
- st.rerun()
398
-
399
-
400
- # Handle chat input and return hybrid summary
401
- if prompt:
402
- raw_text = prompt
403
- start_time = time.time()
404
-
405
- summary_dict = hybrid_summary_hierarchical(raw_text)
406
-
407
- st.session_state.messages.append({
408
- "role": "user",
409
- "content": prompt
410
- })
411
-
412
- # Start building preview
413
- preview_text = f"🧾 **Hybrid Summary of {uploaded_file.name}:**\n\n"
414
-
415
- for section in ["Facts", "Arguments", "Judgment", "Other"]:
416
- if section in summary_dict:
417
-
418
- filtered = role_based_filter(section, summary_dict[section], user_role)
419
-
420
- extractive = filtered.get("extractive", "").strip()
421
- abstractive = filtered.get("abstractive", "").strip()
422
-
423
- if not extractive and not abstractive:
424
- continue # Skip if empty after filtering
425
-
426
- preview_text += f"### 📘 {section} Section\n"
427
- preview_text += f"📌 **Extractive Summary:**\n{extractive if extractive else '_No content extracted._'}\n\n"
428
- preview_text += f"🔍 **Abstractive Summary:**\n{abstractive if abstractive else '_No summary generated._'}\n\n"
429
-
430
-
431
- # Display in chat
432
- with st.chat_message("assistant", avatar=BOT_AVATAR):
433
- display_with_typing_effect(clean_text(preview_text), speed=0)
434
-
435
- # Show processing time after the summary
436
- processing_time = round(time.time() - start_time, 2)
437
- st.session_state["last_response_time"] = processing_time
438
-
439
- if "last_response_time" in st.session_state:
440
- st.info(f"⏱️ Response generated in **{st.session_state['last_response_time']} seconds**.")
441
-
442
- st.session_state.messages.append({
443
- "role": "assistant",
444
- "content": clean_text(preview_text)
445
- })
446
-
447
- save_chat_history(st.session_state.messages)
448
-
449
- st.rerun()