Nivas007 commited on
Commit
b9285e0
·
verified ·
1 Parent(s): 925ce54

Upload mt5_finetuned_summary.py

Browse files
Files changed (1) hide show
  1. mt5_finetuned_summary.py +279 -0
mt5_finetuned_summary.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import spacy
4
+ from pathlib import Path
5
+ import sys
6
+ import warnings
7
+ import re
8
+ import numpy as np
9
+
10
+ # --- Prerequisites ---
11
+ # Ensure these are installed in your .venv:
12
+ # pip install spacy transformers torch sentencepiece protobuf==3.20.3 peft accelerate datasets evaluate gradio numpy
13
+ # (Make sure spacy version matches your NER model training version)
14
+ try:
15
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
16
+ import torch
17
+ from peft import PeftModel, PeftConfig # Import PEFT classes
18
+ except ImportError as e:
19
+ print(f"✘ Error: Missing required library: {e}")
20
+ print("Please install all dependencies: pip install spacy transformers torch sentencepiece protobuf==3.20.3 peft accelerate datasets evaluate gradio numpy")
21
+ sys.exit(1)
22
+
23
+
24
+ # --- Configuration ---
25
+ # 1. Path to your trained spaCy NER model directory
26
+ NER_MODEL_PATH = Path("./training_400/model-best") # <-- ADJUST if different
27
+
28
+ # 2. Hugging Face model name for the BASE summarization model
29
+ BASE_SUMMARIZATION_MODEL_NAME = "csebuetnlp/mT5_multilingual_XLSum"
30
+
31
+ # 3. Path to your saved PEFT/LoRA adapter directory (output from fine-tuning)
32
+ ADAPTER_PATH = Path("./mt5_finetuned_tamil_summary") # <-- ADJUST if different
33
+
34
+ # 4. Device: "cuda" for GPU or "cpu"
35
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
36
+
37
+ # 5. Summarization parameters
38
+ SUMM_NUM_BEAMS = 4
39
+ MIN_LEN_PERC = 0.30 # Target minimum summary length as % of input tokens
40
+ MAX_LEN_PERC = 0.70 # Target maximum summary length as % of input tokens (Increased)
41
+ ABS_MIN_TOKEN_LEN = 30 # Absolute minimum token length
42
+ ABS_MAX_TOKEN_LEN = 512 # Absolute maximum token length (Safer cap)
43
+ # --- End Configuration ---
44
+
45
+ # --- Suppress Warnings ---
46
+ warnings.filterwarnings("ignore", message="CUDA path could not be detected*")
47
+ warnings.filterwarnings("ignore", message=".*You are using `torch.load` with `weights_only=False`.*")
48
+ warnings.filterwarnings("ignore", message=".*The sentencepiece tokenizer that you are converting.*")
49
+
50
+ # --- Global Variables for Loaded Models ---
51
+ ner_model_global = None
52
+ summ_tokenizer_global = None
53
+ summ_model_global = None # This will hold the PEFT model
54
+ models_loaded = False
55
+
56
+ # --- Model Loading Functions ---
57
+ def load_ner_model(path):
58
+ """Loads the spaCy NER model and ensures sentencizer is present."""
59
+ global ner_model_global
60
+ if not path.exists():
61
+ print(f"✘ FATAL: NER Model directory not found at {path.resolve()}")
62
+ return False
63
+ try:
64
+ ner_model_global = spacy.load(path)
65
+ print(f"✔ Successfully loaded NER model from: {path.resolve()}")
66
+ # Ensure a sentence boundary detector is present
67
+ component_to_add_before = None
68
+ if "tok2vec" in ner_model_global.pipe_names: component_to_add_before="tok2vec"
69
+ elif "ner" in ner_model_global.pipe_names: component_to_add_before="ner"
70
+ if not ner_model_global.has_pipe("sentencizer") and not ner_model_global.has_pipe("parser"):
71
+ try:
72
+ if component_to_add_before: ner_model_global.add_pipe("sentencizer", before=component_to_add_before)
73
+ else: ner_model_global.add_pipe("sentencizer", first=True)
74
+ print("INFO: Added 'sentencizer' to loaded NER pipeline.")
75
+ except Exception as e_pipe:
76
+ print(f"✘ WARNING: Could not add 'sentencizer': {e_pipe}.")
77
+ return True
78
+ except Exception as e:
79
+ print(f"✘ FATAL: Error loading NER model from {path.resolve()}: {e}")
80
+ return False
81
+
82
+ def load_finetuned_summarizer(base_model_name, adapter_dir_path):
83
+ """Loads the base HF tokenizer/model and applies PEFT adapters."""
84
+ global summ_tokenizer_global, summ_model_global
85
+ if not adapter_dir_path.exists():
86
+ print(f"✘ FATAL: PEFT Adapter directory not found at {adapter_dir_path.resolve()}")
87
+ return False
88
+ try:
89
+ print(f"\nLoading base summarization tokenizer: {base_model_name}...")
90
+ summ_tokenizer_global = AutoTokenizer.from_pretrained(base_model_name)
91
+
92
+ print(f"Loading base summarization model: {base_model_name}...")
93
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
94
+
95
+ print(f"Loading PEFT adapter from: {adapter_dir_path}...")
96
+ # Load the fine-tuned PEFT model by applying adapters to the base model
97
+ summ_model_global = PeftModel.from_pretrained(base_model, adapter_dir_path)
98
+
99
+ # Optional: Merge weights. This combines the adapter weights into the base model.
100
+ # It can make inference slightly faster but increases memory usage
101
+ # and you can no longer easily unload the adapter. Don't use if you plan
102
+ # to switch adapters or do more training later.
103
+ # print("INFO: Merging PEFT adapters into base model...")
104
+ # summ_model_global = summ_model_global.merge_and_unload()
105
+ # print("INFO: Adapters merged.")
106
+
107
+ summ_model_global.to(DEVICE)
108
+ print(f"INFO: Model's configured max generation length: {summ_model_global.config.max_length}") # Print base model's limit
109
+ print(f"✔ Successfully loaded fine-tuned PEFT model '{adapter_dir_path.name}' on base '{base_model_name}' on {DEVICE}.")
110
+ return True
111
+ except Exception as e:
112
+ print(f"✘ FATAL: Error loading fine-tuned summarization model: {e}")
113
+ import traceback
114
+ traceback.print_exc()
115
+ return False
116
+
117
+ # --- MODIFIED summarize_text function ---
118
+ def summarize_text(tokenizer, model, text, num_beams=SUMM_NUM_BEAMS,
119
+ min_length_perc=MIN_LEN_PERC, max_length_perc=MAX_LEN_PERC):
120
+ """Generates abstractive summary with length based on input token percentage."""
121
+ if not text or text.isspace(): return "Input text is empty."
122
+ print("\nGenerating summary (using percentage lengths)...")
123
+ try:
124
+ # 1. Calculate input token length
125
+ with tokenizer.as_target_tokenizer():
126
+ input_ids_tensor = tokenizer(text, return_tensors="pt", truncation=False, padding=False).input_ids
127
+ input_token_count = input_ids_tensor.shape[1]
128
+ if input_token_count == 0: return "Input text tokenized to zero tokens."
129
+ print(f"INFO: Input text has approx {len(text.split())} words and {input_token_count} tokens.")
130
+
131
+ # 2. Calculate target token lengths
132
+ min_len_tokens = int(input_token_count * min_length_perc)
133
+ max_len_tokens = int(input_token_count * max_length_perc)
134
+
135
+ # 3. Apply absolute limits and ensure min < max
136
+ min_len_tokens = max(ABS_MIN_TOKEN_LEN, min_len_tokens)
137
+ max_len_tokens = max(min_len_tokens + 10, max_len_tokens)
138
+ max_len_tokens = min(ABS_MAX_TOKEN_LEN, max_len_tokens)
139
+ min_len_tokens = min(min_len_tokens, max_len_tokens)
140
+ print(f"INFO: Target summary token length: min={min_len_tokens}, max={max_len_tokens}.")
141
+
142
+ # 4. Tokenize for model input
143
+ inputs = tokenizer(text, max_length=1024, return_tensors="pt", padding="max_length", truncation=True).to(DEVICE)
144
+
145
+ # 5. Generate summary using CALCULATED min/max token lengths
146
+ print("INFO: Starting model.generate()...")
147
+
148
+ # --- *** THE FIX: Use explicit keyword argument 'input_ids=' *** ---
149
+ summary_ids = model.generate(
150
+ input_ids=inputs['input_ids'], # <<< Use keyword argument explicitly
151
+ # --- Other arguments remain keywords ---
152
+ num_beams=num_beams,
153
+ max_length=max_len_tokens,
154
+ min_length=min_len_tokens,
155
+ early_stopping=True
156
+ )
157
+ # --- *** End of Fix *** ---
158
+
159
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
160
+ print("✔ Summary generation complete.")
161
+ return summary
162
+ except TypeError as te: # Catch the specific error for better logging
163
+ print(f"✘ TypeError during summary generation: {te}")
164
+ print("✘ Error details: This often happens if generate() arguments are incorrect, check keyword vs positional.")
165
+ import traceback
166
+ traceback.print_exc()
167
+ return "[TypeError during summary generation - check arguments]"
168
+ except Exception as e:
169
+ print(f"✘ Error during summary generation: {e}")
170
+ import traceback
171
+ traceback.print_exc()
172
+ return "[Error generating summary]"
173
+
174
+
175
+ def extract_entities(ner_nlp, text):
176
+ """Extracts named entities using the spaCy NER model."""
177
+ if not text or text.isspace(): return []
178
+ print("\nExtracting entities from original text using custom NER model...")
179
+ try:
180
+ doc = ner_nlp(text)
181
+ entities = list({(ent.text.strip(), ent.label_) for ent in doc.ents if ent.text.strip()}) # Unique entities
182
+ print(f"✔ Extracted {len(entities)} unique entities.")
183
+ return entities
184
+ except Exception as e:
185
+ print(f"✘ Error during entity extraction: {e}")
186
+ return []
187
+
188
+ def create_prompted_input(text, entities):
189
+ """Creates a new input string with unique entities prepended."""
190
+ if not entities:
191
+ print("INFO: No entities found by NER, using original text for prompted summary.")
192
+ return text
193
+ unique_entity_texts = sorted(list({ent[0] for ent in entities if ent[0]}))
194
+ entity_string = ", ".join(unique_entity_texts)
195
+ separator = ". முக்கிய சொற்கள்: "
196
+ prompted_text = f"{entity_string}{separator}{text}"
197
+ print(f"\nINFO: Created prompted input with {len(unique_entity_texts)} unique entities (showing start): {prompted_text[:250]}...")
198
+ return prompted_text
199
+
200
+ # --- Main execution ---
201
+ def main():
202
+ global models_loaded # Access the flag
203
+ if not models_loaded:
204
+ print("✘ Models failed to load during startup. Cannot proceed.")
205
+ sys.exit(1)
206
+
207
+ print("\n" + "="*50)
208
+ print("Please paste the Tamil text paragraph you want to summarize below.")
209
+ print("Press Enter when finished.")
210
+ print("(You might need to configure your terminal for multi-line paste if it's long)")
211
+ print("-" * 50)
212
+ input_paragraph = input("Input Text:\n")
213
+
214
+ if not input_paragraph or input_paragraph.isspace():
215
+ print("\n✘ Error: No input text provided. Exiting.")
216
+ sys.exit(1)
217
+ text_to_process = input_paragraph.strip()
218
+
219
+ print("\n" + "="*50)
220
+ print("Processing Input Text (Snippet):")
221
+ print(text_to_process[:300] + "...")
222
+ print("="*50)
223
+
224
+ # --- Generate Output 1: Standard Summary (using FINE-TUNED model) ---
225
+ # Note: Even the "standard" summary now uses the fine-tuned model
226
+ print("\n--- Output 1: Standard Abstractive Summary (Fine-tuned Model) ---")
227
+ standard_summary = summarize_text(
228
+ summ_tokenizer_global, summ_model_global, text_to_process,
229
+ num_beams=SUMM_NUM_BEAMS
230
+ )
231
+ print("\nStandard Summary:")
232
+ print(standard_summary)
233
+ print("-" * 50)
234
+
235
+ # --- Generate Output 2: NER-Influenced Summary (using FINE-TUNED model) ---
236
+ print("\n--- Output 2: NER-Influenced Abstractive Summary (Fine-tuned Model) ---")
237
+ # a) Extract entities
238
+ extracted_entities = extract_entities(ner_model_global, text_to_process)
239
+ print("\nKey Entities Extracted by NER:")
240
+ if extracted_entities:
241
+ for text_ent, label in extracted_entities:
242
+ print(f" - '{text_ent}' ({label})")
243
+ else:
244
+ print(" No entities found by NER model.")
245
+
246
+ # b) Create prompted input
247
+ prompted_input_text = create_prompted_input(text_to_process, extracted_entities)
248
+
249
+ # c) Generate summary from prompted input
250
+ ner_influenced_summary = summarize_text(
251
+ summ_tokenizer_global, summ_model_global, prompted_input_text,
252
+ num_beams=SUMM_NUM_BEAMS
253
+ )
254
+ print("\nNER-Influenced Summary (Generated using entities as prefix):")
255
+ print(ner_influenced_summary)
256
+ print("\nNOTE: Compare this summary with the standard summary (Output 1).")
257
+ print("Fine-tuning might make both summaries better reflect your data's style.")
258
+ print("Prepending entities is still experimental for influencing inclusion.")
259
+ print("="*50)
260
+
261
+
262
+ if __name__ == "__main__":
263
+ # --- Load models globally when script starts ---
264
+ print("Application starting up... Loading models...")
265
+ # Load NER first, then summarizer (which might depend on NER path confirmation)
266
+ ner_loaded_ok = load_ner_model(NER_MODEL_PATH)
267
+ if ner_loaded_ok:
268
+ # Proceed to load summarizer only if NER loaded
269
+ summ_loaded_ok = load_finetuned_summarizer(BASE_SUMMARIZATION_MODEL_NAME, ADAPTER_PATH)
270
+ models_loaded = summ_loaded_ok # Overall success depends on summarizer loading
271
+ else:
272
+ models_loaded = False # NER failed, cannot proceed
273
+
274
+ if models_loaded:
275
+ print("\n--- All models loaded successfully! Ready for input. ---")
276
+ main() # Run the main interaction loop
277
+ else:
278
+ print("\n✘✘✘ CRITICAL ERROR: Model loading failed. Exiting. Check logs above. ✘✘✘")
279
+ sys.exit(1)