File size: 13,170 Bytes
b9285e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
# -*- coding: utf-8 -*-

import spacy
from pathlib import Path
import sys
import warnings
import re
import numpy as np

# --- Prerequisites ---
# Ensure these are installed in your .venv:
# pip install spacy transformers torch sentencepiece protobuf==3.20.3 peft accelerate datasets evaluate gradio numpy
# (Make sure spacy version matches your NER model training version)
try:
    from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
    import torch
    from peft import PeftModel, PeftConfig # Import PEFT classes
except ImportError as e:
    print(f"✘ Error: Missing required library: {e}")
    print("Please install all dependencies: pip install spacy transformers torch sentencepiece protobuf==3.20.3 peft accelerate datasets evaluate gradio numpy")
    sys.exit(1)


# --- Configuration ---
# 1. Path to your trained spaCy NER model directory
NER_MODEL_PATH = Path("./training_400/model-best") # <-- ADJUST if different

# 2. Hugging Face model name for the BASE summarization model
BASE_SUMMARIZATION_MODEL_NAME = "csebuetnlp/mT5_multilingual_XLSum"

# 3. Path to your saved PEFT/LoRA adapter directory (output from fine-tuning)
ADAPTER_PATH = Path("./mt5_finetuned_tamil_summary") # <-- ADJUST if different

# 4. Device: "cuda" for GPU or "cpu"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# 5. Summarization parameters
SUMM_NUM_BEAMS = 4
MIN_LEN_PERC = 0.30 # Target minimum summary length as % of input tokens
MAX_LEN_PERC = 0.70 # Target maximum summary length as % of input tokens (Increased)
ABS_MIN_TOKEN_LEN = 30   # Absolute minimum token length
ABS_MAX_TOKEN_LEN = 512  # Absolute maximum token length (Safer cap)
# --- End Configuration ---

# --- Suppress Warnings ---
warnings.filterwarnings("ignore", message="CUDA path could not be detected*")
warnings.filterwarnings("ignore", message=".*You are using `torch.load` with `weights_only=False`.*")
warnings.filterwarnings("ignore", message=".*The sentencepiece tokenizer that you are converting.*")

# --- Global Variables for Loaded Models ---
ner_model_global = None
summ_tokenizer_global = None
summ_model_global = None # This will hold the PEFT model
models_loaded = False

# --- Model Loading Functions ---
def load_ner_model(path):
    """Loads the spaCy NER model and ensures sentencizer is present."""
    global ner_model_global
    if not path.exists():
        print(f"✘ FATAL: NER Model directory not found at {path.resolve()}")
        return False
    try:
        ner_model_global = spacy.load(path)
        print(f"✔ Successfully loaded NER model from: {path.resolve()}")
        # Ensure a sentence boundary detector is present
        component_to_add_before = None
        if "tok2vec" in ner_model_global.pipe_names: component_to_add_before="tok2vec"
        elif "ner" in ner_model_global.pipe_names: component_to_add_before="ner"
        if not ner_model_global.has_pipe("sentencizer") and not ner_model_global.has_pipe("parser"):
            try:
                if component_to_add_before: ner_model_global.add_pipe("sentencizer", before=component_to_add_before)
                else: ner_model_global.add_pipe("sentencizer", first=True)
                print("INFO: Added 'sentencizer' to loaded NER pipeline.")
            except Exception as e_pipe:
                print(f"✘ WARNING: Could not add 'sentencizer': {e_pipe}.")
        return True
    except Exception as e:
        print(f"✘ FATAL: Error loading NER model from {path.resolve()}: {e}")
        return False

def load_finetuned_summarizer(base_model_name, adapter_dir_path):
    """Loads the base HF tokenizer/model and applies PEFT adapters."""
    global summ_tokenizer_global, summ_model_global
    if not adapter_dir_path.exists():
        print(f"✘ FATAL: PEFT Adapter directory not found at {adapter_dir_path.resolve()}")
        return False
    try:
        print(f"\nLoading base summarization tokenizer: {base_model_name}...")
        summ_tokenizer_global = AutoTokenizer.from_pretrained(base_model_name)

        print(f"Loading base summarization model: {base_model_name}...")
        base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)

        print(f"Loading PEFT adapter from: {adapter_dir_path}...")
        # Load the fine-tuned PEFT model by applying adapters to the base model
        summ_model_global = PeftModel.from_pretrained(base_model, adapter_dir_path)

        # Optional: Merge weights. This combines the adapter weights into the base model.
        # It can make inference slightly faster but increases memory usage
        # and you can no longer easily unload the adapter. Don't use if you plan
        # to switch adapters or do more training later.
        # print("INFO: Merging PEFT adapters into base model...")
        # summ_model_global = summ_model_global.merge_and_unload()
        # print("INFO: Adapters merged.")

        summ_model_global.to(DEVICE)
        print(f"INFO: Model's configured max generation length: {summ_model_global.config.max_length}") # Print base model's limit
        print(f"✔ Successfully loaded fine-tuned PEFT model '{adapter_dir_path.name}' on base '{base_model_name}' on {DEVICE}.")
        return True
    except Exception as e:
        print(f"✘ FATAL: Error loading fine-tuned summarization model: {e}")
        import traceback
        traceback.print_exc()
        return False

# --- MODIFIED summarize_text function ---
def summarize_text(tokenizer, model, text, num_beams=SUMM_NUM_BEAMS,

                   min_length_perc=MIN_LEN_PERC, max_length_perc=MAX_LEN_PERC):
    """Generates abstractive summary with length based on input token percentage."""
    if not text or text.isspace(): return "Input text is empty."
    print("\nGenerating summary (using percentage lengths)...")
    try:
        # 1. Calculate input token length
        with tokenizer.as_target_tokenizer():
             input_ids_tensor = tokenizer(text, return_tensors="pt", truncation=False, padding=False).input_ids
        input_token_count = input_ids_tensor.shape[1]
        if input_token_count == 0: return "Input text tokenized to zero tokens."
        print(f"INFO: Input text has approx {len(text.split())} words and {input_token_count} tokens.")

        # 2. Calculate target token lengths
        min_len_tokens = int(input_token_count * min_length_perc)
        max_len_tokens = int(input_token_count * max_length_perc)

        # 3. Apply absolute limits and ensure min < max
        min_len_tokens = max(ABS_MIN_TOKEN_LEN, min_len_tokens)
        max_len_tokens = max(min_len_tokens + 10, max_len_tokens)
        max_len_tokens = min(ABS_MAX_TOKEN_LEN, max_len_tokens)
        min_len_tokens = min(min_len_tokens, max_len_tokens)
        print(f"INFO: Target summary token length: min={min_len_tokens}, max={max_len_tokens}.")

        # 4. Tokenize for model input
        inputs = tokenizer(text, max_length=1024, return_tensors="pt", padding="max_length", truncation=True).to(DEVICE)

        # 5. Generate summary using CALCULATED min/max token lengths
        print("INFO: Starting model.generate()...")

        # --- *** THE FIX: Use explicit keyword argument 'input_ids=' *** ---
        summary_ids = model.generate(
            input_ids=inputs['input_ids'], # <<< Use keyword argument explicitly
            # --- Other arguments remain keywords ---
            num_beams=num_beams,
            max_length=max_len_tokens,
            min_length=min_len_tokens,
            early_stopping=True
        )
        # --- *** End of Fix *** ---

        summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
        print("✔ Summary generation complete.")
        return summary
    except TypeError as te: # Catch the specific error for better logging
         print(f"✘ TypeError during summary generation: {te}")
         print("✘ Error details: This often happens if generate() arguments are incorrect, check keyword vs positional.")
         import traceback
         traceback.print_exc()
         return "[TypeError during summary generation - check arguments]"
    except Exception as e:
        print(f"✘ Error during summary generation: {e}")
        import traceback
        traceback.print_exc()
        return "[Error generating summary]"


def extract_entities(ner_nlp, text):
    """Extracts named entities using the spaCy NER model."""
    if not text or text.isspace(): return []
    print("\nExtracting entities from original text using custom NER model...")
    try:
        doc = ner_nlp(text)
        entities = list({(ent.text.strip(), ent.label_) for ent in doc.ents if ent.text.strip()}) # Unique entities
        print(f"✔ Extracted {len(entities)} unique entities.")
        return entities
    except Exception as e:
        print(f"✘ Error during entity extraction: {e}")
        return []

def create_prompted_input(text, entities):
    """Creates a new input string with unique entities prepended."""
    if not entities:
        print("INFO: No entities found by NER, using original text for prompted summary.")
        return text
    unique_entity_texts = sorted(list({ent[0] for ent in entities if ent[0]}))
    entity_string = ", ".join(unique_entity_texts)
    separator = ". முக்கிய சொற்கள்: "
    prompted_text = f"{entity_string}{separator}{text}"
    print(f"\nINFO: Created prompted input with {len(unique_entity_texts)} unique entities (showing start): {prompted_text[:250]}...")
    return prompted_text

# --- Main execution ---
def main():
    global models_loaded # Access the flag
    if not models_loaded:
         print("✘ Models failed to load during startup. Cannot proceed.")
         sys.exit(1)

    print("\n" + "="*50)
    print("Please paste the Tamil text paragraph you want to summarize below.")
    print("Press Enter when finished.")
    print("(You might need to configure your terminal for multi-line paste if it's long)")
    print("-" * 50)
    input_paragraph = input("Input Text:\n")

    if not input_paragraph or input_paragraph.isspace():
        print("\n✘ Error: No input text provided. Exiting.")
        sys.exit(1)
    text_to_process = input_paragraph.strip()

    print("\n" + "="*50)
    print("Processing Input Text (Snippet):")
    print(text_to_process[:300] + "...")
    print("="*50)

    # --- Generate Output 1: Standard Summary (using FINE-TUNED model) ---
    # Note: Even the "standard" summary now uses the fine-tuned model
    print("\n--- Output 1: Standard Abstractive Summary (Fine-tuned Model) ---")
    standard_summary = summarize_text(
        summ_tokenizer_global, summ_model_global, text_to_process,
        num_beams=SUMM_NUM_BEAMS
    )
    print("\nStandard Summary:")
    print(standard_summary)
    print("-" * 50)

    # --- Generate Output 2: NER-Influenced Summary (using FINE-TUNED model) ---
    print("\n--- Output 2: NER-Influenced Abstractive Summary (Fine-tuned Model) ---")
    # a) Extract entities
    extracted_entities = extract_entities(ner_model_global, text_to_process)
    print("\nKey Entities Extracted by NER:")
    if extracted_entities:
        for text_ent, label in extracted_entities:
            print(f"  - '{text_ent}' ({label})")
    else:
        print("  No entities found by NER model.")

    # b) Create prompted input
    prompted_input_text = create_prompted_input(text_to_process, extracted_entities)

    # c) Generate summary from prompted input
    ner_influenced_summary = summarize_text(
        summ_tokenizer_global, summ_model_global, prompted_input_text,
        num_beams=SUMM_NUM_BEAMS
    )
    print("\nNER-Influenced Summary (Generated using entities as prefix):")
    print(ner_influenced_summary)
    print("\nNOTE: Compare this summary with the standard summary (Output 1).")
    print("Fine-tuning might make both summaries better reflect your data's style.")
    print("Prepending entities is still experimental for influencing inclusion.")
    print("="*50)


if __name__ == "__main__":
    # --- Load models globally when script starts ---
    print("Application starting up... Loading models...")
    # Load NER first, then summarizer (which might depend on NER path confirmation)
    ner_loaded_ok = load_ner_model(NER_MODEL_PATH)
    if ner_loaded_ok:
        # Proceed to load summarizer only if NER loaded
        summ_loaded_ok = load_finetuned_summarizer(BASE_SUMMARIZATION_MODEL_NAME, ADAPTER_PATH)
        models_loaded = summ_loaded_ok # Overall success depends on summarizer loading
    else:
        models_loaded = False # NER failed, cannot proceed

    if models_loaded:
        print("\n--- All models loaded successfully! Ready for input. ---")
        main() # Run the main interaction loop
    else:
        print("\n✘✘✘ CRITICAL ERROR: Model loading failed. Exiting. Check logs above. ✘✘✘")
        sys.exit(1)