File size: 20,432 Bytes
9d6cb8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 |
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from copy import deepcopy
import pytest
import torch
from scipy import stats
from torch import nn
from peft import AdaLoraConfig, LoraConfig, PeftModel, PromptTuningConfig, VeraConfig, get_peft_model
from peft.utils import infer_device
class TestLoraInitialization:
"""Test class to check the initialization of adapters."""
torch_device = infer_device()
def get_uniform(self, amin, amax, size=(10000,)):
unif = torch.distributions.uniform.Uniform(amin, amax)
samples = unif.sample(size)
return samples
def get_normal(self, mean, std, size=(10000,)):
normal = torch.distributions.normal.Normal(mean, std)
samples = normal.sample(size)
return samples
def get_model(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
# choose a large weight so that averages are close to expected values
self.linear = nn.Linear(1000, 1000)
self.embed = nn.Embedding(1000, 1000)
self.conv2d = nn.Conv2d(100, 100, 3)
def forward(self, x):
x_int = (100 * x).int()
x_4d = x.flatten().reshape(1, 100, 10, 10)
return self.linear(x), self.embed(x_int), self.conv2d(x_4d)
return MyModule().eval().to(self.torch_device)
@pytest.fixture
def data(self):
return torch.rand(10, 1000).to(self.torch_device)
def test_lora_linear_init_default(self):
# default is True
torch.manual_seed(0)
model = self.get_model()
config = LoraConfig(target_modules=["linear"])
model = get_peft_model(model, config)
weight_A = model.linear.lora_A["default"].weight
weight_B = model.linear.lora_B["default"].weight
# use statistical test to check if weight A is from a uniform distribution
unif = self.get_uniform(weight_A.min().item(), weight_A.max().item())
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy())
assert p_value > 0.5
# check that weight A is *not* from a normal distribution
normal = self.get_normal(weight_A.mean().item(), weight_A.std().item())
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy())
assert p_value < 0.05
# check that weight B is zero
assert (weight_B == 0.0).all()
def test_lora_linear_init_gaussian(self):
# use gaussian init
torch.manual_seed(0)
model = self.get_model()
config = LoraConfig(target_modules=["linear"], init_lora_weights="gaussian")
model = get_peft_model(model, config)
weight_A = model.linear.lora_A["default"].weight
weight_B = model.linear.lora_B["default"].weight
# use statistical test to check if weight A is from a normal distribution
normal = self.get_normal(0.0, 1 / config.r)
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy())
# import matplotlib.pyplot as plt
# x = weight_A.detach().flatten().cpu().numpy()
# breakpoint()
assert p_value > 0.5
# check that weight A is *not* from a uniform distribution
unif = self.get_uniform(weight_A.min().item(), weight_A.max().item())
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy())
assert p_value < 0.05
# check that weight B is zero
assert (weight_B == 0.0).all()
def test_lora_linear_false(self):
torch.manual_seed(0)
model = self.get_model()
config = LoraConfig(target_modules=["linear"], init_lora_weights=False)
model = get_peft_model(model, config)
weight_B = model.linear.lora_B["default"].weight
# with init_lora_weights=False, weight B should *not* be zero. We don't care so much about the actual values
# as long as they are not zero, in order to avoid identity transformation.
assert not torch.allclose(weight_B, torch.zeros_like(weight_B))
def test_lora_embedding_default(self):
# embedding is initialized as a normal distribution, not kaiming uniform
torch.manual_seed(0)
model = self.get_model()
config = LoraConfig(target_modules=["embed"])
model = get_peft_model(model, config)
weight_A = model.embed.lora_embedding_A["default"]
weight_B = model.embed.lora_embedding_B["default"]
# use statistical test to check if weight B is from a normal distribution
normal = self.get_normal(0.0, 1.0)
_, p_value = stats.kstest(weight_B.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy())
assert p_value > 0.5
# check that weight B is *not* from a uniform distribution
unif = self.get_uniform(weight_B.min().item(), weight_B.max().item())
_, p_value = stats.kstest(weight_B.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy())
assert p_value < 0.05
# check that weight A is zero
assert (weight_A == 0.0).all()
def test_lora_embedding_gaussian(self):
# embedding does not change with init_lora_weights="gaussian" vs True
torch.manual_seed(0)
model = self.get_model()
config = LoraConfig(target_modules=["embed"], init_lora_weights="gaussian")
model = get_peft_model(model, config)
weight_A = model.embed.lora_embedding_A["default"]
weight_B = model.embed.lora_embedding_B["default"]
# use statistical test to check if weight B is from a normal distribution
normal = self.get_normal(0.0, 1.0)
_, p_value = stats.kstest(weight_B.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy())
assert p_value > 0.5
# check that weight B is *not* from a uniform distribution
unif = self.get_uniform(weight_B.min().item(), weight_B.max().item())
_, p_value = stats.kstest(weight_B.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy())
assert p_value < 0.05
# check that weight A is zero
assert (weight_A == 0.0).all()
def test_lora_embedding_false(self):
torch.manual_seed(0)
model = self.get_model()
config = LoraConfig(target_modules=["embed"], init_lora_weights=False)
model = get_peft_model(model, config)
weight_A = model.embed.lora_embedding_B["default"]
# with init_lora_weights=False, weight A should *not* be zero. We don't care so much about the actual values
# as long as they are not zero, in order to avoid identity transformation.
assert not torch.allclose(weight_A, torch.zeros_like(weight_A))
def test_lora_conv2d_default(self):
# default is True
torch.manual_seed(0)
model = self.get_model()
config = LoraConfig(target_modules=["conv2d"])
model = get_peft_model(model, config)
weight_A = model.conv2d.lora_A["default"].weight
weight_B = model.conv2d.lora_B["default"].weight
# use statistical test to check if weight A is from a uniform distribution
unif = self.get_uniform(weight_A.min().item(), weight_A.max().item())
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy())
assert p_value > 0.5
# check that weight A is *not* from a normal distribution
normal = self.get_normal(weight_A.mean().item(), weight_A.std().item())
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy())
assert p_value < 0.05
# check that weight B is zero
assert (weight_B == 0.0).all()
def test_lora_conv2d_init_gaussian(self):
# use gaussian init
torch.manual_seed(0)
model = self.get_model()
config = LoraConfig(target_modules=["conv2d"], init_lora_weights="gaussian")
model = get_peft_model(model, config)
weight_A = model.conv2d.lora_A["default"].weight
weight_B = model.conv2d.lora_B["default"].weight
# use statistical test to check if weight A is from a normal distribution
normal = self.get_normal(0.0, 1 / config.r)
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), normal.flatten().cpu().numpy())
assert p_value > 0.5
# check that weight A is *not* from a uniform distribution
unif = self.get_uniform(weight_A.min().item(), weight_A.max().item())
_, p_value = stats.kstest(weight_A.detach().flatten().cpu().numpy(), unif.flatten().cpu().numpy())
assert p_value < 0.05
# check that weight B is zero
assert (weight_B == 0.0).all()
def test_lora_conv2d_false(self):
torch.manual_seed(0)
model = self.get_model()
config = LoraConfig(target_modules=["conv2d"], init_lora_weights=False)
model = get_peft_model(model, config)
weight_B = model.conv2d.lora_B["default"].weight
# with init_lora_weights=False, weight B should *not* be zero. We don't care so much about the actual values
# as long as they are not zero, in order to avoid identity transformation.
assert not torch.allclose(weight_B, torch.zeros_like(weight_B))
def test_lora_scaling_default(self):
# default is True
torch.manual_seed(0)
model = self.get_model()
# check scaling factor use_rslora=False
config = LoraConfig(target_modules=["linear", "embed", "conv2d"], lora_alpha=3, r=16, use_rslora=False)
model = get_peft_model(model, config)
expected_scaling = config.lora_alpha / config.r
assert model.linear.scaling["default"] == expected_scaling
assert model.embed.scaling["default"] == expected_scaling
assert model.conv2d.scaling["default"] == expected_scaling
def test_lora_pissa_linear_init_default(self, data):
model = self.get_model()
output = model(data)[0]
config = LoraConfig(init_lora_weights="pissa", target_modules=["linear"])
peft_model = get_peft_model(deepcopy(model), config)
assert torch.allclose(output, peft_model(data)[0], atol=1e-06)
config = LoraConfig(init_lora_weights="pissa_niter_16", target_modules=["linear"])
peft_model = get_peft_model(deepcopy(model), config)
assert torch.allclose(output, peft_model(data)[0], atol=1e-06)
def test_lora_pissa_conversion_same_output_after_loading(self, data, tmp_path):
model = self.get_model()
output_base = model(data)[0]
config = LoraConfig(init_lora_weights="pissa", target_modules=["linear"], r=8)
peft_model = get_peft_model(deepcopy(model), config)
# save the initial model
peft_model.peft_config["default"].init_lora_weights = True
peft_model.save_pretrained(tmp_path / "init-model")
peft_model.peft_config["default"].init_lora_weights = "pissa"
# modify the weights, or else the adapter performs an identity transformation
peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0
output_pissa = peft_model(data)[0]
# sanity check
tol = 1e-06
assert not torch.allclose(output_base, output_pissa, atol=tol, rtol=tol)
# save the model normally
peft_model.save_pretrained(tmp_path / "pissa-model")
model_loaded = PeftModel.from_pretrained(deepcopy(model), tmp_path / "pissa-model")
output_loaded = model_loaded(data)[0]
assert torch.allclose(output_pissa, output_loaded, atol=tol, rtol=tol)
# sanity check: ranks should still be 8 as initially
assert model_loaded.peft_config["default"].r == 8
assert model_loaded.base_model.model.linear.lora_A["default"].weight.shape[0] == 8
# sanity check: the base model weights were indeed changed
assert not torch.allclose(
model.linear.weight, model_loaded.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol
)
# save the model with conversion
peft_model.save_pretrained(tmp_path / "pissa-model-converted", convert_pissa_to_lora=tmp_path / "init-model")
model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "pissa-model-converted")
output_converted = model_converted(data)[0]
assert torch.allclose(output_pissa, output_converted, atol=tol, rtol=tol)
# rank should be double of what it was initially
assert model_converted.peft_config["default"].r == 16
assert model_converted.base_model.model.linear.lora_A["default"].weight.shape[0] == 16
# base model weights should be the same as the initial model
assert torch.allclose(
model.linear.weight, model_converted.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol
)
def test_lora_rslora_scaling(self):
# default is True
torch.manual_seed(0)
model = self.get_model()
# check scaling factor use_rslora=True
config = LoraConfig(target_modules=["linear", "embed", "conv2d"], lora_alpha=3, r=16, use_rslora=True)
model = get_peft_model(model, config)
expected_scaling = config.lora_alpha / (config.r**0.5)
assert model.linear.scaling["default"] == expected_scaling
assert model.embed.scaling["default"] == expected_scaling
assert model.conv2d.scaling["default"] == expected_scaling
def test_lora_default_scaling_pattern(self):
# default is True
torch.manual_seed(0)
model = self.get_model()
# check scaling factor use_rslora=False with rank and alpha pattern
config = LoraConfig(
target_modules=["linear", "embed", "conv2d"],
rank_pattern={"embed": 9, "conv2d": 16},
alpha_pattern={"linear": 11, "conv2d": 13},
lora_alpha=17,
r=25,
use_rslora=False,
)
model = get_peft_model(model, config)
expected_scaling = {
"linear": config.alpha_pattern["linear"] / config.r,
"embed": config.lora_alpha / config.rank_pattern["embed"],
"conv2d": config.alpha_pattern["conv2d"] / config.rank_pattern["conv2d"],
}
assert model.linear.scaling["default"] == expected_scaling["linear"]
assert model.embed.scaling["default"] == expected_scaling["embed"]
assert model.conv2d.scaling["default"] == expected_scaling["conv2d"]
def test_lora_rslora_scaling_pattern(self):
# default is True
torch.manual_seed(0)
model = self.get_model()
# check scaling factor use_rslora=True with rank and alpha pattern
config = LoraConfig(
target_modules=["linear", "embed", "conv2d"],
rank_pattern={"embed": 9, "conv2d": 16},
alpha_pattern={"linear": 11, "conv2d": 13},
lora_alpha=17,
r=25,
use_rslora=True,
)
model = get_peft_model(model, config)
expected_scaling = {
"linear": config.alpha_pattern["linear"] / (config.r**0.5),
"embed": config.lora_alpha / (config.rank_pattern["embed"] ** 0.5),
"conv2d": config.alpha_pattern["conv2d"] / (config.rank_pattern["conv2d"] ** 0.5),
}
assert model.linear.scaling["default"] == expected_scaling["linear"]
assert model.embed.scaling["default"] == expected_scaling["embed"]
assert model.conv2d.scaling["default"] == expected_scaling["conv2d"]
def test_lora_use_dora_linear(self, data):
# check that dora is a no-op when initialized
torch.manual_seed(0)
model = self.get_model()
output_base, _, _ = model(data)
# check scaling factor use_rslora=True
config = LoraConfig(target_modules=["linear"], use_dora=True)
model = get_peft_model(model, config)
with model.disable_adapter():
output_disabled, _, _ = model(data)
output_dora, _, _ = model(data)
assert torch.allclose(output_base, output_disabled)
assert torch.allclose(output_base, output_dora)
def test_lora_use_dora_linear_init_false(self, data):
# with init_lora_weights=False, dora should not be a no-op
torch.manual_seed(0)
model = self.get_model()
output_base, _, _ = model(data)
# check scaling factor use_rslora=True
config = LoraConfig(target_modules=["linear"], use_dora=True, init_lora_weights=False)
model = get_peft_model(model, config)
with model.disable_adapter():
output_disabled, _, _ = model(data)
output_dora, _, _ = model(data)
assert torch.allclose(output_base, output_disabled)
assert not torch.allclose(output_base, output_dora)
def test_lora_use_dora_with_megatron_core_raises(self):
megatron_config = {"does-not": "matter-here"}
with pytest.raises(ValueError, match="DoRA does not support megatron_core"):
LoraConfig(target_modules=["linear"], use_dora=True, megatron_config=megatron_config)
class TestAdaLoraInitialization:
def test_adalora_target_modules_set(self):
config = AdaLoraConfig(target_modules=["linear", "embed", "conv2d"])
assert config.target_modules == {"linear", "embed", "conv2d"}
def test_adalora_use_dora_raises(self):
with pytest.raises(ValueError, match="ADALORA does not support DoRA"):
AdaLoraConfig(use_dora=True)
def test_adalora_loftq_config_raises(self):
with pytest.raises(ValueError, match="ADALORA does not support LOFTQ"):
AdaLoraConfig(loftq_config={"loftq": "config"})
class TestPromptTuningInitialization:
torch_device = infer_device()
def get_model(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
# choose a large weight so that averages are close to expected values
self.linear = nn.Linear(1000, 1000)
self.embed = nn.Embedding(1000, 1000)
self.conv2d = nn.Conv2d(100, 100, 3)
def forward(self, x):
x_int = (100 * x).int()
x_4d = x.flatten().reshape(1, 100, 10, 10)
return self.linear(x), self.embed(x_int), self.conv2d(x_4d)
return MyModule().eval().to(self.torch_device)
def test_use_prompt_tuning_init_text_raises(self):
with pytest.raises(ValueError, match="When prompt_tuning_init='TEXT', tokenizer_name_or_path can't be None"):
PromptTuningConfig(prompt_tuning_init="TEXT", prompt_tuning_init_text="prompt tuning init text")
with pytest.raises(ValueError, match="When prompt_tuning_init='TEXT', prompt_tuning_init_text can't be None"):
PromptTuningConfig(prompt_tuning_init="TEXT", tokenizer_name_or_path="t5-base")
def test_vera_mixing_save_projection_raises(self):
# it is unclear what the right thing to do would be if some adapters save the projection weights and some don't
# so we better raise an error
config0 = VeraConfig(target_modules="linear", init_weights=False, save_projection=True)
model = self.get_model()
model = get_peft_model(model, config0)
config1 = VeraConfig(target_modules="linear", init_weights=False, save_projection=False)
msg = re.escape(
"VeRA projection weights must be saved for all adapters or none, but got multiple different values: "
"[False, True]"
)
with pytest.raises(ValueError, match=msg):
model.add_adapter("other", config1)
|