hymarog1 commited on
Commit
ce7b020
Β·
verified Β·
1 Parent(s): 1409054

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +880 -823
app.py CHANGED
@@ -1,823 +1,880 @@
1
- import streamlit as st
2
- import shelve
3
- import docx2txt
4
- import PyPDF2
5
- import time # Used to simulate typing effect
6
- import nltk
7
- import re
8
- import os
9
- import time # already imported in your code
10
- from dotenv import load_dotenv
11
- import torch
12
- from sentence_transformers import SentenceTransformer, util
13
- nltk.download('punkt')
14
- import hashlib
15
- from nltk import sent_tokenize
16
- nltk.download('punkt_tab')
17
- from transformers import LEDTokenizer, LEDForConditionalGeneration
18
- from transformers import pipeline
19
- import asyncio
20
- import dateutil.parser
21
- from datetime import datetime
22
- import sys
23
-
24
- from openai import OpenAI
25
- import numpy as np
26
-
27
-
28
- # Fix for RuntimeError: no running event loop on Windows
29
- if sys.platform.startswith("win"):
30
- asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
31
-
32
- st.set_page_config(page_title="Legal Document Summarizer", layout="wide")
33
-
34
- if "processed" not in st.session_state:
35
- st.session_state.processed = False
36
- if "last_uploaded_hash" not in st.session_state:
37
- st.session_state.last_uploaded_hash = None
38
- if "chat_prompt_processed" not in st.session_state:
39
- st.session_state.chat_prompt_processed = False
40
-
41
- if "embedding_text" not in st.session_state:
42
- st.session_state.embedding_text = None
43
-
44
- if "document_context" not in st.session_state:
45
- st.session_state.document_context = None
46
-
47
- if "last_prompt_hash" not in st.session_state:
48
- st.session_state.last_prompt_hash = None
49
-
50
-
51
- st.title("πŸ“„ Legal Document Summarizer (Simple RAG with evaluation results)")
52
-
53
- USER_AVATAR = "πŸ‘€"
54
- BOT_AVATAR = "πŸ€–"
55
-
56
- # Load chat history
57
- def load_chat_history():
58
- with shelve.open("chat_history") as db:
59
- return db.get("messages", [])
60
-
61
- # Save chat history
62
- def save_chat_history(messages):
63
- with shelve.open("chat_history") as db:
64
- db["messages"] = messages
65
-
66
- # Function to limit text preview to 500 words
67
- def limit_text(text, word_limit=500):
68
- words = text.split()
69
- return " ".join(words[:word_limit]) + ("..." if len(words) > word_limit else "")
70
-
71
-
72
- # CLEAN AND NORMALIZE TEXT
73
-
74
-
75
- def clean_text(text):
76
- # Remove newlines and extra spaces
77
- text = text.replace('\r\n', ' ').replace('\n', ' ')
78
- text = re.sub(r'\s+', ' ', text)
79
-
80
- # Remove page number markers like "Page 1 of 10"
81
- text = re.sub(r'Page\s+\d+\s+of\s+\d+', '', text, flags=re.IGNORECASE)
82
-
83
- # Remove long dashed or underscored lines
84
- text = re.sub(r'[_]{5,}', '', text) # Lines with underscores: _____
85
- text = re.sub(r'[-]{5,}', '', text) # Lines with hyphens: -----
86
-
87
- # Remove long dotted separators
88
- text = re.sub(r'[.]{4,}', '', text) # Dots like "......" or ".............."
89
-
90
- # Trim final leading/trailing whitespace
91
- text = text.strip()
92
-
93
- return text
94
-
95
-
96
- #######################################################################################################################
97
-
98
-
99
- # LOADING MODELS FOR DIVIDING TEXT INTO SECTIONS
100
-
101
- # Load token from .env file
102
- load_dotenv()
103
- HF_API_TOKEN = os.getenv("HF_API_TOKEN")
104
-
105
- client = OpenAI(
106
- base_url="https://api.studio.nebius.com/v1/",
107
- api_key=os.getenv("OPENAI_API_KEY")
108
- )
109
-
110
- # print("API Key:", os.getenv("OPENAI_API_KEY")) # Temporary for debugging
111
-
112
-
113
- # Load once at the top (cache for performance)
114
- @st.cache_resource
115
- def load_local_zero_shot_classifier():
116
- return pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")
117
-
118
- local_classifier = load_local_zero_shot_classifier()
119
-
120
-
121
- SECTION_LABELS = ["Facts", "Arguments", "Judgement", "Others"]
122
-
123
- def classify_chunk(text):
124
- result = local_classifier(text, candidate_labels=SECTION_LABELS)
125
- return result["labels"][0]
126
-
127
-
128
- # NEW: NLP-based sectioning using zero-shot classification
129
- def section_by_zero_shot(text):
130
- sections = {"Facts": "", "Arguments": "", "Judgment": "", "Others": ""}
131
- sentences = sent_tokenize(text)
132
- chunk = ""
133
-
134
- for i, sent in enumerate(sentences):
135
- chunk += sent + " "
136
- if (i + 1) % 3 == 0 or i == len(sentences) - 1:
137
- label = classify_chunk(chunk.strip())
138
- print(f"πŸ”Ž Chunk: {chunk[:60]}...\nπŸ”– Predicted Label: {label}")
139
- # πŸ‘‡ Normalize label (title case and fallback)
140
- label = label.capitalize()
141
- if label not in sections:
142
- label = "Others"
143
- sections[label] += chunk + "\n"
144
- chunk = ""
145
-
146
- return sections
147
-
148
- #######################################################################################################################
149
-
150
-
151
-
152
- # EXTRACTING TEXT FROM UPLOADED FILES
153
-
154
- # Function to extract text from uploaded file
155
- def extract_text(file):
156
- if file.name.endswith(".pdf"):
157
- reader = PyPDF2.PdfReader(file)
158
- full_text = "\n".join(page.extract_text() or "" for page in reader.pages)
159
- elif file.name.endswith(".docx"):
160
- full_text = docx2txt.process(file)
161
- elif file.name.endswith(".txt"):
162
- full_text = file.read().decode("utf-8")
163
- else:
164
- return "Unsupported file type."
165
-
166
- return full_text # Full text is needed for summarization
167
-
168
-
169
- #######################################################################################################################
170
-
171
- # EXTRACTIVE AND ABSTRACTIVE SUMMARIZATION
172
-
173
-
174
- @st.cache_resource
175
- def load_legalbert():
176
- return SentenceTransformer("nlpaueb/legal-bert-base-uncased")
177
-
178
-
179
- legalbert_model = load_legalbert()
180
-
181
- @st.cache_resource
182
- def load_led():
183
- tokenizer = LEDTokenizer.from_pretrained("allenai/led-base-16384")
184
- model = LEDForConditionalGeneration.from_pretrained("allenai/led-base-16384")
185
- return tokenizer, model
186
-
187
- tokenizer_led, model_led = load_led()
188
-
189
-
190
- def legalbert_extractive_summary(text, top_ratio=0.5):
191
- sentences = sent_tokenize(text)
192
- top_k = max(3, int(len(sentences) * top_ratio))
193
- if len(sentences) <= top_k:
194
- return text
195
- sentence_embeddings = legalbert_model.encode(sentences, convert_to_tensor=True)
196
- doc_embedding = torch.mean(sentence_embeddings, dim=0)
197
- cosine_scores = util.pytorch_cos_sim(doc_embedding, sentence_embeddings)[0]
198
- top_results = torch.topk(cosine_scores, k=top_k)
199
- selected_sentences = [sentences[i] for i in sorted(top_results.indices.tolist())]
200
- return " ".join(selected_sentences)
201
-
202
- # Add LED Abstractive Summarization
203
-
204
-
205
- def led_abstractive_summary(text, max_length=512, min_length=100):
206
- inputs = tokenizer_led(
207
- text, return_tensors="pt", padding="max_length",
208
- truncation=True, max_length=4096
209
- )
210
- global_attention_mask = torch.zeros_like(inputs["input_ids"])
211
- global_attention_mask[:, 0] = 1
212
-
213
- outputs = model_led.generate(
214
- inputs["input_ids"],
215
- attention_mask=inputs["attention_mask"],
216
- global_attention_mask=global_attention_mask,
217
- max_length=max_length,
218
- min_length=min_length,
219
- num_beams=4, # Use beam search
220
- repetition_penalty=2.0, # Penalize repetition
221
- length_penalty=1.0,
222
- early_stopping=True,
223
- no_repeat_ngram_size=4 # Prevent repeated phrases
224
- )
225
-
226
- return tokenizer_led.decode(outputs[0], skip_special_tokens=True)
227
-
228
-
229
-
230
- def led_abstractive_summary_chunked(text, max_tokens=3000):
231
- sentences = sent_tokenize(text)
232
- current_chunk, chunks, summaries = "", [], []
233
- for sent in sentences:
234
- if len(tokenizer_led(current_chunk + sent)["input_ids"]) > max_tokens:
235
- chunks.append(current_chunk)
236
- current_chunk = sent
237
- else:
238
- current_chunk += " " + sent
239
- if current_chunk:
240
- chunks.append(current_chunk)
241
- for chunk in chunks:
242
- inputs = tokenizer_led(chunk, return_tensors="pt", padding="max_length", truncation=True, max_length=4096)
243
- global_attention_mask = torch.zeros_like(inputs["input_ids"])
244
- global_attention_mask[:, 0] = 1
245
- output = model_led.generate(
246
- inputs["input_ids"],
247
- attention_mask=inputs["attention_mask"],
248
- global_attention_mask=global_attention_mask,
249
- max_length=512,
250
- min_length=100,
251
- num_beams=4,
252
- repetition_penalty=2.0,
253
- length_penalty=1.0,
254
- early_stopping=True,
255
- no_repeat_ngram_size=4,
256
- )
257
- summaries.append(tokenizer_led.decode(output[0], skip_special_tokens=True))
258
- return " ".join(summaries)
259
-
260
-
261
-
262
- def extract_timeline(text):
263
- sentences = sent_tokenize(text)
264
- timeline = []
265
-
266
- for sentence in sentences:
267
- try:
268
- # Try fuzzy parsing on the sentence
269
- parsed = dateutil.parser.parse(sentence, fuzzy=True)
270
-
271
- # Validate year: exclude years before 1950 unless explicitly whitelisted
272
- current_year = datetime.now().year
273
- if 1900 <= parsed.year <= current_year + 5:
274
- # Additional filtering: discard misleading past years unless contextually valid
275
- if parsed.year < 1950 and parsed.year not in [2020, 2022, 2023]:
276
- continue
277
-
278
- # Further validation: ignore obviously wrong patterns like years starting with 0
279
- if re.match(r"^0\d{3}$", str(parsed.year)):
280
- continue
281
-
282
- # Passed all checks
283
- timeline.append((parsed.date(), sentence.strip()))
284
- except Exception:
285
- continue
286
-
287
- # Remove duplicates and sort
288
- unique_timeline = list(set(timeline))
289
- return sorted(unique_timeline, key=lambda x: x[0])
290
-
291
-
292
-
293
- def format_timeline_for_chat(timeline_data):
294
- if not timeline_data:
295
- return "_No significant timeline events detected._"
296
-
297
- formatted = "πŸ—“οΈ **Timeline of Events**\n\n"
298
- for date, event in timeline_data:
299
- formatted += f"**{date.strftime('%Y-%m-%d')}**: {event}\n\n"
300
- return formatted.strip()
301
-
302
-
303
-
304
- def hybrid_summary_hierarchical(text, top_ratio=0.8):
305
- cleaned_text = clean_text(text)
306
- sections = section_by_zero_shot(cleaned_text)
307
-
308
- structured_summary = {} # <-- hierarchical summary here
309
-
310
- for name, content in sections.items():
311
- if content.strip():
312
- # Extractive summary
313
- extractive = legalbert_extractive_summary(content, top_ratio)
314
-
315
- # Abstractive summary
316
- abstractive = led_abstractive_summary_chunked(extractive)
317
-
318
- # Store in dictionary (hierarchical structure)
319
- structured_summary[name] = {
320
- "extractive": extractive,
321
- "abstractive": abstractive
322
- }
323
-
324
- return structured_summary
325
-
326
-
327
- from sentence_transformers import SentenceTransformer
328
-
329
- @st.cache_resource
330
- def load_embedder():
331
- return SentenceTransformer("all-MiniLM-L6-v2")
332
-
333
- embedder = load_embedder()
334
-
335
- # import faiss
336
- import numpy as np
337
-
338
-
339
- # def build_faiss_index(chunks):
340
- # embedder = load_embedder()
341
- # embeddings = embedder.encode(chunks, convert_to_tensor=False)
342
- # dimension = embeddings[0].shape[0]
343
- # index = faiss.IndexFlatL2(dimension)
344
- # index.add(np.array(embeddings).astype("float32"))
345
- # st.session_state["embedder"] = embedder
346
- # return index, chunks # βœ… Return both
347
-
348
-
349
- def retrieve_top_k(query, chunks, index, k=3):
350
- query_vec = embedder.encode([query])
351
- D, I = index.search(np.array(query_vec).astype("float32"), k)
352
- return [chunks[i] for i in I[0]]
353
-
354
-
355
- def chunk_text_custom(text, n=1000, overlap=200):
356
- chunks = []
357
- for i in range(0, len(text), n - overlap):
358
- chunks.append(text[i:i + n])
359
- return chunks
360
-
361
- def create_embeddings(text_chunks, model="BAAI/bge-en-icl"):
362
- response = client.embeddings.create(
363
- model=model,
364
- input=text_chunks
365
- )
366
- return response.data
367
-
368
- def cosine_similarity(vec1, vec2):
369
- return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
370
-
371
-
372
- def semantic_search(query, text_chunks, chunk_embeddings, k=7):
373
- query_embedding = create_embeddings([query])[0].embedding
374
- scores = [(i, cosine_similarity(np.array(query_embedding), np.array(emb.embedding))) for i, emb in enumerate(chunk_embeddings)]
375
- top_indices = [idx for idx, _ in sorted(scores, key=lambda x: x[1], reverse=True)[:k]]
376
- return [text_chunks[i] for i in top_indices]
377
-
378
-
379
-
380
- def generate_response(system_prompt, user_message, model="meta-llama/Llama-3.2-3B-Instruct"):
381
- return client.chat.completions.create(
382
- model=model,
383
- temperature=0,
384
- messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_message}]
385
- ).choices[0].message.content
386
-
387
-
388
- def rag_query_response(prompt, embedding_text):
389
- chunks = chunk_text_custom(embedding_text)
390
- chunk_embeddings = create_embeddings(chunks)
391
- top_chunks = semantic_search(prompt, chunks, chunk_embeddings, k=5)
392
- context_block = "\n\n".join([f"Context {i+1}:\n{chunk}" for i, chunk in enumerate(top_chunks)])
393
- user_prompt = f"{context_block}\n\nQuestion: {prompt}"
394
- system_instruction = (
395
- "You are an AI assistant that strictly answers based on the given context. "
396
- "If the answer cannot be derived directly from the context, respond: 'I do not have enough information to answer that.'"
397
- )
398
- return generate_response(system_instruction, user_prompt)
399
-
400
-
401
-
402
-
403
- #######################################################################################################################
404
-
405
-
406
- # STREAMLIT APP INTERFACE CODE
407
-
408
- # Initialize or load chat history
409
- if "messages" not in st.session_state:
410
- st.session_state.messages = load_chat_history()
411
-
412
- # Initialize last_uploaded if not set
413
- if "last_uploaded" not in st.session_state:
414
- st.session_state.last_uploaded = None
415
-
416
-
417
-
418
- # Sidebar with a button to delete chat history
419
- with st.sidebar:
420
- st.subheader("βš™οΈ Options")
421
- if st.button("Delete Chat History"):
422
- st.session_state.messages = []
423
- st.session_state.last_uploaded = None
424
- st.session_state.processed = False
425
- st.session_state.chat_prompt_processed = False
426
- save_chat_history([])
427
-
428
-
429
- # Display chat messages with a typing effect
430
- def display_with_typing_effect(text, speed=0.005):
431
- placeholder = st.empty()
432
- displayed_text = ""
433
- for char in text:
434
- displayed_text += char
435
- placeholder.markdown(displayed_text)
436
- time.sleep(speed)
437
- return displayed_text
438
-
439
- # Show existing chat messages
440
- for message in st.session_state.messages:
441
- avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
442
- with st.chat_message(message["role"], avatar=avatar):
443
- st.markdown(message["content"])
444
-
445
-
446
- # Standard chat input field
447
- prompt = st.chat_input("Type a message...")
448
-
449
-
450
- # Place uploader before the chat so it's always visible
451
- with st.container():
452
- st.subheader("πŸ“Ž Upload a Legal Document")
453
- uploaded_file = st.file_uploader("Upload a file (PDF, DOCX, TXT)", type=["pdf", "docx", "txt"])
454
- reprocess_btn = st.button("πŸ”„ Reprocess Last Uploaded File")
455
-
456
-
457
-
458
- # Hashing logic
459
- def get_file_hash(file):
460
- file.seek(0)
461
- content = file.read()
462
- file.seek(0)
463
- return hashlib.md5(content).hexdigest()
464
-
465
- # Function to prepare text for embedding
466
- # This function combines the extractive and abstractive summaries into a single string for embedding
467
- def prepare_text_for_embedding(summary_dict, timeline_data):
468
- combined_chunks = []
469
-
470
- for section, content in summary_dict.items():
471
- ext = content.get("extractive", "").strip()
472
- abs = content.get("abstractive", "").strip()
473
- if ext:
474
- combined_chunks.append(f"{section} - Extractive Summary:\n{ext}")
475
- if abs:
476
- combined_chunks.append(f"{section} - Abstractive Summary:\n{abs}")
477
-
478
- if timeline_data:
479
-
480
- combined_chunks.append("Timeline of Events:\n")
481
- for date, event in timeline_data:
482
- combined_chunks.append(f"{date.strftime('%Y-%m-%d')}: {event.strip()}")
483
-
484
- return "\n\n".join(combined_chunks)
485
-
486
-
487
- ###################################################################################################################
488
-
489
- # Store cleaned text and FAISS index only when document is processed
490
-
491
- # Embedding for chunking
492
-
493
-
494
- def chunk_text(text, max_tokens=100):
495
- sentences = sent_tokenize(text)
496
- chunks, current_chunk = [], ""
497
-
498
- for sentence in sentences:
499
- if len(current_chunk.split()) + len(sentence.split()) > max_tokens:
500
- chunks.append(current_chunk.strip())
501
- current_chunk = sentence
502
- else:
503
- current_chunk += " " + sentence
504
- if current_chunk:
505
- chunks.append(current_chunk.strip())
506
-
507
- return chunks
508
-
509
-
510
-
511
- ##############################################################################################################
512
-
513
- user_role = st.sidebar.selectbox(
514
- "🎭 Select Your Role for Custom Summary",
515
- ["General", "Judge", "Lawyer", "Student"]
516
- )
517
-
518
-
519
- def role_based_filter(section, summary, role):
520
- if role == "General":
521
- return summary
522
-
523
- filtered_summary = {
524
- "extractive": "",
525
- "abstractive": ""
526
- }
527
-
528
- if role == "Judge" and section in ["Judgement", "Facts"]:
529
- filtered_summary = summary
530
- elif role == "Lawyer" and section in ["Arguments", "Facts"]:
531
- filtered_summary = summary
532
- elif role == "Student" and section in ["Facts"]:
533
- filtered_summary = summary
534
-
535
- return filtered_summary
536
-
537
-
538
-
539
-
540
-
541
- #########################################################################################################################
542
-
543
-
544
- if uploaded_file:
545
- file_hash = get_file_hash(uploaded_file)
546
- if file_hash != st.session_state.last_uploaded_hash or reprocess_btn:
547
- st.session_state.processed = False
548
-
549
- # if is_new_file or reprocess_btn:
550
- # st.session_state.processed = False
551
-
552
- if not st.session_state.processed:
553
- start_time = time.time()
554
- raw_text = extract_text(uploaded_file)
555
- summary_dict = hybrid_summary_hierarchical(raw_text)
556
- timeline_data = extract_timeline(clean_text(raw_text))
557
- embedding_text = prepare_text_for_embedding(summary_dict, timeline_data)
558
-
559
- # Generate and display RAG-based summary
560
-
561
- st.session_state.document_context = embedding_text
562
-
563
- role_specific_prompt = f"As a {user_role}, summarize the legal document focusing on the most relevant aspects such as facts, arguments, and judgments tailored for your role. Include key legal reasoning and timeline of events where necessary."
564
- rag_summary = rag_query_response(role_specific_prompt, embedding_text)
565
-
566
- st.session_state.generated_summary = rag_summary
567
-
568
-
569
- st.session_state.messages.append({"role": "user", "content": f"πŸ“€ Uploaded **{uploaded_file.name}**"})
570
- st.session_state.messages.append({"role": "assistant", "content": rag_summary})
571
-
572
- with st.chat_message("assistant", avatar=BOT_AVATAR):
573
- display_with_typing_effect(rag_summary)
574
-
575
- processing_time = round((time.time() - start_time) / 60, 2)
576
- st.info(f"⏱️ Response generated in **{processing_time} minutes**.")
577
-
578
- st.session_state.last_uploaded_hash = file_hash
579
- st.session_state.processed = True
580
- st.session_state.last_prompt_hash = None
581
- save_chat_history(st.session_state.messages)
582
-
583
-
584
- # if prompt:
585
- # word_count = len(prompt.split())
586
- # # Document ingestion if long and not yet processed
587
- # if word_count > 30 and not st.session_state.processed:
588
- # raw_text = prompt
589
- # start_time = time.time()
590
- # summary_dict = hybrid_summary_hierarchical(raw_text)
591
- # timeline_data = extract_timeline(clean_text(raw_text))
592
- # embedding_text = prepare_text_for_embedding(summary_dict, timeline_data)
593
-
594
- # # Save document context for future queries
595
- # st.session_state.document_context = embedding_text
596
- # st.session_state.processed = True
597
-
598
- # # Initial role-based summary
599
- # role_prompt = f"As a {user_role}, summarize the document focusing on facts, arguments, judgments, plus timeline of events."
600
- # initial_summary = rag_query_response(role_prompt, embedding_text)
601
- # st.session_state.messages.append({"role": "user", "content": "πŸ“₯ Document ingested"})
602
- # st.session_state.messages.append({"role": "assistant", "content": initial_summary})
603
- # with st.chat_message("assistant", avatar=BOT_AVATAR):
604
- # display_with_typing_effect(initial_summary)
605
- # # Step 10: Show time
606
- # processing_time = round((time.time() - start_time) / 60, 2)
607
- # st.info(f"⏱️ Response generated in **{processing_time} minutes**.")
608
- # save_chat_history(st.session_state.messages)
609
-
610
- # # Querying phase: use existing document context
611
- # elif st.session_state.processed:
612
- # if not st.session_state.document_context:
613
- # st.warning("⚠️ No document context found. Please upload or paste your document first (30+ words).")
614
- # else:
615
- # answer = rag_query_response(prompt, st.session_state.document_context)
616
-
617
- # st.session_state.messages.append({"role": "user", "content": prompt})
618
- # st.session_state.messages.append({"role": "assistant", "content": answer})
619
- # with st.chat_message("assistant", avatar=BOT_AVATAR):
620
- # display_with_typing_effect(answer)
621
- # save_chat_history(st.session_state.messages)
622
-
623
- # # Prompt too short and no document yet
624
- # else:
625
- # with st.chat_message("assistant", avatar=BOT_AVATAR):
626
- # st.markdown("❗ Please first paste your document (more than 30 words), then ask questions.")
627
-
628
-
629
- if prompt:
630
- words = prompt.split()
631
- word_count = len(words)
632
-
633
- # compute a quick hash to detect β€œnew” direct-paste
634
- prompt_hash = hashlib.md5(prompt.encode("utf-8")).hexdigest()
635
-
636
- # --- 1) LONG prompts always re-ingest as a NEW doc ---
637
- if word_count > 30 and prompt_hash != st.session_state.last_prompt_hash:
638
- # mark this as our new β€œlast prompt”
639
- st.session_state.last_prompt_hash = prompt_hash
640
-
641
- # ingest exactly like you do for an uploaded file
642
- raw_text = prompt
643
- start_time = time.time()
644
-
645
- summary_dict = hybrid_summary_hierarchical(raw_text)
646
- timeline_data = extract_timeline(clean_text(raw_text))
647
- emb_text = prepare_text_for_embedding(summary_dict, timeline_data)
648
-
649
- # overwrite context
650
- st.session_state.document_context = emb_text
651
- st.session_state.processed = True
652
-
653
- # produce your initial summary
654
- role_prompt = (
655
- f"As a {user_role}, summarize the document focusing on facts, "
656
- "arguments, judgments, plus timeline of events."
657
- )
658
- initial_summary = rag_query_response(role_prompt, emb_text)
659
-
660
- st.session_state.messages.append({"role":"user", "content":"πŸ“₯ Document ingested"})
661
- st.session_state.messages.append({"role":"assistant","content":initial_summary})
662
- with st.chat_message("assistant", avatar=BOT_AVATAR):
663
- display_with_typing_effect(initial_summary)
664
-
665
- st.info(f"⏱️ Summary generated in {round((time.time()-start_time)/60,2)} minutes")
666
- save_chat_history(st.session_state.messages)
667
-
668
-
669
- # --- 2) SHORT prompts are queries against the last context ---
670
- elif word_count <= 30 and st.session_state.processed:
671
- answer = rag_query_response(prompt, st.session_state.document_context)
672
- st.session_state.messages.append({"role":"user", "content":prompt})
673
- st.session_state.messages.append({"role":"assistant", "content":answer})
674
- with st.chat_message("assistant", avatar=BOT_AVATAR):
675
- display_with_typing_effect(answer)
676
- save_chat_history(st.session_state.messages)
677
-
678
- # --- 3) anything else: ask them to paste something first ---
679
- else:
680
- with st.chat_message("assistant", avatar=BOT_AVATAR):
681
- st.markdown("❗ Paste at least 30 words of your document to ingest it first.")
682
-
683
-
684
-
685
- ######################################################################################################################### --- Evaluation Code Starts Here ---
686
-
687
- import evaluate
688
-
689
- # Load evaluators
690
- rouge = evaluate.load("rouge")
691
- bertscore = evaluate.load("bertscore")
692
-
693
-
694
- def evaluate_summary(generated_summary, ground_truth_summary):
695
- """Evaluate model-generated summary against ground truth."""
696
- # Compute ROUGE
697
- rouge_result = rouge.compute(predictions=[generated_summary], references=[ground_truth_summary])
698
-
699
- # Compute BERTScore
700
- bert_result = bertscore.compute(predictions=[generated_summary], references=[ground_truth_summary], lang="en")
701
-
702
- return rouge_result, bert_result
703
-
704
-
705
- # πŸ›‘ Upload ground truth (fix file uploader text)
706
- ground_truth_summary_file = st.file_uploader("πŸ“„ Upload Ground Truth Summary (.txt)", type=["txt"])
707
-
708
- if ground_truth_summary_file:
709
- ground_truth_summary = ground_truth_summary_file.read().decode("utf-8").strip()
710
-
711
- # ⚑ Make sure you have generated_summary available
712
- if "generated_summary" in st.session_state and st.session_state.generated_summary:
713
-
714
- # Perform evaluation
715
- rouge_result, bert_result = evaluate_summary(st.session_state.generated_summary, ground_truth_summary)
716
-
717
- # Display Results
718
- st.subheader("πŸ“Š Evaluation Results")
719
-
720
- st.write("πŸ”Ή ROUGE Scores:")
721
- st.json(rouge_result)
722
-
723
- st.write("πŸ”Ή BERTScore:")
724
- st.json(bert_result)
725
-
726
- else:
727
- st.warning("")
728
-
729
-
730
-
731
-
732
-
733
- ######################################################################################################################
734
-
735
-
736
- # Run this along with streamlit run app.py to evaluate the model's performance on a test set
737
- # Otherwise, comment the below code
738
-
739
- # β‡’ EVALUATION HOOK: after the very first summary, fire off evaluate.main() once
740
-
741
- # import json
742
- # import pandas as pd
743
- # import threading
744
- #
745
- #
746
- # def run_eval(doc_context):
747
- #
748
- # with open("test_case2.json", "r", encoding="utf-8") as f:
749
- # gt_data = json.load(f)
750
- #
751
- # # 2) map document_id β†’ local file
752
- # doc_paths = {
753
- # "case2": "case2.pdf",
754
- # # add more if you have more documents
755
- # }
756
- #
757
- # records = []
758
- # for entry in gt_data:
759
- # doc_id = entry["document_id"]
760
- # query = entry["query"]
761
- # gt_ans = entry["ground_truth_answer"]
762
- #
763
- #
764
- # # model_ans = rag_query_response(query, emb_text)
765
- # model_ans = rag_query_response(query, doc_context)
766
- #
767
- # records.append({
768
- # "document_id": doc_id,
769
- # "query": query,
770
- # "ground_truth_answer": gt_ans,
771
- # "model_answer": model_ans
772
- # })
773
- # print(f"βœ… Done {doc_id} / β€œ{query}”")
774
- #
775
- # # 3) push to DataFrame + CSV
776
- # df = pd.DataFrame(records)
777
- # out = "evaluation_results.csv"
778
- # df.to_csv(out, index=False, encoding="utf-8")
779
- # print(f"\nπŸ“ Saved {len(df)} rows to {out}")
780
- #
781
- #
782
- # # you could log this somewhere
783
- # def _run_evaluation():
784
- # try:
785
- # run_eval()
786
- # except Exception as e:
787
- # print("‼️ Evaluation script error:", e)
788
- #
789
- # if st.session_state.processed and not st.session_state.get("evaluation_launched", False):
790
- # st.session_state.evaluation_launched = True
791
- #
792
- # # inform user
793
- # st.sidebar.info("πŸ”¬ Starting background evaluation run…")
794
- #
795
- # # *capture* the context
796
- # doc_ctx = st.session_state.document_context
797
- #
798
- # # spawn the thread, passing doc_ctx in
799
- # threading.Thread(
800
- # target=lambda: run_eval(doc_ctx),
801
- # daemon=True
802
- # ).start()
803
- #
804
- # st.sidebar.success("βœ… Evaluation launched β€” check evaluation_results.csv when done.")
805
- #
806
- # # check for file existence & show download button
807
- # eval_path = os.path.abspath("evaluation_results.csv")
808
- # if os.path.exists(eval_path):
809
- # st.sidebar.success(f"βœ… Results saved to:\n`{eval_path}`")
810
- # # load it into a small dataframe (optional)
811
- # df_eval = pd.read_csv(eval_path)
812
- # # add a download button
813
- # st.sidebar.download_button(
814
- # label="⬇️ Download evaluation_results.csv",
815
- # data=df_eval.to_csv(index=False).encode("utf-8"),
816
- # file_name="evaluation_results.csv",
817
- # mime="text/csv"
818
- # )
819
- # else:
820
- # # if you want, display the cwd so you can inspect it
821
- # st.sidebar.info(f"Current working dir:\n`{os.getcwd()}`")
822
- #
823
- #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import shelve
3
+ import docx2txt
4
+ import PyPDF2
5
+ import time # Used to simulate typing effect
6
+ import nltk
7
+ import re
8
+ import os
9
+ import time # already imported in your code
10
+ from dotenv import load_dotenv
11
+ import torch
12
+ from sentence_transformers import SentenceTransformer, util
13
+ nltk.download('punkt')
14
+ import hashlib
15
+ from nltk import sent_tokenize
16
+ nltk.download('punkt_tab')
17
+ from transformers import LEDTokenizer, LEDForConditionalGeneration
18
+ from transformers import pipeline
19
+ import asyncio
20
+ import dateutil.parser
21
+ from datetime import datetime
22
+ import sys
23
+
24
+ from openai import OpenAI
25
+ import numpy as np
26
+
27
+
28
+ # Fix for RuntimeError: no running event loop on Windows
29
+ if sys.platform.startswith("win"):
30
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
31
+
32
+ st.set_page_config(page_title="Legal Document Summarizer", layout="wide")
33
+
34
+ if "processed" not in st.session_state:
35
+ st.session_state.processed = False
36
+ if "last_uploaded_hash" not in st.session_state:
37
+ st.session_state.last_uploaded_hash = None
38
+ if "chat_prompt_processed" not in st.session_state:
39
+ st.session_state.chat_prompt_processed = False
40
+
41
+ if "embedding_text" not in st.session_state:
42
+ st.session_state.embedding_text = None
43
+
44
+ if "document_context" not in st.session_state:
45
+ st.session_state.document_context = None
46
+
47
+ if "last_prompt_hash" not in st.session_state:
48
+ st.session_state.last_prompt_hash = None
49
+
50
+
51
+ st.title("πŸ“„ Legal Document Summarizer (Document Augmentation RAG)")
52
+
53
+ USER_AVATAR = "πŸ‘€"
54
+ BOT_AVATAR = "πŸ€–"
55
+
56
+ # Load chat history
57
+ def load_chat_history():
58
+ with shelve.open("chat_history") as db:
59
+ return db.get("messages", [])
60
+
61
+ # Save chat history
62
+ def save_chat_history(messages):
63
+ with shelve.open("chat_history") as db:
64
+ db["messages"] = messages
65
+
66
+ # Function to limit text preview to 500 words
67
+ def limit_text(text, word_limit=500):
68
+ words = text.split()
69
+ return " ".join(words[:word_limit]) + ("..." if len(words) > word_limit else "")
70
+
71
+
72
+ # CLEAN AND NORMALIZE TEXT
73
+
74
+
75
+ def clean_text(text):
76
+ # Remove newlines and extra spaces
77
+ text = text.replace('\r\n', ' ').replace('\n', ' ')
78
+ text = re.sub(r'\s+', ' ', text)
79
+
80
+ # Remove page number markers like "Page 1 of 10"
81
+ text = re.sub(r'Page\s+\d+\s+of\s+\d+', '', text, flags=re.IGNORECASE)
82
+
83
+ # Remove long dashed or underscored lines
84
+ text = re.sub(r'[_]{5,}', '', text) # Lines with underscores: _____
85
+ text = re.sub(r'[-]{5,}', '', text) # Lines with hyphens: -----
86
+
87
+ # Remove long dotted separators
88
+ text = re.sub(r'[.]{4,}', '', text) # Dots like "......" or ".............."
89
+
90
+ # Trim final leading/trailing whitespace
91
+ text = text.strip()
92
+
93
+ return text
94
+
95
+
96
+ #######################################################################################################################
97
+
98
+
99
+ # LOADING MODELS FOR DIVIDING TEXT INTO SECTIONS
100
+
101
+ # Load token from .env file
102
+ load_dotenv()
103
+ HF_API_TOKEN = os.getenv("HF_API_TOKEN")
104
+
105
+ client = OpenAI(
106
+ base_url="https://api.studio.nebius.com/v1/",
107
+ api_key=os.getenv("OPENAI_API_KEY")
108
+ )
109
+
110
+ # print("API Key:", os.getenv("OPENAI_API_KEY")) # Temporary for debugging
111
+
112
+
113
+ # Load once at the top (cache for performance)
114
+ @st.cache_resource
115
+ def load_local_zero_shot_classifier():
116
+ return pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")
117
+
118
+ local_classifier = load_local_zero_shot_classifier()
119
+
120
+
121
+ SECTION_LABELS = ["Facts", "Arguments", "Judgement", "Others"]
122
+
123
+ def classify_chunk(text):
124
+ result = local_classifier(text, candidate_labels=SECTION_LABELS)
125
+ return result["labels"][0]
126
+
127
+
128
+ # NEW: NLP-based sectioning using zero-shot classification
129
+ def section_by_zero_shot(text):
130
+ sections = {"Facts": "", "Arguments": "", "Judgment": "", "Others": ""}
131
+ sentences = sent_tokenize(text)
132
+ chunk = ""
133
+
134
+ for i, sent in enumerate(sentences):
135
+ chunk += sent + " "
136
+ if (i + 1) % 3 == 0 or i == len(sentences) - 1:
137
+ label = classify_chunk(chunk.strip())
138
+ print(f"πŸ”Ž Chunk: {chunk[:60]}...\nπŸ”– Predicted Label: {label}")
139
+ # πŸ‘‡ Normalize label (title case and fallback)
140
+ label = label.capitalize()
141
+ if label not in sections:
142
+ label = "Others"
143
+ sections[label] += chunk + "\n"
144
+ chunk = ""
145
+
146
+ return sections
147
+
148
+ #######################################################################################################################
149
+
150
+
151
+
152
+ # EXTRACTING TEXT FROM UPLOADED FILES
153
+
154
+ # Function to extract text from uploaded file
155
+ def extract_text(file):
156
+ if file.name.endswith(".pdf"):
157
+ reader = PyPDF2.PdfReader(file)
158
+ full_text = "\n".join(page.extract_text() or "" for page in reader.pages)
159
+ elif file.name.endswith(".docx"):
160
+ full_text = docx2txt.process(file)
161
+ elif file.name.endswith(".txt"):
162
+ full_text = file.read().decode("utf-8")
163
+ else:
164
+ return "Unsupported file type."
165
+
166
+ return full_text # Full text is needed for summarization
167
+
168
+
169
+ #######################################################################################################################
170
+
171
+ # EXTRACTIVE AND ABSTRACTIVE SUMMARIZATION
172
+
173
+
174
+ @st.cache_resource
175
+ def load_legalbert():
176
+ return SentenceTransformer("nlpaueb/legal-bert-base-uncased")
177
+
178
+
179
+ legalbert_model = load_legalbert()
180
+
181
+ @st.cache_resource
182
+ def load_led():
183
+ tokenizer = LEDTokenizer.from_pretrained("allenai/led-base-16384")
184
+ model = LEDForConditionalGeneration.from_pretrained("allenai/led-base-16384")
185
+ return tokenizer, model
186
+
187
+ tokenizer_led, model_led = load_led()
188
+
189
+
190
+ def legalbert_extractive_summary(text, top_ratio=0.2):
191
+ sentences = sent_tokenize(text)
192
+ top_k = max(3, int(len(sentences) * top_ratio))
193
+ if len(sentences) <= top_k:
194
+ return text
195
+ sentence_embeddings = legalbert_model.encode(sentences, convert_to_tensor=True)
196
+ doc_embedding = torch.mean(sentence_embeddings, dim=0)
197
+ cosine_scores = util.pytorch_cos_sim(doc_embedding, sentence_embeddings)[0]
198
+ top_results = torch.topk(cosine_scores, k=top_k)
199
+ selected_sentences = [sentences[i] for i in sorted(top_results.indices.tolist())]
200
+ return " ".join(selected_sentences)
201
+
202
+ # Add LED Abstractive Summarization
203
+
204
+
205
+ def led_abstractive_summary(text, max_length=512, min_length=100):
206
+ inputs = tokenizer_led(
207
+ text, return_tensors="pt", padding="max_length",
208
+ truncation=True, max_length=4096
209
+ )
210
+ global_attention_mask = torch.zeros_like(inputs["input_ids"])
211
+ global_attention_mask[:, 0] = 1
212
+
213
+ outputs = model_led.generate(
214
+ inputs["input_ids"],
215
+ attention_mask=inputs["attention_mask"],
216
+ global_attention_mask=global_attention_mask,
217
+ max_length=max_length,
218
+ min_length=min_length,
219
+ num_beams=4, # Use beam search
220
+ repetition_penalty=2.0, # Penalize repetition
221
+ length_penalty=1.0,
222
+ early_stopping=True,
223
+ no_repeat_ngram_size=4 # Prevent repeated phrases
224
+ )
225
+
226
+ return tokenizer_led.decode(outputs[0], skip_special_tokens=True)
227
+
228
+
229
+
230
+ def led_abstractive_summary_chunked(text, max_tokens=3000):
231
+ sentences = sent_tokenize(text)
232
+ current_chunk, chunks, summaries = "", [], []
233
+ for sent in sentences:
234
+ if len(tokenizer_led(current_chunk + sent)["input_ids"]) > max_tokens:
235
+ chunks.append(current_chunk)
236
+ current_chunk = sent
237
+ else:
238
+ current_chunk += " " + sent
239
+ if current_chunk:
240
+ chunks.append(current_chunk)
241
+ for chunk in chunks:
242
+ inputs = tokenizer_led(chunk, return_tensors="pt", padding="max_length", truncation=True, max_length=4096)
243
+ global_attention_mask = torch.zeros_like(inputs["input_ids"])
244
+ global_attention_mask[:, 0] = 1
245
+ output = model_led.generate(
246
+ inputs["input_ids"],
247
+ attention_mask=inputs["attention_mask"],
248
+ global_attention_mask=global_attention_mask,
249
+ max_length=512,
250
+ min_length=100,
251
+ num_beams=4,
252
+ repetition_penalty=2.0,
253
+ length_penalty=1.0,
254
+ early_stopping=True,
255
+ no_repeat_ngram_size=4,
256
+ )
257
+ summaries.append(tokenizer_led.decode(output[0], skip_special_tokens=True))
258
+ return " ".join(summaries)
259
+
260
+
261
+
262
+ def hybrid_summary_hierarchical(text, top_ratio=0.8):
263
+ cleaned_text = clean_text(text)
264
+ sections = section_by_zero_shot(cleaned_text)
265
+
266
+ structured_summary = {} # <-- hierarchical summary here
267
+
268
+ for name, content in sections.items():
269
+ if content.strip():
270
+ # Extractive summary
271
+ extractive = legalbert_extractive_summary(content, top_ratio)
272
+
273
+ # Abstractive summary
274
+ abstractive = led_abstractive_summary_chunked(extractive)
275
+
276
+ # Store in dictionary (hierarchical structure)
277
+ structured_summary[name] = {
278
+ "extractive": extractive,
279
+ "abstractive": abstractive
280
+ }
281
+
282
+ return structured_summary
283
+
284
+
285
+ def chunk_text_custom(text, n=1000, overlap=200):
286
+ chunks = []
287
+ for i in range(0, len(text), n - overlap):
288
+ chunks.append(text[i:i + n])
289
+ return chunks
290
+
291
+
292
+
293
+ def get_embedding(text, model="BAAI/bge-en-icl"):
294
+ """
295
+ From your notebook:
296
+ Creates an embedding for the given text chunk using the BGE-ICL model.
297
+ """
298
+ resp = client.embeddings.create(model=model, input=text)
299
+ return np.array(resp.data[0].embedding)
300
+
301
+
302
+
303
+ def semantic_search(query, text_chunks, chunk_embeddings, k=5):
304
+ """
305
+ Compute cosine similarity between the query embedding and each chunk embedding,
306
+ then pick the top-k chunks.
307
+ """
308
+ q_emb = get_embedding(query)
309
+ # simple cosine:
310
+ def cosine(a, b): return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
311
+ scores = [cosine(q_emb, emb) for emb in chunk_embeddings]
312
+ top_idxs = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:k]
313
+ return [text_chunks[i] for i in top_idxs]
314
+
315
+
316
+ def generate_response(system_prompt, user_message, model="meta-llama/Llama-3.2-3B-Instruct"):
317
+ return client.chat.completions.create(
318
+ model=model,
319
+ temperature=0,
320
+ messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_message}]
321
+ ).choices[0].message.content
322
+
323
+
324
+ def generate_questions(text_chunk, num_questions=5,
325
+ model="meta-llama/Llama-3.2-3B-Instruct"):
326
+ system_prompt = (
327
+ "You are an expert at generating relevant questions from text. "
328
+ "Create concise questions that can be answered using only the provided text."
329
+ )
330
+ user_prompt = f"""
331
+ Based on the following text, generate {num_questions} different questions
332
+ that can be answered using only this text:
333
+
334
+ {text_chunk}
335
+
336
+ Format your response as a numbered list of questions only.
337
+ """
338
+ resp = client.chat.completions.create(
339
+ model=model,
340
+ temperature=0.7,
341
+ messages=[
342
+ {"role":"system","content":system_prompt},
343
+ {"role":"user","content":user_prompt}
344
+ ]
345
+ )
346
+ raw = resp.choices[0].message.content.strip()
347
+ questions = []
348
+ for line in raw.split("\n"):
349
+ q = re.sub(r"^\d+\.\s*", "", line).strip()
350
+ if q.endswith("?"):
351
+ questions.append(q)
352
+ return questions
353
+
354
+ # 2) EMBEDDINGS
355
+ def create_embeddings(text, model="BAAI/bge-en-icl"):
356
+ resp = client.embeddings.create(model=model, input=text)
357
+ return resp.data[0].embedding
358
+
359
+ def cosine_similarity(a,b):
360
+ return float(np.dot(a,b)/(np.linalg.norm(a)*np.linalg.norm(b)))
361
+
362
+ # 3) VECTOR STORE
363
+ class SimpleVectorStore:
364
+ def __init__(self):
365
+ self.items = [] # each item is dict {text, embedding, metadata}
366
+ def add_item(self, text, embedding, metadata):
367
+ self.items.append(dict(text=text, embedding=embedding, metadata=metadata))
368
+ def search(self, query, k=5):
369
+ q_emb = create_embeddings(query)
370
+ scores = [(i, cosine_similarity(q_emb, item["embedding"]))
371
+ for i,item in enumerate(self.items)]
372
+ scores.sort(key=lambda x:x[1], reverse=True)
373
+ return [self.items[i] for i,_ in scores[:k]]
374
+
375
+ # 4) DOCUMENT PROCESSOR
376
+ def process_document(raw_text,
377
+ chunk_size=1000,
378
+ chunk_overlap=200,
379
+ questions_per_chunk=5):
380
+ # chunk the text
381
+ chunks = []
382
+ for i in range(0, len(raw_text), chunk_size - chunk_overlap):
383
+ chunks.append(raw_text[i:i+chunk_size])
384
+ store = SimpleVectorStore()
385
+ for idx,chunk in enumerate(chunks):
386
+ # chunk embedding
387
+ emb = create_embeddings(chunk)
388
+ store.add_item(chunk, emb, {"type":"chunk","index":idx})
389
+ # generate Qs + their embeddings
390
+ qs = generate_questions(chunk, num_questions=questions_per_chunk)
391
+ for q in qs:
392
+ q_emb = create_embeddings(q)
393
+ store.add_item(q, q_emb, {
394
+ "type":"question",
395
+ "chunk_index":idx,
396
+ "original_chunk": chunk
397
+ })
398
+ return chunks, store
399
+
400
+ # 5) CONTEXT BUILDER
401
+ def prepare_context(results):
402
+ seen = set()
403
+ ctx = []
404
+ # first direct chunks
405
+ for r in results:
406
+ m = r["metadata"]
407
+ if m["type"]=="chunk" and m["index"] not in seen:
408
+ seen.add(m["index"])
409
+ ctx.append(f"Chunk {m['index']}:\n{r['text']}")
410
+ # then referenced by questions
411
+ for r in results:
412
+ m = r["metadata"]
413
+ if m["type"]=="question":
414
+ ci = m["chunk_index"]
415
+ if ci not in seen:
416
+ seen.add(ci)
417
+ ctx.append(f"Chunk {ci} (via Q β€œ{r['text']}”):\n{m['original_chunk']}")
418
+ return "\n\n".join(ctx)
419
+
420
+ # 6) ANSWER GENERATOR (overrides your old generate_response)
421
+ def generate_response_from_context(query, context,
422
+ model="meta-llama/Llama-3.2-3B-Instruct"):
423
+ sp = (
424
+ "You are an AI assistant that strictly answers based on the given context. "
425
+ "If the answer cannot be derived directly from the provided context, "
426
+ "respond with: 'I do not have enough information to answer that.'"
427
+ )
428
+ up = f"""
429
+ Context:
430
+ {context}
431
+
432
+ Question: {query}
433
+
434
+ Please answer the question based only on the context above.
435
+ """
436
+ resp = client.chat.completions.create(
437
+ model=model,
438
+ temperature=0,
439
+ messages=[{"role":"system","content":sp},
440
+ {"role":"user","content":up}]
441
+ )
442
+ return resp.choices[0].message.content
443
+
444
+
445
+
446
+
447
+ #######################################################################################################################
448
+
449
+
450
+ # STREAMLIT APP INTERFACE CODE
451
+
452
+ # Initialize or load chat history
453
+ if "messages" not in st.session_state:
454
+ st.session_state.messages = load_chat_history()
455
+
456
+ # Initialize last_uploaded if not set
457
+ if "last_uploaded" not in st.session_state:
458
+ st.session_state.last_uploaded = None
459
+
460
+
461
+
462
+ # Sidebar with a button to delete chat history
463
+ with st.sidebar:
464
+ st.subheader("βš™οΈ Options")
465
+ if st.button("Delete Chat History"):
466
+ st.session_state.messages = []
467
+ st.session_state.last_uploaded = None
468
+ st.session_state.processed = False
469
+ st.session_state.chat_prompt_processed = False
470
+ save_chat_history([])
471
+
472
+
473
+ # Display chat messages with a typing effect
474
+ def display_with_typing_effect(text, speed=0.005):
475
+ placeholder = st.empty()
476
+ displayed_text = ""
477
+ for char in text:
478
+ displayed_text += char
479
+ placeholder.markdown(displayed_text)
480
+ time.sleep(speed)
481
+ return displayed_text
482
+
483
+ # Show existing chat messages
484
+ for message in st.session_state.messages:
485
+ avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
486
+ with st.chat_message(message["role"], avatar=avatar):
487
+ st.markdown(message["content"])
488
+
489
+
490
+ # Standard chat input field
491
+ prompt = st.chat_input("Type a message...")
492
+
493
+
494
+ # Place uploader before the chat so it's always visible
495
+ with st.container():
496
+ st.subheader("πŸ“Ž Upload a Legal Document")
497
+ uploaded_file = st.file_uploader("Upload a file (PDF, DOCX, TXT)", type=["pdf", "docx", "txt"])
498
+ reprocess_btn = st.button("πŸ”„ Reprocess Last Uploaded File")
499
+
500
+
501
+
502
+ # Hashing logic
503
+ def get_file_hash(file):
504
+ file.seek(0)
505
+ content = file.read()
506
+ file.seek(0)
507
+ return hashlib.md5(content).hexdigest()
508
+
509
+ # Function to prepare text for embedding
510
+ # This function combines the extractive and abstractive summaries into a single string for embedding
511
+ def prepare_text_for_embedding(summary_dict):
512
+ combined_chunks = []
513
+
514
+ for section, content in summary_dict.items():
515
+ ext = content.get("extractive", "").strip()
516
+ abs = content.get("abstractive", "").strip()
517
+ if ext:
518
+ combined_chunks.append(f"{section} - Extractive Summary:\n{ext}")
519
+ if abs:
520
+ combined_chunks.append(f"{section} - Abstractive Summary:\n{abs}")
521
+
522
+ return "\n\n".join(combined_chunks)
523
+
524
+
525
+ ##############################################################################################################
526
+
527
+ user_role = st.sidebar.selectbox(
528
+ "🎭 Select Your Role for Custom Summary",
529
+ ["General", "Judge", "Lawyer", "Student"]
530
+ )
531
+
532
+
533
+ def role_based_filter(section, summary, role):
534
+ if role == "General":
535
+ return summary
536
+
537
+ filtered_summary = {
538
+ "extractive": "",
539
+ "abstractive": ""
540
+ }
541
+
542
+ if role == "Judge" and section in ["Judgement", "Facts"]:
543
+ filtered_summary = summary
544
+ elif role == "Lawyer" and section in ["Arguments", "Facts"]:
545
+ filtered_summary = summary
546
+ elif role == "Student" and section in ["Facts"]:
547
+ filtered_summary = summary
548
+
549
+ return filtered_summary
550
+
551
+
552
+
553
+ #########################################################################################################################
554
+
555
+
556
+ if uploaded_file:
557
+ file_hash = get_file_hash(uploaded_file)
558
+ if file_hash != st.session_state.last_uploaded_hash or reprocess_btn:
559
+ st.session_state.processed = False
560
+
561
+ if not st.session_state.processed:
562
+ start_time = time.time()
563
+
564
+ # 1) extract & summarize as before
565
+ raw_text = extract_text(uploaded_file)
566
+ summary_dict = hybrid_summary_hierarchical(raw_text)
567
+ embedding_text = prepare_text_for_embedding(summary_dict)
568
+
569
+ # ─── NEW: document‐augmentation ingestion ───
570
+ chunks, store = process_document(raw_text,
571
+ chunk_size=1000,
572
+ chunk_overlap=200,
573
+ questions_per_chunk=5)
574
+ st.session_state.vector_store = store
575
+ # ────────────────────────────────────────────
576
+
577
+ # 2) generate your β€œrole‐specific prompt” as before
578
+ st.session_state.document_context = embedding_text
579
+
580
+ if user_role == "General":
581
+ role_specific_prompt = (
582
+ "Summarize the legal document focusing on the most relevant aspects "
583
+ "such as facts, arguments, and judgments. Include key legal reasoning "
584
+ "and a timeline of events where necessary."
585
+ )
586
+ else:
587
+ role_specific_prompt = (
588
+ f"As a {user_role}, summarize the legal document focusing on "
589
+ "the most relevant aspects such as facts, arguments, and judgments "
590
+ "tailored for your role. Include key legal reasoning and timeline of events."
591
+ )
592
+
593
+ # ─── REPLACE rag_query_response with doc‐augmentation RAG ───
594
+ results = store.search(role_specific_prompt, k=5)
595
+ context = prepare_context(results)
596
+ rag_summary = generate_response_from_context(role_specific_prompt, context)
597
+ #
598
+
599
+ st.session_state.messages.append({
600
+ "role": "user",
601
+ "content": f"πŸ“€ Uploaded **{uploaded_file.name}**"
602
+ })
603
+ st.session_state.messages.append({
604
+ "role": "assistant",
605
+ "content": rag_summary
606
+ })
607
+ with st.chat_message("assistant", avatar=BOT_AVATAR):
608
+ display_with_typing_effect(rag_summary)
609
+
610
+ processing_time = round((time.time() - start_time) / 60, 2)
611
+ st.info(f"⏱️ Response generated in **{processing_time} minutes**.")
612
+
613
+ st.session_state.generated_summary = rag_summary
614
+ st.session_state.last_uploaded_hash = file_hash
615
+ st.session_state.processed = True
616
+ st.session_state.last_prompt_hash = None
617
+ save_chat_history(st.session_state.messages)
618
+
619
+
620
+
621
+ if prompt:
622
+ words = prompt.split()
623
+ word_count = len(words)
624
+ prompt_hash = hashlib.md5(prompt.encode("utf-8")).hexdigest()
625
+
626
+ # 1) LONG prompts – echo & ingest like a β€œpaste‐in” document
627
+ if word_count > 30 and prompt_hash != st.session_state.last_prompt_hash:
628
+ st.session_state.last_prompt_hash = prompt_hash
629
+
630
+ raw_text = prompt
631
+ st.session_state.messages.append({
632
+ "role": "user",
633
+ "content": f"πŸ“₯ **Pasted Document Text:**\n\n{limit_text(raw_text,500)}"
634
+ })
635
+ with st.chat_message("user", avatar=USER_AVATAR):
636
+ st.markdown(limit_text(raw_text,500))
637
+
638
+ start_time = time.time()
639
+ # summarization + emb_text as before
640
+ summary_dict = hybrid_summary_hierarchical(raw_text)
641
+ emb_text = prepare_text_for_embedding(summary_dict)
642
+ st.session_state.document_context = emb_text
643
+ st.session_state.processed = True
644
+
645
+ # ─── NEW: ingest via document‐augmentation ───
646
+ chunks, store = process_document(raw_text)
647
+ st.session_state.vector_store = store
648
+
649
+ if user_role == "General":
650
+ role_prompt = (
651
+ "Summarize the document focusing on facts, arguments, judgments, "
652
+ "and include a timeline of events."
653
+ )
654
+ else:
655
+ role_prompt = (
656
+ f"As a {user_role}, summarize the document focusing on facts, "
657
+ "arguments, judgments, plus timeline of events."
658
+ )
659
+
660
+ # ─── doc‐augmentation RAG here too ───
661
+ results = store.search(role_prompt, k=5)
662
+ context = prepare_context(results)
663
+ initial_summary = generate_response_from_context(role_prompt, context)
664
+
665
+ st.session_state.messages.append({
666
+ "role": "assistant",
667
+ "content": initial_summary
668
+ })
669
+ with st.chat_message("assistant", avatar=BOT_AVATAR):
670
+ display_with_typing_effect(initial_summary)
671
+
672
+ st.info(f"⏱️ Summary generated in {round((time.time()-start_time)/60,2)} minutes")
673
+ save_chat_history(st.session_state.messages)
674
+
675
+
676
+ # 2) SHORT prompts – normal RAG against last ingested context
677
+ elif word_count <= 30 and st.session_state.processed:
678
+
679
+ with st.chat_message("user", avatar=USER_AVATAR):
680
+ st.markdown(prompt)
681
+
682
+ # 2) save to history
683
+ st.session_state.messages.append({"role": "user", "content": prompt})
684
+ store = st.session_state.vector_store
685
+
686
+ # ─── instead of rag_query_response, do doc‐augmentation RAG ───
687
+ results = store.search(prompt, k=5)
688
+ context = prepare_context(results)
689
+ answer = generate_response_from_context(prompt, context)
690
+
691
+ # st.session_state.messages.append({"role":"user", "content":prompt})
692
+ st.session_state.messages.append({"role":"assistant","content":answer})
693
+ with st.chat_message("assistant", avatar=BOT_AVATAR):
694
+ display_with_typing_effect(answer)
695
+ save_chat_history(st.session_state.messages)
696
+
697
+
698
+ # 3) not enough input
699
+ else:
700
+ with st.chat_message("assistant", avatar=BOT_AVATAR):
701
+ st.markdown("❗ Paste at least 30 words of your document to ingest it first.")
702
+
703
+
704
+ ################################Evaluation###########################
705
+ ######################################################################################################################
706
+
707
+ # πŸ“š Imports
708
+ import evaluate
709
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
710
+ from sklearn.metrics import f1_score
711
+
712
+ # πŸ“Œ Load Evaluators Once
713
+ @st.cache_resource
714
+ def load_evaluators():
715
+ rouge = evaluate.load("rouge")
716
+ bertscore = evaluate.load("bertscore")
717
+ return rouge, bertscore
718
+
719
+ rouge, bertscore = load_evaluators()
720
+
721
+ # πŸ“Œ Define Evaluation Functions
722
+ def evaluate_summary(generated_summary, ground_truth_summary):
723
+ """Evaluate ROUGE and BERTScore."""
724
+ rouge_result = rouge.compute(predictions=[generated_summary], references=[ground_truth_summary])
725
+ bert_result = bertscore.compute(predictions=[generated_summary], references=[ground_truth_summary], lang="en")
726
+ return rouge_result, bert_result
727
+
728
+ def exact_match(prediction, ground_truth):
729
+ return int(prediction.strip().lower() == ground_truth.strip().lower())
730
+
731
+ def compute_bleu(prediction, ground_truth):
732
+ reference = [ground_truth.strip().split()]
733
+ candidate = prediction.strip().split()
734
+ smoothie = SmoothingFunction().method4
735
+ return sentence_bleu(reference, candidate, smoothing_function=smoothie)
736
+
737
+ def compute_f1(prediction, ground_truth):
738
+ """Compute F1 score based on token overlap, like in QA evaluation."""
739
+ pred_tokens = prediction.strip().lower().split()
740
+ gt_tokens = ground_truth.strip().lower().split()
741
+
742
+ common_tokens = set(pred_tokens) & set(gt_tokens)
743
+ num_common = len(common_tokens)
744
+
745
+ if num_common == 0:
746
+ return 0.0
747
+
748
+ precision = num_common / len(pred_tokens)
749
+ recall = num_common / len(gt_tokens)
750
+ f1 = 2 * (precision * recall) / (precision + recall)
751
+ return f1
752
+
753
+ def evaluate_additional_metrics(prediction, ground_truth):
754
+ em = exact_match(prediction, ground_truth)
755
+ bleu = compute_bleu(prediction, ground_truth)
756
+ f1 = compute_f1(prediction, ground_truth)
757
+ return {
758
+ "Exact Match": em,
759
+ "BLEU Score": bleu,
760
+ "F1 Score": f1
761
+ }
762
+
763
+ # πŸ“₯ Upload and Evaluate
764
+ ground_truth_summary_file = st.file_uploader("πŸ“„ Upload Ground Truth Summary (.txt)", type=["txt"])
765
+
766
+ if ground_truth_summary_file:
767
+ ground_truth_summary = ground_truth_summary_file.read().decode("utf-8").strip()
768
+
769
+ if "generated_summary" in st.session_state and st.session_state.generated_summary:
770
+ prediction = st.session_state.generated_summary
771
+
772
+ # Evaluate ROUGE and BERTScore
773
+ rouge_result, bert_result = evaluate_summary(prediction, ground_truth_summary)
774
+
775
+ # Display ROUGE and BERTScore
776
+ st.subheader("πŸ“Š Evaluation Results")
777
+ st.write("πŸ”Ή ROUGE Scores:")
778
+ st.json(rouge_result)
779
+ st.write("πŸ”Ή BERTScore:")
780
+ st.json(bert_result)
781
+
782
+ # Evaluate and Display Exact Match, BLEU, F1
783
+ additional_metrics = evaluate_additional_metrics(prediction, ground_truth_summary)
784
+ st.subheader("πŸ”Ž Additional Evaluation Metrics")
785
+ st.json(additional_metrics)
786
+
787
+ else:
788
+ st.warning("⚠️ Please generate a summary first by uploading a document.")
789
+
790
+
791
+
792
+
793
+
794
+ ######################################################################################################################
795
+
796
+
797
+ # Run this along with streamlit run app.py to evaluate the model's performance on a test set
798
+ # Otherwise, comment the below code
799
+
800
+ # β‡’ EVALUATION HOOK: after the very first summary, fire off evaluate.main() once
801
+
802
+ # import json
803
+ # import pandas as pd
804
+ # import threading
805
+
806
+
807
+ # def run_eval(doc_context):
808
+
809
+ # with open("test_case1.json", "r", encoding="utf-8") as f:
810
+ # gt_data = json.load(f)
811
+
812
+ # # 2) map document_id β†’ local file
813
+
814
+ # records = []
815
+ # for entry in gt_data:
816
+ # doc_id = entry["document_id"]
817
+ # query = entry["query"]
818
+ # gt_ans = entry["ground_truth_answer"]
819
+
820
+
821
+ # # model_ans = rag_query_response(query, emb_text)
822
+ # model_ans = rag_query_response(query, doc_context)
823
+
824
+ # records.append({
825
+ # "document_id": doc_id,
826
+ # "query": query,
827
+ # "ground_truth_answer": gt_ans,
828
+ # "model_answer": model_ans
829
+ # })
830
+ # print(f"βœ… Done {doc_id} / β€œ{query}”")
831
+
832
+ # # 3) push to DataFrame + CSV
833
+ # df = pd.DataFrame(records)
834
+ # out = "evaluation_results.csv"
835
+ # df.to_csv(out, index=False, encoding="utf-8")
836
+ # print(f"\nπŸ“ Saved {len(df)} rows to {out}")
837
+
838
+
839
+ # # you could log this somewhere
840
+ # def _run_evaluation():
841
+ # try:
842
+ # run_eval()
843
+ # except Exception as e:
844
+ # print("‼️ Evaluation script error:", e)
845
+
846
+ # if st.session_state.processed and not st.session_state.get("evaluation_launched", False):
847
+ # st.session_state.evaluation_launched = True
848
+
849
+ # # inform user
850
+ # st.sidebar.info("πŸ”¬ Starting background evaluation run…")
851
+
852
+ # # *capture* the context
853
+ # doc_ctx = st.session_state.document_context
854
+
855
+ # # spawn the thread, passing doc_ctx in
856
+ # threading.Thread(
857
+ # target=lambda: run_eval(doc_ctx),
858
+ # daemon=True
859
+ # ).start()
860
+
861
+ # st.sidebar.success("βœ… Evaluation launched β€” check evaluation_results.csv when done.")
862
+
863
+ # # check for file existence & show download button
864
+ # eval_path = os.path.abspath("evaluation_results.csv")
865
+ # if os.path.exists(eval_path):
866
+ # st.sidebar.success(f"βœ… Results saved to:\n`{eval_path}`")
867
+ # # load it into a small dataframe (optional)
868
+ # df_eval = pd.read_csv(eval_path)
869
+ # # add a download button
870
+ # st.sidebar.download_button(
871
+ # label="⬇️ Download evaluation_results.csv",
872
+ # data=df_eval.to_csv(index=False).encode("utf-8"),
873
+ # file_name="evaluation_results.csv",
874
+ # mime="text/csv"
875
+ # )
876
+ # else:
877
+ # # if you want, display the cwd so you can inspect it
878
+ # st.sidebar.info(f"Current working dir:\n`{os.getcwd()}`")
879
+
880
+