Upload mt5_finetuned_summary.py
Browse files- mt5_finetuned_summary.py +279 -0
mt5_finetuned_summary.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import spacy
|
4 |
+
from pathlib import Path
|
5 |
+
import sys
|
6 |
+
import warnings
|
7 |
+
import re
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
# --- Prerequisites ---
|
11 |
+
# Ensure these are installed in your .venv:
|
12 |
+
# pip install spacy transformers torch sentencepiece protobuf==3.20.3 peft accelerate datasets evaluate gradio numpy
|
13 |
+
# (Make sure spacy version matches your NER model training version)
|
14 |
+
try:
|
15 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
16 |
+
import torch
|
17 |
+
from peft import PeftModel, PeftConfig # Import PEFT classes
|
18 |
+
except ImportError as e:
|
19 |
+
print(f"✘ Error: Missing required library: {e}")
|
20 |
+
print("Please install all dependencies: pip install spacy transformers torch sentencepiece protobuf==3.20.3 peft accelerate datasets evaluate gradio numpy")
|
21 |
+
sys.exit(1)
|
22 |
+
|
23 |
+
|
24 |
+
# --- Configuration ---
|
25 |
+
# 1. Path to your trained spaCy NER model directory
|
26 |
+
NER_MODEL_PATH = Path("./training_400/model-best") # <-- ADJUST if different
|
27 |
+
|
28 |
+
# 2. Hugging Face model name for the BASE summarization model
|
29 |
+
BASE_SUMMARIZATION_MODEL_NAME = "csebuetnlp/mT5_multilingual_XLSum"
|
30 |
+
|
31 |
+
# 3. Path to your saved PEFT/LoRA adapter directory (output from fine-tuning)
|
32 |
+
ADAPTER_PATH = Path("./mt5_finetuned_tamil_summary") # <-- ADJUST if different
|
33 |
+
|
34 |
+
# 4. Device: "cuda" for GPU or "cpu"
|
35 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
36 |
+
|
37 |
+
# 5. Summarization parameters
|
38 |
+
SUMM_NUM_BEAMS = 4
|
39 |
+
MIN_LEN_PERC = 0.30 # Target minimum summary length as % of input tokens
|
40 |
+
MAX_LEN_PERC = 0.70 # Target maximum summary length as % of input tokens (Increased)
|
41 |
+
ABS_MIN_TOKEN_LEN = 30 # Absolute minimum token length
|
42 |
+
ABS_MAX_TOKEN_LEN = 512 # Absolute maximum token length (Safer cap)
|
43 |
+
# --- End Configuration ---
|
44 |
+
|
45 |
+
# --- Suppress Warnings ---
|
46 |
+
warnings.filterwarnings("ignore", message="CUDA path could not be detected*")
|
47 |
+
warnings.filterwarnings("ignore", message=".*You are using `torch.load` with `weights_only=False`.*")
|
48 |
+
warnings.filterwarnings("ignore", message=".*The sentencepiece tokenizer that you are converting.*")
|
49 |
+
|
50 |
+
# --- Global Variables for Loaded Models ---
|
51 |
+
ner_model_global = None
|
52 |
+
summ_tokenizer_global = None
|
53 |
+
summ_model_global = None # This will hold the PEFT model
|
54 |
+
models_loaded = False
|
55 |
+
|
56 |
+
# --- Model Loading Functions ---
|
57 |
+
def load_ner_model(path):
|
58 |
+
"""Loads the spaCy NER model and ensures sentencizer is present."""
|
59 |
+
global ner_model_global
|
60 |
+
if not path.exists():
|
61 |
+
print(f"✘ FATAL: NER Model directory not found at {path.resolve()}")
|
62 |
+
return False
|
63 |
+
try:
|
64 |
+
ner_model_global = spacy.load(path)
|
65 |
+
print(f"✔ Successfully loaded NER model from: {path.resolve()}")
|
66 |
+
# Ensure a sentence boundary detector is present
|
67 |
+
component_to_add_before = None
|
68 |
+
if "tok2vec" in ner_model_global.pipe_names: component_to_add_before="tok2vec"
|
69 |
+
elif "ner" in ner_model_global.pipe_names: component_to_add_before="ner"
|
70 |
+
if not ner_model_global.has_pipe("sentencizer") and not ner_model_global.has_pipe("parser"):
|
71 |
+
try:
|
72 |
+
if component_to_add_before: ner_model_global.add_pipe("sentencizer", before=component_to_add_before)
|
73 |
+
else: ner_model_global.add_pipe("sentencizer", first=True)
|
74 |
+
print("INFO: Added 'sentencizer' to loaded NER pipeline.")
|
75 |
+
except Exception as e_pipe:
|
76 |
+
print(f"✘ WARNING: Could not add 'sentencizer': {e_pipe}.")
|
77 |
+
return True
|
78 |
+
except Exception as e:
|
79 |
+
print(f"✘ FATAL: Error loading NER model from {path.resolve()}: {e}")
|
80 |
+
return False
|
81 |
+
|
82 |
+
def load_finetuned_summarizer(base_model_name, adapter_dir_path):
|
83 |
+
"""Loads the base HF tokenizer/model and applies PEFT adapters."""
|
84 |
+
global summ_tokenizer_global, summ_model_global
|
85 |
+
if not adapter_dir_path.exists():
|
86 |
+
print(f"✘ FATAL: PEFT Adapter directory not found at {adapter_dir_path.resolve()}")
|
87 |
+
return False
|
88 |
+
try:
|
89 |
+
print(f"\nLoading base summarization tokenizer: {base_model_name}...")
|
90 |
+
summ_tokenizer_global = AutoTokenizer.from_pretrained(base_model_name)
|
91 |
+
|
92 |
+
print(f"Loading base summarization model: {base_model_name}...")
|
93 |
+
base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
|
94 |
+
|
95 |
+
print(f"Loading PEFT adapter from: {adapter_dir_path}...")
|
96 |
+
# Load the fine-tuned PEFT model by applying adapters to the base model
|
97 |
+
summ_model_global = PeftModel.from_pretrained(base_model, adapter_dir_path)
|
98 |
+
|
99 |
+
# Optional: Merge weights. This combines the adapter weights into the base model.
|
100 |
+
# It can make inference slightly faster but increases memory usage
|
101 |
+
# and you can no longer easily unload the adapter. Don't use if you plan
|
102 |
+
# to switch adapters or do more training later.
|
103 |
+
# print("INFO: Merging PEFT adapters into base model...")
|
104 |
+
# summ_model_global = summ_model_global.merge_and_unload()
|
105 |
+
# print("INFO: Adapters merged.")
|
106 |
+
|
107 |
+
summ_model_global.to(DEVICE)
|
108 |
+
print(f"INFO: Model's configured max generation length: {summ_model_global.config.max_length}") # Print base model's limit
|
109 |
+
print(f"✔ Successfully loaded fine-tuned PEFT model '{adapter_dir_path.name}' on base '{base_model_name}' on {DEVICE}.")
|
110 |
+
return True
|
111 |
+
except Exception as e:
|
112 |
+
print(f"✘ FATAL: Error loading fine-tuned summarization model: {e}")
|
113 |
+
import traceback
|
114 |
+
traceback.print_exc()
|
115 |
+
return False
|
116 |
+
|
117 |
+
# --- MODIFIED summarize_text function ---
|
118 |
+
def summarize_text(tokenizer, model, text, num_beams=SUMM_NUM_BEAMS,
|
119 |
+
min_length_perc=MIN_LEN_PERC, max_length_perc=MAX_LEN_PERC):
|
120 |
+
"""Generates abstractive summary with length based on input token percentage."""
|
121 |
+
if not text or text.isspace(): return "Input text is empty."
|
122 |
+
print("\nGenerating summary (using percentage lengths)...")
|
123 |
+
try:
|
124 |
+
# 1. Calculate input token length
|
125 |
+
with tokenizer.as_target_tokenizer():
|
126 |
+
input_ids_tensor = tokenizer(text, return_tensors="pt", truncation=False, padding=False).input_ids
|
127 |
+
input_token_count = input_ids_tensor.shape[1]
|
128 |
+
if input_token_count == 0: return "Input text tokenized to zero tokens."
|
129 |
+
print(f"INFO: Input text has approx {len(text.split())} words and {input_token_count} tokens.")
|
130 |
+
|
131 |
+
# 2. Calculate target token lengths
|
132 |
+
min_len_tokens = int(input_token_count * min_length_perc)
|
133 |
+
max_len_tokens = int(input_token_count * max_length_perc)
|
134 |
+
|
135 |
+
# 3. Apply absolute limits and ensure min < max
|
136 |
+
min_len_tokens = max(ABS_MIN_TOKEN_LEN, min_len_tokens)
|
137 |
+
max_len_tokens = max(min_len_tokens + 10, max_len_tokens)
|
138 |
+
max_len_tokens = min(ABS_MAX_TOKEN_LEN, max_len_tokens)
|
139 |
+
min_len_tokens = min(min_len_tokens, max_len_tokens)
|
140 |
+
print(f"INFO: Target summary token length: min={min_len_tokens}, max={max_len_tokens}.")
|
141 |
+
|
142 |
+
# 4. Tokenize for model input
|
143 |
+
inputs = tokenizer(text, max_length=1024, return_tensors="pt", padding="max_length", truncation=True).to(DEVICE)
|
144 |
+
|
145 |
+
# 5. Generate summary using CALCULATED min/max token lengths
|
146 |
+
print("INFO: Starting model.generate()...")
|
147 |
+
|
148 |
+
# --- *** THE FIX: Use explicit keyword argument 'input_ids=' *** ---
|
149 |
+
summary_ids = model.generate(
|
150 |
+
input_ids=inputs['input_ids'], # <<< Use keyword argument explicitly
|
151 |
+
# --- Other arguments remain keywords ---
|
152 |
+
num_beams=num_beams,
|
153 |
+
max_length=max_len_tokens,
|
154 |
+
min_length=min_len_tokens,
|
155 |
+
early_stopping=True
|
156 |
+
)
|
157 |
+
# --- *** End of Fix *** ---
|
158 |
+
|
159 |
+
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
160 |
+
print("✔ Summary generation complete.")
|
161 |
+
return summary
|
162 |
+
except TypeError as te: # Catch the specific error for better logging
|
163 |
+
print(f"✘ TypeError during summary generation: {te}")
|
164 |
+
print("✘ Error details: This often happens if generate() arguments are incorrect, check keyword vs positional.")
|
165 |
+
import traceback
|
166 |
+
traceback.print_exc()
|
167 |
+
return "[TypeError during summary generation - check arguments]"
|
168 |
+
except Exception as e:
|
169 |
+
print(f"✘ Error during summary generation: {e}")
|
170 |
+
import traceback
|
171 |
+
traceback.print_exc()
|
172 |
+
return "[Error generating summary]"
|
173 |
+
|
174 |
+
|
175 |
+
def extract_entities(ner_nlp, text):
|
176 |
+
"""Extracts named entities using the spaCy NER model."""
|
177 |
+
if not text or text.isspace(): return []
|
178 |
+
print("\nExtracting entities from original text using custom NER model...")
|
179 |
+
try:
|
180 |
+
doc = ner_nlp(text)
|
181 |
+
entities = list({(ent.text.strip(), ent.label_) for ent in doc.ents if ent.text.strip()}) # Unique entities
|
182 |
+
print(f"✔ Extracted {len(entities)} unique entities.")
|
183 |
+
return entities
|
184 |
+
except Exception as e:
|
185 |
+
print(f"✘ Error during entity extraction: {e}")
|
186 |
+
return []
|
187 |
+
|
188 |
+
def create_prompted_input(text, entities):
|
189 |
+
"""Creates a new input string with unique entities prepended."""
|
190 |
+
if not entities:
|
191 |
+
print("INFO: No entities found by NER, using original text for prompted summary.")
|
192 |
+
return text
|
193 |
+
unique_entity_texts = sorted(list({ent[0] for ent in entities if ent[0]}))
|
194 |
+
entity_string = ", ".join(unique_entity_texts)
|
195 |
+
separator = ". முக்கிய சொற்கள்: "
|
196 |
+
prompted_text = f"{entity_string}{separator}{text}"
|
197 |
+
print(f"\nINFO: Created prompted input with {len(unique_entity_texts)} unique entities (showing start): {prompted_text[:250]}...")
|
198 |
+
return prompted_text
|
199 |
+
|
200 |
+
# --- Main execution ---
|
201 |
+
def main():
|
202 |
+
global models_loaded # Access the flag
|
203 |
+
if not models_loaded:
|
204 |
+
print("✘ Models failed to load during startup. Cannot proceed.")
|
205 |
+
sys.exit(1)
|
206 |
+
|
207 |
+
print("\n" + "="*50)
|
208 |
+
print("Please paste the Tamil text paragraph you want to summarize below.")
|
209 |
+
print("Press Enter when finished.")
|
210 |
+
print("(You might need to configure your terminal for multi-line paste if it's long)")
|
211 |
+
print("-" * 50)
|
212 |
+
input_paragraph = input("Input Text:\n")
|
213 |
+
|
214 |
+
if not input_paragraph or input_paragraph.isspace():
|
215 |
+
print("\n✘ Error: No input text provided. Exiting.")
|
216 |
+
sys.exit(1)
|
217 |
+
text_to_process = input_paragraph.strip()
|
218 |
+
|
219 |
+
print("\n" + "="*50)
|
220 |
+
print("Processing Input Text (Snippet):")
|
221 |
+
print(text_to_process[:300] + "...")
|
222 |
+
print("="*50)
|
223 |
+
|
224 |
+
# --- Generate Output 1: Standard Summary (using FINE-TUNED model) ---
|
225 |
+
# Note: Even the "standard" summary now uses the fine-tuned model
|
226 |
+
print("\n--- Output 1: Standard Abstractive Summary (Fine-tuned Model) ---")
|
227 |
+
standard_summary = summarize_text(
|
228 |
+
summ_tokenizer_global, summ_model_global, text_to_process,
|
229 |
+
num_beams=SUMM_NUM_BEAMS
|
230 |
+
)
|
231 |
+
print("\nStandard Summary:")
|
232 |
+
print(standard_summary)
|
233 |
+
print("-" * 50)
|
234 |
+
|
235 |
+
# --- Generate Output 2: NER-Influenced Summary (using FINE-TUNED model) ---
|
236 |
+
print("\n--- Output 2: NER-Influenced Abstractive Summary (Fine-tuned Model) ---")
|
237 |
+
# a) Extract entities
|
238 |
+
extracted_entities = extract_entities(ner_model_global, text_to_process)
|
239 |
+
print("\nKey Entities Extracted by NER:")
|
240 |
+
if extracted_entities:
|
241 |
+
for text_ent, label in extracted_entities:
|
242 |
+
print(f" - '{text_ent}' ({label})")
|
243 |
+
else:
|
244 |
+
print(" No entities found by NER model.")
|
245 |
+
|
246 |
+
# b) Create prompted input
|
247 |
+
prompted_input_text = create_prompted_input(text_to_process, extracted_entities)
|
248 |
+
|
249 |
+
# c) Generate summary from prompted input
|
250 |
+
ner_influenced_summary = summarize_text(
|
251 |
+
summ_tokenizer_global, summ_model_global, prompted_input_text,
|
252 |
+
num_beams=SUMM_NUM_BEAMS
|
253 |
+
)
|
254 |
+
print("\nNER-Influenced Summary (Generated using entities as prefix):")
|
255 |
+
print(ner_influenced_summary)
|
256 |
+
print("\nNOTE: Compare this summary with the standard summary (Output 1).")
|
257 |
+
print("Fine-tuning might make both summaries better reflect your data's style.")
|
258 |
+
print("Prepending entities is still experimental for influencing inclusion.")
|
259 |
+
print("="*50)
|
260 |
+
|
261 |
+
|
262 |
+
if __name__ == "__main__":
|
263 |
+
# --- Load models globally when script starts ---
|
264 |
+
print("Application starting up... Loading models...")
|
265 |
+
# Load NER first, then summarizer (which might depend on NER path confirmation)
|
266 |
+
ner_loaded_ok = load_ner_model(NER_MODEL_PATH)
|
267 |
+
if ner_loaded_ok:
|
268 |
+
# Proceed to load summarizer only if NER loaded
|
269 |
+
summ_loaded_ok = load_finetuned_summarizer(BASE_SUMMARIZATION_MODEL_NAME, ADAPTER_PATH)
|
270 |
+
models_loaded = summ_loaded_ok # Overall success depends on summarizer loading
|
271 |
+
else:
|
272 |
+
models_loaded = False # NER failed, cannot proceed
|
273 |
+
|
274 |
+
if models_loaded:
|
275 |
+
print("\n--- All models loaded successfully! Ready for input. ---")
|
276 |
+
main() # Run the main interaction loop
|
277 |
+
else:
|
278 |
+
print("\n✘✘✘ CRITICAL ERROR: Model loading failed. Exiting. Check logs above. ✘✘✘")
|
279 |
+
sys.exit(1)
|