Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2023 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import copy | |
import os | |
import random | |
from typing import Any, Dict, Iterable, Optional, Union | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torch.nn.functional as F | |
def enable_full_determinism(seed: int): | |
""" | |
Helper function for reproducible behavior during distributed training. See | |
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch | |
""" | |
# set seed first | |
set_seed(seed) | |
# Enable PyTorch deterministic mode. This potentially requires either the environment | |
# variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, | |
# depending on the CUDA version, so we set them both here | |
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" | |
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" | |
torch.use_deterministic_algorithms(True) | |
# Enable CUDNN deterministic mode | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
def set_seed(seed: int): | |
""" | |
Args: | |
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. | |
seed (`int`): The seed to set. | |
""" | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
# ^^ safe to call this function even if cuda is not available | |
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 | |
class EMA: | |
""" | |
Exponential Moving Average of models weights | |
""" | |
def __init__( | |
self, | |
parameters: Iterable[torch.nn.Parameter], | |
decay: float = 0.9999, | |
min_decay: float = 0.0, | |
update_after_step: int = 0, | |
use_ema_warmup: bool = False, | |
inv_gamma: Union[float, int] = 1.0, | |
power: Union[float, int] = 2 / 3, | |
model_cls: Optional[Any] = None, | |
model_config: Dict[str, Any] = None, | |
**kwargs, | |
): | |
""" | |
Args: | |
parameters (Iterable[torch.nn.Parameter]): The parameters to track. | |
decay (float): The decay factor for the exponential moving average. | |
min_decay (float): The minimum decay factor for the exponential moving average. | |
update_after_step (int): The number of steps to wait before starting to update the EMA weights. | |
use_ema_warmup (bool): Whether to use EMA warmup. | |
inv_gamma (float): | |
Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. | |
power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. | |
device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA | |
weights will be stored on CPU. | |
@crowsonkb's notes on EMA Warmup: | |
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan | |
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), | |
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 | |
at 215.4k steps). | |
""" | |
parameters = list(parameters) | |
self.shadow_params = [p.clone().detach() for p in parameters] | |
self.temp_stored_params = None | |
self.decay = decay | |
self.min_decay = min_decay | |
self.update_after_step = update_after_step | |
self.use_ema_warmup = use_ema_warmup | |
self.inv_gamma = inv_gamma | |
self.power = power | |
self.optimization_step = 0 | |
self.cur_decay_value = None # set in `step()` | |
self.model_cls = model_cls | |
self.model_config = model_config | |
def from_pretrained(cls, path, model_cls) -> "EMA": | |
_, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) | |
model = model_cls.from_pretrained(path) | |
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config) | |
ema_model.load_state_dict(ema_kwargs) | |
return ema_model | |
def save_pretrained(self, path): | |
if self.model_cls is None: | |
raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") | |
if self.model_config is None: | |
raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.") | |
model = self.model_cls.from_config(self.model_config) | |
state_dict = self.state_dict() | |
state_dict.pop("shadow_params", None) | |
model.register_to_config(**state_dict) | |
self.copy_to(model.parameters()) | |
model.save_pretrained(path) | |
def get_decay(self, optimization_step: int) -> float: | |
""" | |
Compute the decay factor for the exponential moving average. | |
""" | |
step = max(0, optimization_step - self.update_after_step - 1) | |
if step <= 0: | |
return 0.0 | |
if self.use_ema_warmup: | |
cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power | |
else: | |
cur_decay_value = (1 + step) / (10 + step) | |
cur_decay_value = min(cur_decay_value, self.decay) | |
# make sure decay is not smaller than min_decay | |
cur_decay_value = max(cur_decay_value, self.min_decay) | |
return cur_decay_value | |
def step(self, parameters: Iterable[torch.nn.Parameter]): | |
parameters = list(parameters) | |
self.optimization_step += 1 | |
# Compute the decay factor for the exponential moving average. | |
decay = self.get_decay(self.optimization_step) | |
self.cur_decay_value = decay | |
one_minus_decay = 1 - decay | |
for s_param, param in zip(self.shadow_params, parameters): | |
if param.requires_grad: | |
s_param.sub_(one_minus_decay * (s_param - param)) | |
else: | |
s_param.copy_(param) | |
torch.cuda.empty_cache() | |
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: | |
""" | |
Copy current averaged parameters into given collection of parameters. | |
Args: | |
parameters: Iterable of `torch.nn.Parameter`; the parameters to be | |
updated with the stored moving averages. If `None`, the parameters with which this | |
`ExponentialMovingAverage` was initialized will be used. | |
""" | |
parameters = list(parameters) | |
for s_param, param in zip(self.shadow_params, parameters): | |
param.data.copy_(s_param.to(param.device).data) | |
def to(self, device=None, dtype=None) -> None: | |
r"""Move internal buffers of the ExponentialMovingAverage to `device`. | |
Args: | |
device: like `device` argument to `torch.Tensor.to` | |
""" | |
# .to() on the tensors handles None correctly | |
self.shadow_params = [ | |
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) | |
for p in self.shadow_params | |
] | |
def state_dict(self) -> dict: | |
r""" | |
Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during | |
checkpointing to save the ema state dict. | |
""" | |
# Following PyTorch conventions, references to tensors are returned: | |
# "returns a reference to the state and not its copy!" - | |
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict | |
return { | |
"decay": self.decay, | |
"min_decay": self.min_decay, | |
"optimization_step": self.optimization_step, | |
"update_after_step": self.update_after_step, | |
"use_ema_warmup": self.use_ema_warmup, | |
"inv_gamma": self.inv_gamma, | |
"power": self.power, | |
"shadow_params": self.shadow_params, | |
} | |
def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: | |
r""" | |
Args: | |
Save the current parameters for restoring later. | |
parameters: Iterable of `torch.nn.Parameter`; the parameters to be | |
temporarily stored. | |
""" | |
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] | |
def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: | |
r""" | |
Args: | |
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: | |
affecting the original optimization process. Store the parameters before the `copy_to()` method. After | |
validation (or model saving), use this to restore the former parameters. | |
parameters: Iterable of `torch.nn.Parameter`; the parameters to be | |
updated with the stored parameters. If `None`, the parameters with which this | |
`ExponentialMovingAverage` was initialized will be used. | |
""" | |
if self.temp_stored_params is None: | |
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`") | |
for c_param, param in zip(self.temp_stored_params, parameters): | |
param.data.copy_(c_param.data) | |
# Better memory-wise. | |
self.temp_stored_params = None | |
def load_state_dict(self, state_dict: dict) -> None: | |
r""" | |
Args: | |
Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the | |
ema state dict. | |
state_dict (dict): EMA state. Should be an object returned | |
from a call to :meth:`state_dict`. | |
""" | |
# deepcopy, to be consistent with module API | |
state_dict = copy.deepcopy(state_dict) | |
self.decay = state_dict.get("decay", self.decay) | |
if self.decay < 0.0 or self.decay > 1.0: | |
raise ValueError("Decay must be between 0 and 1") | |
self.min_decay = state_dict.get("min_decay", self.min_decay) | |
if not isinstance(self.min_decay, float): | |
raise ValueError("Invalid min_decay") | |
self.optimization_step = state_dict.get("optimization_step", self.optimization_step) | |
if not isinstance(self.optimization_step, int): | |
raise ValueError("Invalid optimization_step") | |
self.update_after_step = state_dict.get("update_after_step", self.update_after_step) | |
if not isinstance(self.update_after_step, int): | |
raise ValueError("Invalid update_after_step") | |
self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) | |
if not isinstance(self.use_ema_warmup, bool): | |
raise ValueError("Invalid use_ema_warmup") | |
self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) | |
if not isinstance(self.inv_gamma, (float, int)): | |
raise ValueError("Invalid inv_gamma") | |
self.power = state_dict.get("power", self.power) | |
if not isinstance(self.power, (float, int)): | |
raise ValueError("Invalid power") | |
shadow_params = state_dict.get("shadow_params", None) | |
if shadow_params is not None: | |
self.shadow_params = shadow_params | |
if not isinstance(self.shadow_params, list): | |
raise ValueError("shadow_params must be a list") | |
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): | |
raise ValueError("shadow_params must all be Tensors") | |
# calculates entropy over each pixel distribution | |
def pixel_entropy_per_percent_masked_bucket(logits, input_ids, mask_id): | |
# only calculated entropy over image tokens that were masked in the original image | |
masked_tokens = input_ids == mask_id | |
num_masked_pixels = masked_tokens.sum(-1) | |
probs = F.softmax(logits, dim=-1) | |
log_probs = F.log_softmax(logits, dim=-1) | |
entropy_per_pixel = -((probs * log_probs).sum(-1)) | |
# the predictions for non-masked aren't used, so set their entropies to zero | |
entropy_per_pixel[~masked_tokens] = 0 | |
entropy_per_image_numerator = entropy_per_pixel.sum(-1) | |
entropy_per_image = entropy_per_image_numerator / num_masked_pixels | |
total_buckets = 10 | |
masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets) | |
entropy_by_masked_bucket = average_by_buckets(entropy_per_image, masked_buckets, total_buckets) | |
return entropy_by_masked_bucket | |
# calculates entropy over the averaged distribution of pixels for the whole image | |
def image_entropy_per_percent_masked_bucket(logits, input_ids, mask_id): | |
# only calculated entropy over image tokens that were masked in the original image | |
masked_tokens = input_ids == mask_id | |
num_masked_pixels = masked_tokens.sum(-1, keepdim=True) | |
pixel_probs = F.softmax(logits, dim=-1) | |
pixel_probs[~masked_tokens] = 0 | |
image_probs_numerator = pixel_probs.sum(-2) | |
image_probs = image_probs_numerator / num_masked_pixels | |
image_log_probs = image_probs.log() | |
entropy_per_image = -((image_probs * image_log_probs).sum(-1)) | |
total_buckets = 10 | |
masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets) | |
entropy_by_masked_bucket = average_by_buckets(entropy_per_image, masked_buckets, total_buckets) | |
return entropy_by_masked_bucket | |
def cross_entropy_per_percent_masked_bucket(logits, labels, input_ids, mask_id, output_size, label_smoothing): | |
cross_entropy_per_image = F.cross_entropy( | |
logits.view(-1, output_size), | |
labels.view(-1), | |
ignore_index=-100, | |
label_smoothing=label_smoothing, | |
reduction="none", | |
) | |
total_buckets = 10 | |
masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets) | |
cross_entropy_by_percent_masked_bucket = average_by_buckets(cross_entropy_per_image, masked_buckets, total_buckets) | |
return cross_entropy_by_percent_masked_bucket | |
def token_probability_distributions_per_percent_masked_bucket(logits, input_ids, mask_id): | |
probs = F.softmax(logits, dim=-1) | |
total_buckets = 10 | |
masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets) | |
data = [] | |
for bucket_idx in range(total_buckets): | |
indices_for_bucket = masked_buckets[masked_buckets == bucket_idx] | |
# It's ok if none were noised in the range of this bucket. This | |
# function will be called for a later training step where it's likely | |
# there will be an element noised in the range. | |
if indices_for_bucket.shape[0] == 0: | |
continue | |
index_for_bucket = indices_for_bucket[0] | |
image_probs = probs[index_for_bucket] | |
# find the index of a masked pixel for the image | |
input_ids_for_image = input_ids[index_for_bucket] | |
masked_pixels_probs = image_probs[input_ids_for_image == mask_id] | |
masked_pixel_probs = masked_pixels_probs[0] | |
masked_pixel_probs = masked_pixel_probs.cpu().numpy() | |
for masked_pixel_prob in masked_pixel_probs: | |
data.append({"bucket": bucket_idx, "masked_pixel_prob": masked_pixel_prob}) | |
df = pd.DataFrame(data) | |
return df | |
def average_by_buckets(values, masked_buckets, total_buckets): | |
unique_buckets, bucket_counts = masked_buckets.unique(dim=0, return_counts=True) | |
numerator = torch.zeros(total_buckets, device=values.device) | |
numerator.scatter_add_(0, masked_buckets, values) | |
# default value is one because the buckets for which there aren't | |
# any values will have a numerator of zero. So we just need to not divide | |
# by zero. | |
denominator = torch.ones(total_buckets, device=values.device, dtype=torch.long) | |
denominator[unique_buckets] = bucket_counts | |
averaged_by_buckets = numerator / denominator | |
return averaged_by_buckets | |
def input_ids_to_masked_buckets(input_ids, mask_id, total_buckets=10): | |
assert total_buckets == 10 | |
masked_percent = (input_ids == mask_id).sum(-1) / input_ids.shape[-1] | |
# we do not formally use timesteps to noise images. Instead, we mask a percent | |
# of the pixels. We don't want to log entropy for every mask percent between 0 and 1, | |
# and we also want to track how the entropy evolves over time w/in a range of mask | |
# percents that should have similar entropy. So we bucket the masked percents into a | |
# fixed number of buckets | |
# we could generalize this later if needed but for now, let's just assume a fixed | |
# number of 10 buckets. | |
# How this maps to a bucket index: | |
# (mask) * bucket_index + | |
# (mask_1) * bucket_index_1 | |
# | |
# -> Where the mask is true will be set to the expected bucket index, | |
# where the mask is false will be set to 0. | |
# | |
# Given the probabilities are between 0 and 1, each masked_percent will get mapped | |
# to a timestep by one and only one of the masks. | |
masked_buckets = ( | |
((0 < masked_percent) & (masked_percent <= 0.1)) * 0 | |
+ ((0.1 < masked_percent) & (masked_percent <= 0.2)) * 1 | |
+ ((0.2 < masked_percent) & (masked_percent <= 0.3)) * 2 | |
+ ((0.3 < masked_percent) & (masked_percent <= 0.4)) * 3 | |
+ ((0.4 < masked_percent) & (masked_percent <= 0.5)) * 4 | |
+ ((0.5 < masked_percent) & (masked_percent <= 0.6)) * 5 | |
+ ((0.6 < masked_percent) & (masked_percent <= 0.7)) * 6 | |
+ ((0.7 < masked_percent) & (masked_percent <= 0.8)) * 7 | |
+ ((0.8 < masked_percent) & (masked_percent <= 0.9)) * 8 | |
+ ((0.9 < masked_percent) & (masked_percent <= 1.0)) * 9 | |
) | |
return masked_buckets | |