Spaces:
Paused
Paused
from collections import defaultdict | |
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union | |
from datasets import Dataset | |
from transformers.trainer_callback import TrainerCallback | |
from setfit.span.modeling import AbsaModel, AspectModel, PolarityModel | |
from setfit.training_args import TrainingArguments | |
from .. import logging | |
from ..trainer import ColumnMappingMixin, Trainer | |
if TYPE_CHECKING: | |
import optuna | |
logger = logging.get_logger(__name__) | |
class AbsaTrainer(ColumnMappingMixin): | |
"""Trainer to train a SetFit ABSA model. | |
Args: | |
model (`AbsaModel`): | |
The AbsaModel model to train. | |
args (`TrainingArguments`, *optional*): | |
The training arguments to use. If `polarity_args` is not defined, then `args` is used for both | |
the aspect and the polarity model. | |
polarity_args (`TrainingArguments`, *optional*): | |
The training arguments to use for the polarity model. If not defined, `args` is used for both | |
the aspect and the polarity model. | |
train_dataset (`Dataset`): | |
The training dataset. The dataset must have "text", "span", "label" and "ordinal" columns. | |
eval_dataset (`Dataset`, *optional*): | |
The evaluation dataset. The dataset must have "text", "span", "label" and "ordinal" columns. | |
metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`): | |
The metric to use for evaluation. If a string is provided, we treat it as the metric | |
name and load it with default settings. | |
If a callable is provided, it must take two arguments (`y_pred`, `y_test`). | |
metric_kwargs (`Dict[str, Any]`, *optional*): | |
Keyword arguments passed to the evaluation function if `metric` is an evaluation string like "f1". | |
For example useful for providing an averaging strategy for computing f1 in a multi-label setting. | |
callbacks (`List[`[`~transformers.TrainerCallback`]`]`, *optional*): | |
A list of callbacks to customize the training loop. Will add those to the list of default callbacks | |
detailed in [here](https://huggingface.co/docs/transformers/main/en/main_classes/callback). | |
If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method. | |
column_mapping (`Dict[str, str]`, *optional*): | |
A mapping from the column names in the dataset to the column names expected by the model. | |
The expected format is a dictionary with the following format: | |
`{"text_column_name": "text", "span_column_name": "span", "label_column_name: "label", "ordinal_column_name": "ordinal"}`. | |
""" | |
_REQUIRED_COLUMNS = {"text", "span", "label", "ordinal"} | |
def __init__( | |
self, | |
model: AbsaModel, | |
args: Optional[TrainingArguments] = None, | |
polarity_args: Optional[TrainingArguments] = None, | |
train_dataset: Optional["Dataset"] = None, | |
eval_dataset: Optional["Dataset"] = None, | |
metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy", | |
metric_kwargs: Optional[Dict[str, Any]] = None, | |
callbacks: Optional[List[TrainerCallback]] = None, | |
column_mapping: Optional[Dict[str, str]] = None, | |
) -> None: | |
self.model = model | |
self.aspect_extractor = model.aspect_extractor | |
if train_dataset is not None and column_mapping: | |
train_dataset = self._apply_column_mapping(train_dataset, column_mapping) | |
aspect_train_dataset, polarity_train_dataset = self.preprocess_dataset( | |
model.aspect_model, model.polarity_model, train_dataset | |
) | |
if eval_dataset is not None and column_mapping: | |
eval_dataset = self._apply_column_mapping(eval_dataset, column_mapping) | |
aspect_eval_dataset, polarity_eval_dataset = self.preprocess_dataset( | |
model.aspect_model, model.polarity_model, eval_dataset | |
) | |
self.aspect_trainer = Trainer( | |
model.aspect_model, | |
args=args, | |
train_dataset=aspect_train_dataset, | |
eval_dataset=aspect_eval_dataset, | |
metric=metric, | |
metric_kwargs=metric_kwargs, | |
callbacks=callbacks, | |
) | |
self.aspect_trainer._set_logs_mapper( | |
{ | |
"eval_embedding_loss": "eval_aspect_embedding_loss", | |
"embedding_loss": "aspect_embedding_loss", | |
} | |
) | |
self.polarity_trainer = Trainer( | |
model.polarity_model, | |
args=polarity_args or args, | |
train_dataset=polarity_train_dataset, | |
eval_dataset=polarity_eval_dataset, | |
metric=metric, | |
metric_kwargs=metric_kwargs, | |
callbacks=callbacks, | |
) | |
self.polarity_trainer._set_logs_mapper( | |
{ | |
"eval_embedding_loss": "eval_polarity_embedding_loss", | |
"embedding_loss": "polarity_embedding_loss", | |
} | |
) | |
def preprocess_dataset( | |
self, aspect_model: AspectModel, polarity_model: PolarityModel, dataset: Dataset | |
) -> Dataset: | |
if dataset is None: | |
return dataset, dataset | |
# Group by "text" | |
grouped_data = defaultdict(list) | |
for sample in dataset: | |
text = sample.pop("text") | |
grouped_data[text].append(sample) | |
def index_ordinal(text: str, target: str, ordinal: int) -> Tuple[int, int]: | |
find_from = 0 | |
for _ in range(ordinal + 1): | |
start_idx = text.index(target, find_from) | |
find_from = start_idx + 1 | |
return start_idx, start_idx + len(target) | |
def overlaps(aspect: slice, aspects: List[slice]) -> bool: | |
for test_aspect in aspects: | |
overlapping_indices = set(range(aspect.start, aspect.stop + 1)) & set( | |
range(test_aspect.start, test_aspect.stop + 1) | |
) | |
if overlapping_indices: | |
return True | |
return False | |
docs, aspects_list = self.aspect_extractor(grouped_data.keys()) | |
aspect_aspect_list = [] | |
aspect_labels = [] | |
polarity_aspect_list = [] | |
polarity_labels = [] | |
for doc, aspects, text in zip(docs, aspects_list, grouped_data): | |
# Collect all of the gold aspects | |
gold_aspects = [] | |
gold_polarity_labels = [] | |
for annotation in grouped_data[text]: | |
try: | |
start, end = index_ordinal(text, annotation["span"], annotation["ordinal"]) | |
except ValueError: | |
logger.info( | |
f"The ordinal of {annotation['ordinal']} for span {annotation['span']!r} in {text!r} is too high. " | |
"Skipping this sample." | |
) | |
continue | |
gold_aspect_span = doc.char_span(start, end) | |
if gold_aspect_span is None: | |
continue | |
gold_aspects.append(slice(gold_aspect_span.start, gold_aspect_span.end)) | |
gold_polarity_labels.append(annotation["label"]) | |
# The Aspect model uses all gold aspects as "True", and all non-overlapping predicted | |
# aspects as "False" | |
aspect_labels.extend([True] * len(gold_aspects)) | |
aspect_aspect_list.append(gold_aspects[:]) | |
for aspect in aspects: | |
if not overlaps(aspect, gold_aspects): | |
aspect_labels.append(False) | |
aspect_aspect_list[-1].append(aspect) | |
# The Polarity model uses only the gold aspects and labels | |
polarity_labels.extend(gold_polarity_labels) | |
polarity_aspect_list.append(gold_aspects) | |
aspect_texts = list(aspect_model.prepend_aspects(docs, aspect_aspect_list)) | |
polarity_texts = list(polarity_model.prepend_aspects(docs, polarity_aspect_list)) | |
return Dataset.from_dict({"text": aspect_texts, "label": aspect_labels}), Dataset.from_dict( | |
{"text": polarity_texts, "label": polarity_labels} | |
) | |
def train( | |
self, | |
args: Optional[TrainingArguments] = None, | |
polarity_args: Optional[TrainingArguments] = None, | |
trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, | |
**kwargs, | |
) -> None: | |
""" | |
Main training entry point. | |
Args: | |
args (`TrainingArguments`, *optional*): | |
Temporarily change the aspect training arguments for this training call. | |
polarity_args (`TrainingArguments`, *optional*): | |
Temporarily change the polarity training arguments for this training call. | |
trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): | |
The trial run or the hyperparameter dictionary for hyperparameter search. | |
""" | |
self.train_aspect(args=args, trial=trial, **kwargs) | |
self.train_polarity(args=polarity_args, trial=trial, **kwargs) | |
def train_aspect( | |
self, | |
args: Optional[TrainingArguments] = None, | |
trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, | |
**kwargs, | |
) -> None: | |
""" | |
Train the aspect model only. | |
Args: | |
args (`TrainingArguments`, *optional*): | |
Temporarily change the aspect training arguments for this training call. | |
trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): | |
The trial run or the hyperparameter dictionary for hyperparameter search. | |
""" | |
self.aspect_trainer.train(args=args, trial=trial, **kwargs) | |
def train_polarity( | |
self, | |
args: Optional[TrainingArguments] = None, | |
trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, | |
**kwargs, | |
) -> None: | |
""" | |
Train the polarity model only. | |
Args: | |
args (`TrainingArguments`, *optional*): | |
Temporarily change the aspect training arguments for this training call. | |
trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): | |
The trial run or the hyperparameter dictionary for hyperparameter search. | |
""" | |
self.polarity_trainer.train(args=args, trial=trial, **kwargs) | |
def add_callback(self, callback: Union[type, TrainerCallback]) -> None: | |
""" | |
Add a callback to the current list of [`~transformers.TrainerCallback`]. | |
Args: | |
callback (`type` or [`~transformers.TrainerCallback`]): | |
A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the | |
first case, will instantiate a member of that class. | |
""" | |
self.aspect_trainer.add_callback(callback) | |
self.polarity_trainer.add_callback(callback) | |
def pop_callback(self, callback: Union[type, TrainerCallback]) -> Tuple[TrainerCallback, TrainerCallback]: | |
""" | |
Remove a callback from the current list of [`~transformers.TrainerCallback`] and returns it. | |
If the callback is not found, returns `None` (and no error is raised). | |
Args: | |
callback (`type` or [`~transformers.TrainerCallback`]): | |
A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the | |
first case, will pop the first member of that class found in the list of callbacks. | |
Returns: | |
`Tuple[`[`~transformers.TrainerCallback`], [`~transformers.TrainerCallback`]`]`: The callbacks removed from the | |
aspect and polarity trainers, if found. | |
""" | |
return self.aspect_trainer.pop_callback(callback), self.polarity_trainer.pop_callback(callback) | |
def remove_callback(self, callback: Union[type, TrainerCallback]) -> None: | |
""" | |
Remove a callback from the current list of [`~transformers.TrainerCallback`]. | |
Args: | |
callback (`type` or [`~transformers.TrainerCallback`]): | |
A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the | |
first case, will remove the first member of that class found in the list of callbacks. | |
""" | |
self.aspect_trainer.remove_callback(callback) | |
self.polarity_trainer.remove_callback(callback) | |
def push_to_hub(self, repo_id: str, polarity_repo_id: Optional[str] = None, **kwargs) -> None: | |
"""Upload model checkpoint to the Hub using `huggingface_hub`. | |
See the full list of parameters for your `huggingface_hub` version in the\ | |
[huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.ModelHubMixin.push_to_hub). | |
Args: | |
repo_id (`str`): | |
The full repository ID to push to, e.g. `"tomaarsen/setfit-aspect"`. | |
repo_id (`str`): | |
The full repository ID to push to, e.g. `"tomaarsen/setfit-sst2"`. | |
config (`dict`, *optional*): | |
Configuration object to be saved alongside the model weights. | |
commit_message (`str`, *optional*): | |
Message to commit while pushing. | |
private (`bool`, *optional*, defaults to `False`): | |
Whether the repository created should be private. | |
api_endpoint (`str`, *optional*): | |
The API endpoint to use when pushing the model to the hub. | |
token (`str`, *optional*): | |
The token to use as HTTP bearer authorization for remote files. | |
If not set, will use the token set when logging in with | |
`transformers-cli login` (stored in `~/.huggingface`). | |
branch (`str`, *optional*): | |
The git branch on which to push the model. This defaults to | |
the default branch as specified in your repository, which | |
defaults to `"main"`. | |
create_pr (`boolean`, *optional*): | |
Whether or not to create a Pull Request from `branch` with that commit. | |
Defaults to `False`. | |
allow_patterns (`List[str]` or `str`, *optional*): | |
If provided, only files matching at least one pattern are pushed. | |
ignore_patterns (`List[str]` or `str`, *optional*): | |
If provided, files matching any of the patterns are not pushed. | |
""" | |
return self.model.push_to_hub(repo_id=repo_id, polarity_repo_id=polarity_repo_id, **kwargs) | |
def evaluate(self, dataset: Optional[Dataset] = None) -> Dict[str, Dict[str, float]]: | |
""" | |
Computes the metrics for a given classifier. | |
Args: | |
dataset (`Dataset`, *optional*): | |
The dataset to compute the metrics on. If not provided, will use the evaluation dataset passed via | |
the `eval_dataset` argument at `Trainer` initialization. | |
Returns: | |
`Dict[str, Dict[str, float]]`: The evaluation metrics. | |
""" | |
aspect_eval_dataset = polarity_eval_dataset = None | |
if dataset: | |
aspect_eval_dataset, polarity_eval_dataset = self.preprocess_dataset( | |
self.model.aspect_model, self.model.polarity_model, dataset | |
) | |
return { | |
"aspect": self.aspect_trainer.evaluate(aspect_eval_dataset), | |
"polarity": self.polarity_trainer.evaluate(polarity_eval_dataset), | |
} | |