Tamil_txt_Summarisation_NER / mt5_finetuned_summary.py
Nivas007's picture
Upload mt5_finetuned_summary.py
b9285e0 verified
# -*- coding: utf-8 -*-
import spacy
from pathlib import Path
import sys
import warnings
import re
import numpy as np
# --- Prerequisites ---
# Ensure these are installed in your .venv:
# pip install spacy transformers torch sentencepiece protobuf==3.20.3 peft accelerate datasets evaluate gradio numpy
# (Make sure spacy version matches your NER model training version)
try:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from peft import PeftModel, PeftConfig # Import PEFT classes
except ImportError as e:
print(f"✘ Error: Missing required library: {e}")
print("Please install all dependencies: pip install spacy transformers torch sentencepiece protobuf==3.20.3 peft accelerate datasets evaluate gradio numpy")
sys.exit(1)
# --- Configuration ---
# 1. Path to your trained spaCy NER model directory
NER_MODEL_PATH = Path("./training_400/model-best") # <-- ADJUST if different
# 2. Hugging Face model name for the BASE summarization model
BASE_SUMMARIZATION_MODEL_NAME = "csebuetnlp/mT5_multilingual_XLSum"
# 3. Path to your saved PEFT/LoRA adapter directory (output from fine-tuning)
ADAPTER_PATH = Path("./mt5_finetuned_tamil_summary") # <-- ADJUST if different
# 4. Device: "cuda" for GPU or "cpu"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 5. Summarization parameters
SUMM_NUM_BEAMS = 4
MIN_LEN_PERC = 0.30 # Target minimum summary length as % of input tokens
MAX_LEN_PERC = 0.70 # Target maximum summary length as % of input tokens (Increased)
ABS_MIN_TOKEN_LEN = 30 # Absolute minimum token length
ABS_MAX_TOKEN_LEN = 512 # Absolute maximum token length (Safer cap)
# --- End Configuration ---
# --- Suppress Warnings ---
warnings.filterwarnings("ignore", message="CUDA path could not be detected*")
warnings.filterwarnings("ignore", message=".*You are using `torch.load` with `weights_only=False`.*")
warnings.filterwarnings("ignore", message=".*The sentencepiece tokenizer that you are converting.*")
# --- Global Variables for Loaded Models ---
ner_model_global = None
summ_tokenizer_global = None
summ_model_global = None # This will hold the PEFT model
models_loaded = False
# --- Model Loading Functions ---
def load_ner_model(path):
"""Loads the spaCy NER model and ensures sentencizer is present."""
global ner_model_global
if not path.exists():
print(f"✘ FATAL: NER Model directory not found at {path.resolve()}")
return False
try:
ner_model_global = spacy.load(path)
print(f"✔ Successfully loaded NER model from: {path.resolve()}")
# Ensure a sentence boundary detector is present
component_to_add_before = None
if "tok2vec" in ner_model_global.pipe_names: component_to_add_before="tok2vec"
elif "ner" in ner_model_global.pipe_names: component_to_add_before="ner"
if not ner_model_global.has_pipe("sentencizer") and not ner_model_global.has_pipe("parser"):
try:
if component_to_add_before: ner_model_global.add_pipe("sentencizer", before=component_to_add_before)
else: ner_model_global.add_pipe("sentencizer", first=True)
print("INFO: Added 'sentencizer' to loaded NER pipeline.")
except Exception as e_pipe:
print(f"✘ WARNING: Could not add 'sentencizer': {e_pipe}.")
return True
except Exception as e:
print(f"✘ FATAL: Error loading NER model from {path.resolve()}: {e}")
return False
def load_finetuned_summarizer(base_model_name, adapter_dir_path):
"""Loads the base HF tokenizer/model and applies PEFT adapters."""
global summ_tokenizer_global, summ_model_global
if not adapter_dir_path.exists():
print(f"✘ FATAL: PEFT Adapter directory not found at {adapter_dir_path.resolve()}")
return False
try:
print(f"\nLoading base summarization tokenizer: {base_model_name}...")
summ_tokenizer_global = AutoTokenizer.from_pretrained(base_model_name)
print(f"Loading base summarization model: {base_model_name}...")
base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
print(f"Loading PEFT adapter from: {adapter_dir_path}...")
# Load the fine-tuned PEFT model by applying adapters to the base model
summ_model_global = PeftModel.from_pretrained(base_model, adapter_dir_path)
# Optional: Merge weights. This combines the adapter weights into the base model.
# It can make inference slightly faster but increases memory usage
# and you can no longer easily unload the adapter. Don't use if you plan
# to switch adapters or do more training later.
# print("INFO: Merging PEFT adapters into base model...")
# summ_model_global = summ_model_global.merge_and_unload()
# print("INFO: Adapters merged.")
summ_model_global.to(DEVICE)
print(f"INFO: Model's configured max generation length: {summ_model_global.config.max_length}") # Print base model's limit
print(f"✔ Successfully loaded fine-tuned PEFT model '{adapter_dir_path.name}' on base '{base_model_name}' on {DEVICE}.")
return True
except Exception as e:
print(f"✘ FATAL: Error loading fine-tuned summarization model: {e}")
import traceback
traceback.print_exc()
return False
# --- MODIFIED summarize_text function ---
def summarize_text(tokenizer, model, text, num_beams=SUMM_NUM_BEAMS,
min_length_perc=MIN_LEN_PERC, max_length_perc=MAX_LEN_PERC):
"""Generates abstractive summary with length based on input token percentage."""
if not text or text.isspace(): return "Input text is empty."
print("\nGenerating summary (using percentage lengths)...")
try:
# 1. Calculate input token length
with tokenizer.as_target_tokenizer():
input_ids_tensor = tokenizer(text, return_tensors="pt", truncation=False, padding=False).input_ids
input_token_count = input_ids_tensor.shape[1]
if input_token_count == 0: return "Input text tokenized to zero tokens."
print(f"INFO: Input text has approx {len(text.split())} words and {input_token_count} tokens.")
# 2. Calculate target token lengths
min_len_tokens = int(input_token_count * min_length_perc)
max_len_tokens = int(input_token_count * max_length_perc)
# 3. Apply absolute limits and ensure min < max
min_len_tokens = max(ABS_MIN_TOKEN_LEN, min_len_tokens)
max_len_tokens = max(min_len_tokens + 10, max_len_tokens)
max_len_tokens = min(ABS_MAX_TOKEN_LEN, max_len_tokens)
min_len_tokens = min(min_len_tokens, max_len_tokens)
print(f"INFO: Target summary token length: min={min_len_tokens}, max={max_len_tokens}.")
# 4. Tokenize for model input
inputs = tokenizer(text, max_length=1024, return_tensors="pt", padding="max_length", truncation=True).to(DEVICE)
# 5. Generate summary using CALCULATED min/max token lengths
print("INFO: Starting model.generate()...")
# --- *** THE FIX: Use explicit keyword argument 'input_ids=' *** ---
summary_ids = model.generate(
input_ids=inputs['input_ids'], # <<< Use keyword argument explicitly
# --- Other arguments remain keywords ---
num_beams=num_beams,
max_length=max_len_tokens,
min_length=min_len_tokens,
early_stopping=True
)
# --- *** End of Fix *** ---
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
print("✔ Summary generation complete.")
return summary
except TypeError as te: # Catch the specific error for better logging
print(f"✘ TypeError during summary generation: {te}")
print("✘ Error details: This often happens if generate() arguments are incorrect, check keyword vs positional.")
import traceback
traceback.print_exc()
return "[TypeError during summary generation - check arguments]"
except Exception as e:
print(f"✘ Error during summary generation: {e}")
import traceback
traceback.print_exc()
return "[Error generating summary]"
def extract_entities(ner_nlp, text):
"""Extracts named entities using the spaCy NER model."""
if not text or text.isspace(): return []
print("\nExtracting entities from original text using custom NER model...")
try:
doc = ner_nlp(text)
entities = list({(ent.text.strip(), ent.label_) for ent in doc.ents if ent.text.strip()}) # Unique entities
print(f"✔ Extracted {len(entities)} unique entities.")
return entities
except Exception as e:
print(f"✘ Error during entity extraction: {e}")
return []
def create_prompted_input(text, entities):
"""Creates a new input string with unique entities prepended."""
if not entities:
print("INFO: No entities found by NER, using original text for prompted summary.")
return text
unique_entity_texts = sorted(list({ent[0] for ent in entities if ent[0]}))
entity_string = ", ".join(unique_entity_texts)
separator = ". முக்கிய சொற்கள்: "
prompted_text = f"{entity_string}{separator}{text}"
print(f"\nINFO: Created prompted input with {len(unique_entity_texts)} unique entities (showing start): {prompted_text[:250]}...")
return prompted_text
# --- Main execution ---
def main():
global models_loaded # Access the flag
if not models_loaded:
print("✘ Models failed to load during startup. Cannot proceed.")
sys.exit(1)
print("\n" + "="*50)
print("Please paste the Tamil text paragraph you want to summarize below.")
print("Press Enter when finished.")
print("(You might need to configure your terminal for multi-line paste if it's long)")
print("-" * 50)
input_paragraph = input("Input Text:\n")
if not input_paragraph or input_paragraph.isspace():
print("\n✘ Error: No input text provided. Exiting.")
sys.exit(1)
text_to_process = input_paragraph.strip()
print("\n" + "="*50)
print("Processing Input Text (Snippet):")
print(text_to_process[:300] + "...")
print("="*50)
# --- Generate Output 1: Standard Summary (using FINE-TUNED model) ---
# Note: Even the "standard" summary now uses the fine-tuned model
print("\n--- Output 1: Standard Abstractive Summary (Fine-tuned Model) ---")
standard_summary = summarize_text(
summ_tokenizer_global, summ_model_global, text_to_process,
num_beams=SUMM_NUM_BEAMS
)
print("\nStandard Summary:")
print(standard_summary)
print("-" * 50)
# --- Generate Output 2: NER-Influenced Summary (using FINE-TUNED model) ---
print("\n--- Output 2: NER-Influenced Abstractive Summary (Fine-tuned Model) ---")
# a) Extract entities
extracted_entities = extract_entities(ner_model_global, text_to_process)
print("\nKey Entities Extracted by NER:")
if extracted_entities:
for text_ent, label in extracted_entities:
print(f" - '{text_ent}' ({label})")
else:
print(" No entities found by NER model.")
# b) Create prompted input
prompted_input_text = create_prompted_input(text_to_process, extracted_entities)
# c) Generate summary from prompted input
ner_influenced_summary = summarize_text(
summ_tokenizer_global, summ_model_global, prompted_input_text,
num_beams=SUMM_NUM_BEAMS
)
print("\nNER-Influenced Summary (Generated using entities as prefix):")
print(ner_influenced_summary)
print("\nNOTE: Compare this summary with the standard summary (Output 1).")
print("Fine-tuning might make both summaries better reflect your data's style.")
print("Prepending entities is still experimental for influencing inclusion.")
print("="*50)
if __name__ == "__main__":
# --- Load models globally when script starts ---
print("Application starting up... Loading models...")
# Load NER first, then summarizer (which might depend on NER path confirmation)
ner_loaded_ok = load_ner_model(NER_MODEL_PATH)
if ner_loaded_ok:
# Proceed to load summarizer only if NER loaded
summ_loaded_ok = load_finetuned_summarizer(BASE_SUMMARIZATION_MODEL_NAME, ADAPTER_PATH)
models_loaded = summ_loaded_ok # Overall success depends on summarizer loading
else:
models_loaded = False # NER failed, cannot proceed
if models_loaded:
print("\n--- All models loaded successfully! Ready for input. ---")
main() # Run the main interaction loop
else:
print("\n✘✘✘ CRITICAL ERROR: Model loading failed. Exiting. Check logs above. ✘✘✘")
sys.exit(1)