File size: 13,170 Bytes
b9285e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 |
# -*- coding: utf-8 -*-
import spacy
from pathlib import Path
import sys
import warnings
import re
import numpy as np
# --- Prerequisites ---
# Ensure these are installed in your .venv:
# pip install spacy transformers torch sentencepiece protobuf==3.20.3 peft accelerate datasets evaluate gradio numpy
# (Make sure spacy version matches your NER model training version)
try:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from peft import PeftModel, PeftConfig # Import PEFT classes
except ImportError as e:
print(f"✘ Error: Missing required library: {e}")
print("Please install all dependencies: pip install spacy transformers torch sentencepiece protobuf==3.20.3 peft accelerate datasets evaluate gradio numpy")
sys.exit(1)
# --- Configuration ---
# 1. Path to your trained spaCy NER model directory
NER_MODEL_PATH = Path("./training_400/model-best") # <-- ADJUST if different
# 2. Hugging Face model name for the BASE summarization model
BASE_SUMMARIZATION_MODEL_NAME = "csebuetnlp/mT5_multilingual_XLSum"
# 3. Path to your saved PEFT/LoRA adapter directory (output from fine-tuning)
ADAPTER_PATH = Path("./mt5_finetuned_tamil_summary") # <-- ADJUST if different
# 4. Device: "cuda" for GPU or "cpu"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 5. Summarization parameters
SUMM_NUM_BEAMS = 4
MIN_LEN_PERC = 0.30 # Target minimum summary length as % of input tokens
MAX_LEN_PERC = 0.70 # Target maximum summary length as % of input tokens (Increased)
ABS_MIN_TOKEN_LEN = 30 # Absolute minimum token length
ABS_MAX_TOKEN_LEN = 512 # Absolute maximum token length (Safer cap)
# --- End Configuration ---
# --- Suppress Warnings ---
warnings.filterwarnings("ignore", message="CUDA path could not be detected*")
warnings.filterwarnings("ignore", message=".*You are using `torch.load` with `weights_only=False`.*")
warnings.filterwarnings("ignore", message=".*The sentencepiece tokenizer that you are converting.*")
# --- Global Variables for Loaded Models ---
ner_model_global = None
summ_tokenizer_global = None
summ_model_global = None # This will hold the PEFT model
models_loaded = False
# --- Model Loading Functions ---
def load_ner_model(path):
"""Loads the spaCy NER model and ensures sentencizer is present."""
global ner_model_global
if not path.exists():
print(f"✘ FATAL: NER Model directory not found at {path.resolve()}")
return False
try:
ner_model_global = spacy.load(path)
print(f"✔ Successfully loaded NER model from: {path.resolve()}")
# Ensure a sentence boundary detector is present
component_to_add_before = None
if "tok2vec" in ner_model_global.pipe_names: component_to_add_before="tok2vec"
elif "ner" in ner_model_global.pipe_names: component_to_add_before="ner"
if not ner_model_global.has_pipe("sentencizer") and not ner_model_global.has_pipe("parser"):
try:
if component_to_add_before: ner_model_global.add_pipe("sentencizer", before=component_to_add_before)
else: ner_model_global.add_pipe("sentencizer", first=True)
print("INFO: Added 'sentencizer' to loaded NER pipeline.")
except Exception as e_pipe:
print(f"✘ WARNING: Could not add 'sentencizer': {e_pipe}.")
return True
except Exception as e:
print(f"✘ FATAL: Error loading NER model from {path.resolve()}: {e}")
return False
def load_finetuned_summarizer(base_model_name, adapter_dir_path):
"""Loads the base HF tokenizer/model and applies PEFT adapters."""
global summ_tokenizer_global, summ_model_global
if not adapter_dir_path.exists():
print(f"✘ FATAL: PEFT Adapter directory not found at {adapter_dir_path.resolve()}")
return False
try:
print(f"\nLoading base summarization tokenizer: {base_model_name}...")
summ_tokenizer_global = AutoTokenizer.from_pretrained(base_model_name)
print(f"Loading base summarization model: {base_model_name}...")
base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
print(f"Loading PEFT adapter from: {adapter_dir_path}...")
# Load the fine-tuned PEFT model by applying adapters to the base model
summ_model_global = PeftModel.from_pretrained(base_model, adapter_dir_path)
# Optional: Merge weights. This combines the adapter weights into the base model.
# It can make inference slightly faster but increases memory usage
# and you can no longer easily unload the adapter. Don't use if you plan
# to switch adapters or do more training later.
# print("INFO: Merging PEFT adapters into base model...")
# summ_model_global = summ_model_global.merge_and_unload()
# print("INFO: Adapters merged.")
summ_model_global.to(DEVICE)
print(f"INFO: Model's configured max generation length: {summ_model_global.config.max_length}") # Print base model's limit
print(f"✔ Successfully loaded fine-tuned PEFT model '{adapter_dir_path.name}' on base '{base_model_name}' on {DEVICE}.")
return True
except Exception as e:
print(f"✘ FATAL: Error loading fine-tuned summarization model: {e}")
import traceback
traceback.print_exc()
return False
# --- MODIFIED summarize_text function ---
def summarize_text(tokenizer, model, text, num_beams=SUMM_NUM_BEAMS,
min_length_perc=MIN_LEN_PERC, max_length_perc=MAX_LEN_PERC):
"""Generates abstractive summary with length based on input token percentage."""
if not text or text.isspace(): return "Input text is empty."
print("\nGenerating summary (using percentage lengths)...")
try:
# 1. Calculate input token length
with tokenizer.as_target_tokenizer():
input_ids_tensor = tokenizer(text, return_tensors="pt", truncation=False, padding=False).input_ids
input_token_count = input_ids_tensor.shape[1]
if input_token_count == 0: return "Input text tokenized to zero tokens."
print(f"INFO: Input text has approx {len(text.split())} words and {input_token_count} tokens.")
# 2. Calculate target token lengths
min_len_tokens = int(input_token_count * min_length_perc)
max_len_tokens = int(input_token_count * max_length_perc)
# 3. Apply absolute limits and ensure min < max
min_len_tokens = max(ABS_MIN_TOKEN_LEN, min_len_tokens)
max_len_tokens = max(min_len_tokens + 10, max_len_tokens)
max_len_tokens = min(ABS_MAX_TOKEN_LEN, max_len_tokens)
min_len_tokens = min(min_len_tokens, max_len_tokens)
print(f"INFO: Target summary token length: min={min_len_tokens}, max={max_len_tokens}.")
# 4. Tokenize for model input
inputs = tokenizer(text, max_length=1024, return_tensors="pt", padding="max_length", truncation=True).to(DEVICE)
# 5. Generate summary using CALCULATED min/max token lengths
print("INFO: Starting model.generate()...")
# --- *** THE FIX: Use explicit keyword argument 'input_ids=' *** ---
summary_ids = model.generate(
input_ids=inputs['input_ids'], # <<< Use keyword argument explicitly
# --- Other arguments remain keywords ---
num_beams=num_beams,
max_length=max_len_tokens,
min_length=min_len_tokens,
early_stopping=True
)
# --- *** End of Fix *** ---
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
print("✔ Summary generation complete.")
return summary
except TypeError as te: # Catch the specific error for better logging
print(f"✘ TypeError during summary generation: {te}")
print("✘ Error details: This often happens if generate() arguments are incorrect, check keyword vs positional.")
import traceback
traceback.print_exc()
return "[TypeError during summary generation - check arguments]"
except Exception as e:
print(f"✘ Error during summary generation: {e}")
import traceback
traceback.print_exc()
return "[Error generating summary]"
def extract_entities(ner_nlp, text):
"""Extracts named entities using the spaCy NER model."""
if not text or text.isspace(): return []
print("\nExtracting entities from original text using custom NER model...")
try:
doc = ner_nlp(text)
entities = list({(ent.text.strip(), ent.label_) for ent in doc.ents if ent.text.strip()}) # Unique entities
print(f"✔ Extracted {len(entities)} unique entities.")
return entities
except Exception as e:
print(f"✘ Error during entity extraction: {e}")
return []
def create_prompted_input(text, entities):
"""Creates a new input string with unique entities prepended."""
if not entities:
print("INFO: No entities found by NER, using original text for prompted summary.")
return text
unique_entity_texts = sorted(list({ent[0] for ent in entities if ent[0]}))
entity_string = ", ".join(unique_entity_texts)
separator = ". முக்கிய சொற்கள்: "
prompted_text = f"{entity_string}{separator}{text}"
print(f"\nINFO: Created prompted input with {len(unique_entity_texts)} unique entities (showing start): {prompted_text[:250]}...")
return prompted_text
# --- Main execution ---
def main():
global models_loaded # Access the flag
if not models_loaded:
print("✘ Models failed to load during startup. Cannot proceed.")
sys.exit(1)
print("\n" + "="*50)
print("Please paste the Tamil text paragraph you want to summarize below.")
print("Press Enter when finished.")
print("(You might need to configure your terminal for multi-line paste if it's long)")
print("-" * 50)
input_paragraph = input("Input Text:\n")
if not input_paragraph or input_paragraph.isspace():
print("\n✘ Error: No input text provided. Exiting.")
sys.exit(1)
text_to_process = input_paragraph.strip()
print("\n" + "="*50)
print("Processing Input Text (Snippet):")
print(text_to_process[:300] + "...")
print("="*50)
# --- Generate Output 1: Standard Summary (using FINE-TUNED model) ---
# Note: Even the "standard" summary now uses the fine-tuned model
print("\n--- Output 1: Standard Abstractive Summary (Fine-tuned Model) ---")
standard_summary = summarize_text(
summ_tokenizer_global, summ_model_global, text_to_process,
num_beams=SUMM_NUM_BEAMS
)
print("\nStandard Summary:")
print(standard_summary)
print("-" * 50)
# --- Generate Output 2: NER-Influenced Summary (using FINE-TUNED model) ---
print("\n--- Output 2: NER-Influenced Abstractive Summary (Fine-tuned Model) ---")
# a) Extract entities
extracted_entities = extract_entities(ner_model_global, text_to_process)
print("\nKey Entities Extracted by NER:")
if extracted_entities:
for text_ent, label in extracted_entities:
print(f" - '{text_ent}' ({label})")
else:
print(" No entities found by NER model.")
# b) Create prompted input
prompted_input_text = create_prompted_input(text_to_process, extracted_entities)
# c) Generate summary from prompted input
ner_influenced_summary = summarize_text(
summ_tokenizer_global, summ_model_global, prompted_input_text,
num_beams=SUMM_NUM_BEAMS
)
print("\nNER-Influenced Summary (Generated using entities as prefix):")
print(ner_influenced_summary)
print("\nNOTE: Compare this summary with the standard summary (Output 1).")
print("Fine-tuning might make both summaries better reflect your data's style.")
print("Prepending entities is still experimental for influencing inclusion.")
print("="*50)
if __name__ == "__main__":
# --- Load models globally when script starts ---
print("Application starting up... Loading models...")
# Load NER first, then summarizer (which might depend on NER path confirmation)
ner_loaded_ok = load_ner_model(NER_MODEL_PATH)
if ner_loaded_ok:
# Proceed to load summarizer only if NER loaded
summ_loaded_ok = load_finetuned_summarizer(BASE_SUMMARIZATION_MODEL_NAME, ADAPTER_PATH)
models_loaded = summ_loaded_ok # Overall success depends on summarizer loading
else:
models_loaded = False # NER failed, cannot proceed
if models_loaded:
print("\n--- All models loaded successfully! Ready for input. ---")
main() # Run the main interaction loop
else:
print("\n✘✘✘ CRITICAL ERROR: Model loading failed. Exiting. Check logs above. ✘✘✘")
sys.exit(1) |