|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
import torch |
|
|
|
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model |
|
from peft.utils import ModulesToSaveWrapper |
|
|
|
|
|
class DummyModel(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.embedding = torch.nn.Embedding(10, 10) |
|
self.linear = torch.nn.Linear(10, 10) |
|
self.lm_head = torch.nn.Linear(10, 10) |
|
|
|
def forward(self, input_ids): |
|
x = self.embedding(input_ids) |
|
x = self.linear(x) |
|
x = self.lm_head(x) |
|
return x |
|
|
|
|
|
class TestPeft(unittest.TestCase): |
|
def setUp(self): |
|
self.model = DummyModel() |
|
|
|
lora_config = LoraConfig( |
|
lora_alpha=16, |
|
lora_dropout=0.1, |
|
r=64, |
|
bias="none", |
|
target_modules=["linear"], |
|
) |
|
|
|
self.model = inject_adapter_in_model(lora_config, self.model) |
|
|
|
def test_inject_adapter_in_model(self): |
|
dummy_inputs = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]) |
|
_ = self.model(dummy_inputs) |
|
|
|
for name, module in self.model.named_modules(): |
|
if name == "linear": |
|
assert hasattr(module, "lora_A") |
|
assert hasattr(module, "lora_B") |
|
|
|
def test_get_peft_model_state_dict(self): |
|
peft_state_dict = get_peft_model_state_dict(self.model) |
|
|
|
for key in peft_state_dict.keys(): |
|
assert "lora" in key |
|
|
|
def test_modules_to_save(self): |
|
self.model = DummyModel() |
|
|
|
lora_config = LoraConfig( |
|
lora_alpha=16, |
|
lora_dropout=0.1, |
|
r=64, |
|
bias="none", |
|
target_modules=["linear"], |
|
modules_to_save=["embedding"], |
|
) |
|
|
|
self.model = inject_adapter_in_model(lora_config, self.model) |
|
|
|
for name, module in self.model.named_modules(): |
|
if name == "linear": |
|
assert hasattr(module, "lora_A") |
|
assert hasattr(module, "lora_B") |
|
elif name == "embedding": |
|
assert isinstance(module, ModulesToSaveWrapper) |
|
|
|
state_dict = get_peft_model_state_dict(self.model) |
|
|
|
assert "embedding.weight" in state_dict.keys() |
|
|
|
assert hasattr(self.model.embedding, "weight") |
|
|