|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pytest |
|
import torch |
|
from torch import nn |
|
from transformers import AutoModelForCausalLM |
|
|
|
from peft import LoraConfig, get_peft_model |
|
|
|
|
|
class ModelWithModuleDict(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.other_layer = nn.Linear(10, 10) |
|
self.module = nn.ModuleDict({"foo": nn.Linear(10, 10)}) |
|
|
|
def forward(self): |
|
return self.module["foo"](torch.rand(1, 10)) |
|
|
|
|
|
class ModelWithModuleList(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.other_layer = nn.Linear(10, 10) |
|
self.module = nn.ModuleList([nn.Linear(10, 10)]) |
|
|
|
def forward(self): |
|
return self.module[0](torch.rand(1, 10)) |
|
|
|
|
|
class ModelWithParameterDict(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.other_layer = nn.Linear(10, 10) |
|
self.module = nn.ParameterDict({"foo": nn.Parameter(torch.rand(10, 10))}) |
|
|
|
def forward(self): |
|
return self.module["foo"] |
|
|
|
|
|
class ModelWithParameterList(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.other_layer = nn.Linear(10, 10) |
|
self.module = nn.ParameterList([nn.Parameter(torch.rand(10, 10))]) |
|
|
|
def forward(self): |
|
return self.module[0] |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"cls", [ModelWithModuleDict, ModelWithModuleList, ModelWithParameterDict, ModelWithParameterList] |
|
) |
|
def test_modules_to_save_targets_module_dict_raises(cls): |
|
model = cls() |
|
peft_config = LoraConfig( |
|
target_modules=["other_layer"], |
|
modules_to_save=["module"], |
|
) |
|
model() |
|
|
|
msg = "modules_to_save cannot be applied to modules of type" |
|
with pytest.raises(TypeError, match=msg): |
|
get_peft_model(model=model, peft_config=peft_config) |
|
|
|
|
|
def test_get_peft_model_revision_warning(tmp_path): |
|
base_model_id = "peft-internal-testing/tiny-random-BertModel" |
|
base_revision = "v2.0.0" |
|
base_model = AutoModelForCausalLM.from_pretrained(base_model_id, revision=base_revision).eval() |
|
lora_config = LoraConfig(revision=base_revision) |
|
|
|
overwrite_revision = "main" |
|
overwrite_warning = f"peft config has already set base model revision to {base_revision}, overwriting with revision {overwrite_revision}" |
|
with pytest.warns(UserWarning, match=overwrite_warning): |
|
_ = get_peft_model(base_model, lora_config, revision=overwrite_revision) |
|
|