|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import inspect |
|
import os |
|
import warnings |
|
from contextlib import nullcontext |
|
from typing import Optional, Tuple |
|
|
|
import accelerate |
|
import torch |
|
from accelerate.hooks import add_hook_to_module, remove_hook_from_module |
|
from accelerate.utils import is_npu_available, is_xpu_available |
|
from huggingface_hub import file_exists |
|
from huggingface_hub.utils import EntryNotFoundError, HFValidationError |
|
from packaging import version |
|
from safetensors.torch import storage_ptr, storage_size |
|
|
|
from ..import_utils import is_auto_gptq_available, is_torch_tpu_available |
|
from .constants import ( |
|
CONFIG_NAME, |
|
EMBEDDING_LAYER_NAMES, |
|
INCLUDE_LINEAR_LAYERS_SHORTHAND, |
|
SAFETENSORS_WEIGHTS_NAME, |
|
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, |
|
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, |
|
TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, |
|
TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING, |
|
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, |
|
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, |
|
TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING, |
|
WEIGHTS_NAME, |
|
bloom_model_postprocess_past_key_value, |
|
starcoder_model_postprocess_past_key_value, |
|
) |
|
|
|
|
|
mlu_available = False |
|
if version.parse(accelerate.__version__) >= version.parse("0.29.0"): |
|
from accelerate.utils import is_mlu_available |
|
|
|
mlu_available = is_mlu_available() |
|
|
|
|
|
__all__ = [ |
|
"CONFIG_NAME", |
|
"EMBEDDING_LAYER_NAMES", |
|
"SAFETENSORS_WEIGHTS_NAME", |
|
"TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING", |
|
"TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING", |
|
"TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING", |
|
"TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING", |
|
"TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING", |
|
"TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING", |
|
"TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING", |
|
"WEIGHTS_NAME", |
|
"INCLUDE_LINEAR_LAYERS_SHORTHAND", |
|
"bloom_model_postprocess_past_key_value", |
|
"starcoder_model_postprocess_past_key_value", |
|
] |
|
|
|
|
|
|
|
def infer_device() -> str: |
|
if torch.cuda.is_available(): |
|
return "cuda" |
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
|
return "mps" |
|
elif mlu_available: |
|
return "mlu" |
|
elif is_xpu_available(): |
|
return "xpu" |
|
elif is_npu_available(): |
|
return "npu" |
|
return "cpu" |
|
|
|
|
|
def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None): |
|
r""" |
|
Note this method only works for `transformers` models. |
|
|
|
This method wraps the entire protocol for preparing a model before running a training. This includes: |
|
1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm |
|
head to fp32 |
|
|
|
Args: |
|
model (`transformers.PreTrainedModel`): |
|
The loaded model from `transformers` |
|
use_gradient_checkpointing (`bool`, *optional*, defaults to `True`): |
|
If True, use gradient checkpointing to save memory at the expense of slower backward pass. |
|
gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`): |
|
Keyword arguments to pass to the gradient checkpointing function, please refer to the documentation of |
|
`torch.utils.checkpoint.checkpoint` for more details about the arguments that you can pass to that method. |
|
Note this is only available in the latest transformers versions (> 4.34.1). |
|
""" |
|
loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) |
|
is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq" |
|
is_aqlm_quantized = getattr(model, "quantization_method", None) == "aqlm" |
|
is_eetq_quantized = getattr(model, "quantization_method", None) == "eetq" |
|
is_hqq_quantized = getattr(model, "quantization_method", None) == "hqq" or getattr(model, "hqq_quantized", False) |
|
|
|
if gradient_checkpointing_kwargs is None: |
|
gradient_checkpointing_kwargs = {} |
|
|
|
for name, param in model.named_parameters(): |
|
|
|
param.requires_grad = False |
|
|
|
if not is_gptq_quantized and not is_aqlm_quantized and not is_eetq_quantized and not is_hqq_quantized: |
|
|
|
for param in model.parameters(): |
|
if ( |
|
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16) |
|
) and param.__class__.__name__ != "Params4bit": |
|
param.data = param.data.to(torch.float32) |
|
|
|
if ( |
|
loaded_in_kbit or is_gptq_quantized or is_aqlm_quantized or is_eetq_quantized or is_hqq_quantized |
|
) and use_gradient_checkpointing: |
|
|
|
if "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]: |
|
|
|
if hasattr(model, "enable_input_require_grads"): |
|
model.enable_input_require_grads() |
|
else: |
|
|
|
def make_inputs_require_grad(module, input, output): |
|
output.requires_grad_(True) |
|
|
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) |
|
|
|
|
|
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list( |
|
inspect.signature(model.gradient_checkpointing_enable).parameters |
|
) |
|
|
|
if not _supports_gc_kwargs and len(gradient_checkpointing_kwargs) > 0: |
|
warnings.warn( |
|
"gradient_checkpointing_kwargs is not supported in this version of transformers. The passed kwargs will be ignored." |
|
" if you want to use that feature, please upgrade to the latest version of transformers.", |
|
FutureWarning, |
|
) |
|
|
|
gc_enable_kwargs = ( |
|
{} if not _supports_gc_kwargs else {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs} |
|
) |
|
|
|
|
|
model.gradient_checkpointing_enable(**gc_enable_kwargs) |
|
return model |
|
|
|
|
|
|
|
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): |
|
""" |
|
Shift input ids one token to the right. |
|
|
|
Args: |
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input ids |
|
pad_token_id (`int`): The id of the `padding` token. |
|
decoder_start_token_id (`int`): The id of the `start` token. |
|
""" |
|
shifted_input_ids = input_ids.new_zeros(input_ids.shape) |
|
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() |
|
shifted_input_ids[:, 0] = decoder_start_token_id |
|
|
|
if pad_token_id is None: |
|
raise ValueError("self.model.config.pad_token_id has to be defined.") |
|
|
|
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) |
|
|
|
return shifted_input_ids |
|
|
|
|
|
class ModulesToSaveWrapper(torch.nn.Module): |
|
def __init__(self, module_to_save, adapter_name): |
|
super().__init__() |
|
self.original_module = module_to_save |
|
self.modules_to_save = torch.nn.ModuleDict({}) |
|
self._active_adapter = adapter_name |
|
self._disable_adapters = False |
|
self.update(adapter_name) |
|
self.check_module() |
|
|
|
def check_module(self): |
|
"""Perform some sanity checks on the module to ensure that it works""" |
|
|
|
|
|
|
|
forbidden_classes = (torch.nn.ModuleDict, torch.nn.ModuleList, torch.nn.ParameterDict, torch.nn.ParameterList) |
|
if isinstance(self.original_module, forbidden_classes): |
|
cls_name = self.original_module.__class__.__name__ |
|
raise TypeError(f"modules_to_save cannot be applied to modules of type {cls_name}") |
|
|
|
@property |
|
def disable_adapters(self) -> bool: |
|
|
|
return self._disable_adapters |
|
|
|
@property |
|
def active_adapter(self) -> str: |
|
|
|
return self._active_adapter |
|
|
|
@property |
|
def weight(self): |
|
if self.active_adapter not in self.modules_to_save: |
|
return self.original_module.weight |
|
return self.modules_to_save[self.active_adapter].weight |
|
|
|
def update(self, adapter_name): |
|
context_manager = nullcontext() |
|
for _, param in self.original_module.named_parameters(): |
|
num_params = param.numel() |
|
|
|
if num_params == 0 and hasattr(param, "ds_numel"): |
|
import deepspeed |
|
|
|
context_manager = deepspeed.zero.GatheredParameters(self.original_module.parameters(), modifier_rank=0) |
|
break |
|
with context_manager: |
|
self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)})) |
|
|
|
if hasattr(self.modules_to_save[adapter_name], "_hf_hook"): |
|
old_hook = self.modules_to_save[adapter_name]._hf_hook |
|
new_hook = self._create_new_hook(old_hook) |
|
remove_hook_from_module(self.modules_to_save[adapter_name]) |
|
add_hook_to_module(self.modules_to_save[adapter_name], new_hook) |
|
|
|
self.original_module.requires_grad_(False) |
|
if adapter_name == self.active_adapter: |
|
self.modules_to_save[adapter_name].requires_grad_(True) |
|
|
|
def _create_new_hook(self, old_hook): |
|
r""" |
|
Creates a new hook based on the old hook. Use it only if you know what you are doing ! |
|
""" |
|
old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__) |
|
old_hook_attr = old_hook.__dict__ |
|
filtered_old_hook_attr = {} |
|
old_hook_init_signature = inspect.signature(old_hook_cls.__init__) |
|
for k in old_hook_attr.keys(): |
|
if k in old_hook_init_signature.parameters: |
|
filtered_old_hook_attr[k] = old_hook_attr[k] |
|
new_hook = old_hook_cls(**filtered_old_hook_attr) |
|
return new_hook |
|
|
|
def forward(self, *args, **kwargs): |
|
if self.disable_adapters or (self.active_adapter not in self.modules_to_save): |
|
return self.original_module(*args, **kwargs) |
|
return self.modules_to_save[self.active_adapter](*args, **kwargs) |
|
|
|
def enable_adapters(self, enabled: bool): |
|
"""Toggle the enabling and disabling of adapters |
|
|
|
Takes care of setting the requires_grad flag for the adapter weights. |
|
|
|
Args: |
|
enabled (bool): True to enable adapters, False to disable adapters |
|
""" |
|
if self._disable_adapters is not enabled: |
|
|
|
return |
|
|
|
if enabled: |
|
self.original_module.requires_grad_(False) |
|
self.modules_to_save[self.active_adapter].requires_grad_(True) |
|
self._disable_adapters = False |
|
else: |
|
self.original_module.requires_grad_(True) |
|
self.modules_to_save.requires_grad_(False) |
|
self._disable_adapters = True |
|
|
|
def set_adapter(self, adapter_name: str): |
|
"""Set the active adapter |
|
|
|
Additionally, this function will set the specified adapter to trainable (i.e., requires_grad=True). If this is |
|
not desired, use the following code. |
|
|
|
```py |
|
>>> for name, param in model_peft.named_parameters(): |
|
... if ...: # some check on name (ex. if 'lora' in name) |
|
... param.requires_grad = False |
|
``` |
|
|
|
Args: |
|
adapter_name (str): The name of the adapter to set as active |
|
""" |
|
if adapter_name not in self.modules_to_save: |
|
raise ValueError(f"Adapter {adapter_name} not found in {self.modules_to_save.keys()}") |
|
|
|
self.modules_to_save[self.active_adapter].requires_grad_(False) |
|
self.modules_to_save[adapter_name].requires_grad_(True) |
|
self._active_adapter = adapter_name |
|
|
|
|
|
def _get_submodules(model, key): |
|
parent = model.get_submodule(".".join(key.split(".")[:-1])) |
|
target_name = key.split(".")[-1] |
|
target = model.get_submodule(key) |
|
return parent, target, target_name |
|
|
|
|
|
def _freeze_adapter(model, adapter_name): |
|
for n, p in model.named_parameters(): |
|
if adapter_name in n: |
|
p.requires_grad = False |
|
|
|
|
|
def _set_trainable(model, adapter_name): |
|
key_list = [key for key, _ in model.named_modules()] |
|
for key in key_list: |
|
target_module_found = any(key.endswith(target_key) for target_key in model.modules_to_save) |
|
if target_module_found: |
|
parent, target, target_name = _get_submodules(model, key) |
|
if isinstance(target, ModulesToSaveWrapper): |
|
target.update(adapter_name) |
|
target.set_adapter(target.active_adapter) |
|
else: |
|
new_module = ModulesToSaveWrapper(target, adapter_name) |
|
new_module.set_adapter(adapter_name) |
|
setattr(parent, target_name, new_module) |
|
|
|
|
|
def _set_adapter(model, adapter_name): |
|
def check_adapter_name(adapter_name): |
|
if isinstance(adapter_name, str): |
|
return adapter_name |
|
|
|
|
|
if len(adapter_name) > 1: |
|
raise ValueError("Only one adapter can be set at a time for modules_to_save") |
|
elif len(adapter_name) == 0: |
|
raise ValueError("Please specify at least one adapter to set") |
|
adapter_name = adapter_name[0] |
|
return adapter_name |
|
|
|
for module in model.modules(): |
|
if isinstance(module, ModulesToSaveWrapper): |
|
|
|
adapter_name = check_adapter_name(adapter_name) |
|
|
|
|
|
|
|
if adapter_name in module.modules_to_save: |
|
module.set_adapter(adapter_name) |
|
else: |
|
module.enable_adapters(False) |
|
|
|
|
|
def _prepare_prompt_learning_config(peft_config, model_config): |
|
if peft_config.num_layers is None: |
|
if "num_hidden_layers" in model_config: |
|
num_layers = model_config["num_hidden_layers"] |
|
elif "num_layers" in model_config: |
|
num_layers = model_config["num_layers"] |
|
elif "n_layer" in model_config: |
|
num_layers = model_config["n_layer"] |
|
else: |
|
raise ValueError("Please specify `num_layers` in `peft_config`") |
|
peft_config.num_layers = num_layers |
|
|
|
if peft_config.token_dim is None: |
|
if "hidden_size" in model_config: |
|
token_dim = model_config["hidden_size"] |
|
elif "n_embd" in model_config: |
|
token_dim = model_config["n_embd"] |
|
elif "d_model" in model_config: |
|
token_dim = model_config["d_model"] |
|
else: |
|
raise ValueError("Please specify `token_dim` in `peft_config`") |
|
peft_config.token_dim = token_dim |
|
|
|
if peft_config.num_attention_heads is None: |
|
if "num_attention_heads" in model_config: |
|
num_attention_heads = model_config["num_attention_heads"] |
|
elif "n_head" in model_config: |
|
num_attention_heads = model_config["n_head"] |
|
elif "num_heads" in model_config: |
|
num_attention_heads = model_config["num_heads"] |
|
elif "encoder_attention_heads" in model_config: |
|
num_attention_heads = model_config["encoder_attention_heads"] |
|
else: |
|
raise ValueError("Please specify `num_attention_heads` in `peft_config`") |
|
peft_config.num_attention_heads = num_attention_heads |
|
|
|
if getattr(peft_config, "encoder_hidden_size", None) is None: |
|
setattr(peft_config, "encoder_hidden_size", peft_config.token_dim) |
|
|
|
return peft_config |
|
|
|
|
|
def fsdp_auto_wrap_policy(model): |
|
import functools |
|
import os |
|
|
|
from accelerate import FullyShardedDataParallelPlugin |
|
|
|
if hasattr(FullyShardedDataParallelPlugin, "get_module_class_from_name"): |
|
get_module_class_from_name = FullyShardedDataParallelPlugin.get_module_class_from_name |
|
else: |
|
from accelerate.utils.dataclasses import get_module_class_from_name |
|
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy |
|
|
|
from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder |
|
|
|
default_transformer_cls_names_to_wrap = ( |
|
",".join(model._no_split_modules) if getattr(model, "_no_split_modules", None) is not None else "" |
|
) |
|
transformer_cls_names_to_wrap = os.environ.get( |
|
"FSDP_TRANSFORMER_CLS_TO_WRAP", default_transformer_cls_names_to_wrap |
|
).split(",") |
|
transformer_cls_to_wrap = {PrefixEncoder, PromptEncoder, PromptEmbedding} |
|
for layer_class in transformer_cls_names_to_wrap: |
|
transformer_cls = get_module_class_from_name(model, layer_class) |
|
if transformer_cls is None: |
|
raise Exception("Could not find the transformer layer class to wrap in the model.") |
|
else: |
|
transformer_cls_to_wrap.add(transformer_cls) |
|
|
|
def lambda_policy_fn(module): |
|
if ( |
|
len(list(module.named_children())) == 0 |
|
and getattr(module, "weight", None) is not None |
|
and module.weight.requires_grad |
|
): |
|
return True |
|
return False |
|
|
|
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) |
|
transformer_wrap_policy = functools.partial( |
|
transformer_auto_wrap_policy, |
|
transformer_layer_cls=transformer_cls_to_wrap, |
|
) |
|
|
|
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) |
|
return auto_wrap_policy |
|
|
|
|
|
def transpose(weight, fan_in_fan_out): |
|
if not fan_in_fan_out: |
|
return weight |
|
|
|
if isinstance(weight, torch.nn.Parameter): |
|
return torch.nn.Parameter(weight.T) |
|
return weight.T |
|
|
|
|
|
def _is_valid_match(key: str, target_key: str): |
|
""" |
|
Helper function to match module names target_key and key. Makes sure that either the key is exactly the target_key |
|
or the target_key is a submodule of key |
|
""" |
|
if key.endswith(target_key): |
|
if len(key) > len(target_key): |
|
return key.endswith("." + target_key) |
|
return True |
|
return False |
|
|
|
|
|
def _get_batch_size(input_ids: Optional[torch.Tensor], inputs_embeds: Optional[torch.Tensor]) -> int: |
|
"""Get the batch size based on either input_ids or input_embeds |
|
|
|
Raises an ValueError if both are None. |
|
|
|
""" |
|
if (input_ids is None) and (inputs_embeds is None): |
|
raise ValueError("You have to provide either input_ids or inputs_embeds") |
|
|
|
if input_ids is not None: |
|
batch_size = input_ids.shape[0] |
|
else: |
|
batch_size = inputs_embeds.shape[0] |
|
return batch_size |
|
|
|
|
|
def get_quantization_config(model: torch.nn.Module, method: str): |
|
""" |
|
Get the quantization config of the related quantization method |
|
""" |
|
if ( |
|
hasattr(model, "config") |
|
and hasattr(model.config, "quantization_config") |
|
and (getattr(model, "quantization_method", None) == method) |
|
): |
|
return model.config.quantization_config |
|
return None |
|
|
|
|
|
def get_auto_gptq_quant_linear(gptq_quantization_config): |
|
""" |
|
Get the right AutoGPTQQuantLinear class based on the quantization config file |
|
""" |
|
if gptq_quantization_config is not None and is_auto_gptq_available(): |
|
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear |
|
|
|
desc_act = gptq_quantization_config.desc_act |
|
group_size = gptq_quantization_config.group_size |
|
bits = gptq_quantization_config.bits |
|
if hasattr(gptq_quantization_config, "use_exllama"): |
|
use_exllama = gptq_quantization_config.use_exllama |
|
else: |
|
use_exllama = not gptq_quantization_config.disable_exllama |
|
if hasattr(gptq_quantization_config, "exllama_config"): |
|
exllama_version = gptq_quantization_config.exllama_config["version"] |
|
else: |
|
exllama_version = 1 |
|
AutoGPTQQuantLinear = dynamically_import_QuantLinear( |
|
use_triton=False, |
|
desc_act=desc_act, |
|
group_size=group_size, |
|
bits=bits, |
|
disable_exllama=not (use_exllama and exllama_version == 1), |
|
disable_exllamav2=not (use_exllama and exllama_version == 2), |
|
) |
|
return AutoGPTQQuantLinear |
|
return None |
|
|
|
|
|
def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]: |
|
""" |
|
Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For |
|
example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is |
|
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with |
|
non-overlapping lifetimes may have the same id. |
|
|
|
This method is the exact same copy of |
|
https://github.com/huggingface/transformers/blob/main/src/transformers/pytorch_utils.py#L282C1-L300C58 but we added |
|
it here manually to avoid import issue with old versions of transformers. |
|
""" |
|
if tensor.device.type == "xla" and is_torch_tpu_available(): |
|
|
|
|
|
|
|
|
|
import torch_xla |
|
|
|
unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor) |
|
else: |
|
unique_id = storage_ptr(tensor) |
|
|
|
return tensor.device, unique_id, storage_size(tensor) |
|
|
|
|
|
def cast_mixed_precision_params(model, dtype): |
|
""" |
|
Cast all non-trainable parameters of the model to the given `dtype`. The `dtype` can be `torch.float16` or |
|
`torch.bfloat16` as per the mixed-precision training you are performing. The trainable parameters are cast to full |
|
precision. This is meant to reduce the GPU memory usage when using PEFT methods by using half-precision dtype for |
|
non-trainable parameters. Having the trainable parameters in full-precision preserves training stability when using |
|
automatic mixed-precision training. |
|
|
|
Args: |
|
model (`torch.nn.Module`): |
|
The model to cast the non-trainable parameters of. |
|
dtype (`torch.dtype`): |
|
The dtype to cast the non-trainable parameters to. The `dtype` can be `torch.float16` or |
|
`torch.bfloat16` as per the mixed-precision training you are performing. |
|
""" |
|
for p in model.parameters(): |
|
if not p.requires_grad: |
|
p.data = p.to(dtype) |
|
else: |
|
p.data = p.to(torch.float32) |
|
|
|
|
|
def str_to_bool(value: str) -> int: |
|
""" |
|
Converts a string representation of truth to `True` (1) or `False` (0). |
|
|
|
True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`; |
|
""" |
|
|
|
value = value.lower() |
|
if value in ("y", "yes", "t", "true", "on", "1"): |
|
return 1 |
|
elif value in ("n", "no", "f", "false", "off", "0"): |
|
return 0 |
|
else: |
|
raise ValueError(f"invalid truth value {value}") |
|
|
|
|
|
def check_file_exists_on_hf_hub(repo_id: str, filename: str, **kwargs) -> Optional[bool]: |
|
"""Check if a file exists on HF Hub, if check was not successful returns None instead of erroring. |
|
|
|
Respect offline mode if set. |
|
|
|
""" |
|
exists: Optional[bool] = None |
|
if str_to_bool(os.environ.get("HF_HUB_OFFLINE", "0")): |
|
|
|
return exists |
|
|
|
try: |
|
exists = file_exists(repo_id, filename, **kwargs) |
|
except (HFValidationError, EntryNotFoundError): |
|
|
|
pass |
|
except Exception as e: |
|
warnings.warn( |
|
f"Unable to fetch remote file due to the following error {e} - silently ignoring the lookup" |
|
f" for the file {filename} in {repo_id}." |
|
) |
|
|
|
return exists |
|
|