Nivas007's picture
Update app.py
99101d2 verified
# -*- coding: utf-8 -*-
# --- Prerequisites ---
# Ensure these are in your requirements.txt for Hugging Face Spaces:
# spacy==3.5.0 # Or the version used to train NER model
# streamlit>=1.0.0
# transformers>=4.20.0
# torch>=1.10.0 # Or tensorflow
# sentencepiece>=0.1.90
# protobuf==3.20.3
# peft>=0.5.0 # Parameter-Efficient Fine-Tuning library
# accelerate>=0.26.0
# numpy
# nltk # For ROUGE metric calculation during fine-tuning (needed for postprocess_text if kept)
# bitsandbytes # If using 8-bit optimizer
import streamlit as st
import spacy
from pathlib import Path
import sys
import torch
import warnings
import re
import numpy as np
try:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
import nltk
nltk.download('punkt', quiet=True) # Ensure punkt tokenizer is available for potential NLTK use
print("✔ Successfully imported core libraries.")
except ImportError as e:
# Display error in the Streamlit app itself if imports fail during runtime
st.error(f"Error importing libraries: {e}. Please check requirements.txt and ensure all packages are installed.")
st.stop() # Stop execution if libraries are missing
# --- Configuration ---
# Use paths relative to this app.py script
NER_MODEL_PATH = Path("./training_400") # Assumes model-best folder is at the repo root
BASE_SUMMARIZATION_MODEL = "csebuetnlp/mT5_multilingual_XLSum"
ADAPTER_PATH = Path("./mt5_finetuned_tamil_summary") # Path to your fine-tuned adapters
# Device Selection
DEVICE = "cpu" # Default to CPU for broader compatibility on free tiers
if torch.cuda.is_available():
print("INFO: CUDA device detected. Setting DEVICE to 'cuda'.")
DEVICE = "cuda"
else:
print("INFO: No CUDA device detected. Using CPU.")
# Summarization parameters
SUMM_NUM_BEAMS = 4
MIN_LEN_PERC = 0.30
MAX_LEN_PERC = 0.70
ABS_MIN_TOKEN_LEN = 30
ABS_MAX_TOKEN_LEN = 512
# --- End Configuration ---
# --- Suppress Warnings ---
warnings.filterwarnings("ignore", message="CUDA path could not be detected*")
warnings.filterwarnings("ignore", message=".*You are using `torch.load` with `weights_only=False`.*")
warnings.filterwarnings("ignore", message=".*The sentencepiece tokenizer that you are converting.*")
# --- Global Variables & Model Loading Control ---
ner_model_global = None
summ_tokenizer_global = None
summ_model_global = None
models_loaded_status = "Not Loaded" # More descriptive status
# --- Model Loading with Streamlit Caching ---
@st.cache_resource # Loads only once per browser session
def load_ner_model_cached(path):
"""Loads the spaCy NER model."""
global models_loaded_status
models_loaded_status = f"Loading NER model from: {path}..."
st.info(models_loaded_status)
if not path.exists():
st.error(f"NER Model directory not found at {path.resolve()}")
models_loaded_status = "Error: NER Model Not Found"
return None
try:
nlp = spacy.load(path)
# Add sentencizer if needed (crucial for sentence splitting later)
if not nlp.has_pipe("sentencizer") and not nlp.has_pipe("parser"):
component_to_add_before = "ner" if "ner" in nlp.pipe_names else "tok2vec" if "tok2vec" in nlp.pipe_names else None
if component_to_add_before: nlp.add_pipe("sentencizer", before=component_to_add_before)
else: nlp.add_pipe("sentencizer", first=True)
print("INFO: Added 'sentencizer' to NER pipeline.")
print(f"✔ NER model loaded from: {path}")
return nlp
except Exception as e:
st.error(f"Error loading NER model: {e}")
models_loaded_status = f"Error Loading NER Model: {e}"
return None
@st.cache_resource # Loads only once per browser session
def load_summarizer_cached(base_model_name, adapter_path, device):
"""Loads the Hugging Face base model and applies PEFT adapter."""
global models_loaded_status
models_loaded_status = f"Loading Summarizer (Base: {base_model_name}, Adapter: {adapter_path})..."
st.info(models_loaded_status)
try:
print(f"┣ Loading base tokenizer: {base_model_name}...")
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
print(f"┣ Loading base model: {base_model_name}...")
base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
print(f"┣ Loading PEFT adapter from: {adapter_path}...")
if not adapter_path.exists():
st.error(f"✘ FATAL: PEFT Adapter directory not found at {adapter_path.resolve()}. Using BASE model only.")
model = base_model # Fallback to base model
else:
model = PeftModel.from_pretrained(base_model, adapter_path)
print(f"✔ Successfully loaded PEFT adapter.")
print(f"┣ Moving summarization model to {device}...")
model.to(device)
model.eval() # Set to evaluation mode
print(f"✔ Summarization model loaded on {device}.")
return tokenizer, model
except Exception as e:
st.error(f"Error loading summarization model: {e}")
print(f"✘ FATAL: Error loading summarization model: {e}")
import traceback
traceback.print_exc()
models_loaded_status = f"Error Loading Summarizer: {e}"
return None, None
# --- Helper Functions ---
def summarize_text_internal(tokenizer, model, text, device, num_beams=SUMM_NUM_BEAMS,
min_length_perc=MIN_LEN_PERC, max_length_perc=MAX_LEN_PERC):
"""Internal function to generate summary."""
if not text or text.isspace(): return "[Error: Input text is empty]"
# Ensure models are loaded before proceeding
if not tokenizer or not model: return "[Error: Summarization model not ready]"
print("INFO: Generating summary (percentage lengths)...")
try:
# Calculate lengths
with tokenizer.as_target_tokenizer():
input_ids = tokenizer(text, return_tensors="pt", truncation=False, padding=False).input_ids
input_token_count = input_ids.shape[1]
if input_token_count == 0: return "[Error: Input tokenized to zero tokens]"
min_len_tokens = max(ABS_MIN_TOKEN_LEN, int(input_token_count * min_length_perc))
max_len_tokens = max(min_len_tokens + 10, int(input_token_count * max_length_perc))
max_len_tokens = min(ABS_MAX_TOKEN_LEN, max_len_tokens)
min_len_tokens = min(min_len_tokens, max_len_tokens)
print(f"INFO: Target summary tokens: min={min_len_tokens}, max={max_len_tokens}")
# Tokenize for input
inputs = tokenizer(text, max_length=1024, return_tensors="pt", padding="max_length", truncation=True).to(device)
# Generate
with torch.no_grad():
summary_ids = model.generate(
input_ids=inputs['input_ids'],
num_beams=num_beams,
max_length=max_len_tokens,
min_length=min_len_tokens,
early_stopping=True
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
print("✔ Summary generation complete.")
return summary
except Exception as e:
st.error(f"Error during summary generation: {e}")
print(f"✘ Error during summary generation: {e}")
import traceback
traceback.print_exc()
return f"[Error generating summary: {e}]"
def extract_entities_internal(ner_nlp, text):
"""Extracts entities and formats them as a markdown string."""
if not text or text.isspace(): return [], "- No input text -"
if ner_nlp is None: return [], "[Error: NER model not loaded]"
print("INFO: Extracting entities...")
try:
doc = ner_nlp(text)
entities = list({(ent.text.strip(), ent.label_) for ent in doc.ents if ent.text.strip()})
print(f"✔ Extracted {len(entities)} unique entities.")
if entities:
# Format as Markdown list
entity_list_str = "\n".join([f"- **{lbl}:** {txt}" for txt, lbl in sorted(entities, key=lambda x: x[1])]) # Sort by label
else:
entity_list_str = "(No entities found by NER model)"
return entities, entity_list_str
except Exception as e:
st.error(f"Error during entity extraction: {e}")
print(f"✘ Error during entity extraction: {e}")
return [], "[Error extracting entities]"
def create_prompted_input_internal(text, entities):
"""Creates input string with unique entities prepended."""
if not entities: return text
if not isinstance(text, str): return "[Invalid Input Text]"
unique_entity_texts = sorted(list({ent[0] for ent in entities if ent[0]}))
entity_string = ", ".join(unique_entity_texts)
separator = ". முக்கிய சொற்கள்: "
prompted_text = f"{entity_string}{separator}{text}"
print(f"INFO: Created prompted input with {len(unique_entity_texts)} unique entities.")
return prompted_text
# --- Streamlit App Layout ---
st.set_page_config(layout="wide", page_title="Tamil NER Summarizer", page_icon="✍️")
st.title("தமிழ் செய்தி சுருக்கம் மற்றும் NER ஒருங்கிணைப்பு")
st.markdown("*(Tamil News Summarization with NER Integration)*")
st.markdown("---")
# --- Load Models ---
# Trigger loading models using the cached functions
# Assign to global variables if loading is successful
ner_model_global = load_ner_model_cached(NER_MODEL_PATH)
summ_tokenizer_global, summ_model_global = load_summarizer_cached(BASE_SUMMARIZATION_MODEL, ADAPTER_PATH, DEVICE)
# Check if models loaded successfully before proceeding
models_ready = ner_model_global is not None and summ_tokenizer_global is not None and summ_model_global is not None
if not models_ready:
st.error("One or more essential models failed to load. Please check the application logs (terminal/HF Spaces logs) for details. The app cannot function.")
st.stop() # Stop the app if models aren't ready
else:
st.sidebar.success(f"Models loaded successfully on {DEVICE.upper()}!")
st.sidebar.markdown(f"**NER Model:** `{NER_MODEL_PATH.name}`")
st.sidebar.markdown(f"**Summarizer:** `{BASE_SUMMARIZATION_MODEL}` + Adapter")
# --- Input Area ---
st.header(" உள்ளீடு / Input")
input_text = st.text_area("உங்கள் தமிழ் உரையை இங்கே ஒட்டவும் (Paste your Tamil text here):", height=300, key="input_text_area")
# --- Processing Trigger ---
if st.button("சுருக்கம் & NER ஐ உருவாக்குக (Generate Summary & NER)", key="generate_button"):
if input_text and not input_text.isspace():
text_to_process = input_text.strip()
st.markdown("---")
st.header(" முடிவுகள் / Results")
# Use columns for the final output
col1, col2 = st.columns(2)
# --- Column 1: NER Entities ---
with col1:
st.subheader("முக்கிய சொற்கள் (NER Entities)")
with st.spinner("Extracting entities..."):
extracted_entities_raw, entities_display_string = extract_entities_internal(ner_model_global, text_to_process)
# Display entities using markdown for copyability
st.markdown(entities_display_string)
# --- Column 2: NER-Influenced Summary ---
with col2:
st.subheader("NER-உடன் செல்வாக்கு பெற்ற சுருக்கம்")
st.markdown("*(NER-Influenced Summary)*")
with st.spinner(f"Generating summary on {DEVICE}... (This may take time)"):
# Create prompted input using the extracted entities
prompted_input_text = create_prompted_input_internal(text_to_process, extracted_entities_raw)
# Generate the summary
ner_influenced_summary = summarize_text_internal(
summ_tokenizer_global, summ_model_global, prompted_input_text, DEVICE
)
# Display summary using markdown for copyability
st.markdown(ner_influenced_summary)
st.caption("Summary generated using fine-tuned model with NER entities prepended to input.")
st.success("Processing complete!")
elif input_text is None or input_text.isspace():
st.warning("Please enter some text into the input area.")
# Handle the case where button hasn't been pressed yet explicitly
# else:
# st.info("Click the button to generate summaries and extract entities.")
st.markdown("---")
st.caption("Developed using Streamlit, spaCy, and Hugging Face Transformers/PEFT.")