|
import os |
|
import tempfile |
|
import unittest |
|
|
|
import torch |
|
|
|
from diffusers.loaders.lora_base import LoraBaseMixin |
|
|
|
|
|
class UtilityMethodDeprecationTests(unittest.TestCase): |
|
def test_fetch_state_dict_cls_method_raises_warning(self): |
|
state_dict = torch.nn.Linear(3, 3).state_dict() |
|
with self.assertWarns(FutureWarning) as warning: |
|
_ = LoraBaseMixin._fetch_state_dict( |
|
state_dict, |
|
weight_name=None, |
|
use_safetensors=False, |
|
local_files_only=True, |
|
cache_dir=None, |
|
force_download=False, |
|
proxies=None, |
|
token=None, |
|
revision=None, |
|
subfolder=None, |
|
user_agent=None, |
|
allow_pickle=None, |
|
) |
|
warning_message = str(warning.warnings[0].message) |
|
assert "Using the `_fetch_state_dict()` method from" in warning_message |
|
|
|
def test_best_guess_weight_name_cls_method_raises_warning(self): |
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
state_dict = torch.nn.Linear(3, 3).state_dict() |
|
torch.save(state_dict, os.path.join(tmpdir, "pytorch_lora_weights.bin")) |
|
|
|
with self.assertWarns(FutureWarning) as warning: |
|
_ = LoraBaseMixin._best_guess_weight_name(pretrained_model_name_or_path_or_dict=tmpdir) |
|
warning_message = str(warning.warnings[0].message) |
|
assert "Using the `_best_guess_weight_name()` method from" in warning_message |
|
|