|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoModelForCausalLM |
|
|
|
from peft import LoraConfig, get_peft_model |
|
from peft.helpers import check_if_peft_model |
|
|
|
|
|
class TestCheckIsPeftModel: |
|
def test_valid_hub_model(self): |
|
result = check_if_peft_model("peft-internal-testing/gpt2-lora-random") |
|
assert result is True |
|
|
|
def test_invalid_hub_model(self): |
|
result = check_if_peft_model("gpt2") |
|
assert result is False |
|
|
|
def test_nonexisting_hub_model(self): |
|
result = check_if_peft_model("peft-internal-testing/non-existing-model") |
|
assert result is False |
|
|
|
def test_local_model_valid(self, tmp_path): |
|
model = AutoModelForCausalLM.from_pretrained("gpt2") |
|
config = LoraConfig() |
|
model = get_peft_model(model, config) |
|
model.save_pretrained(tmp_path / "peft-gpt2-valid") |
|
result = check_if_peft_model(tmp_path / "peft-gpt2-valid") |
|
assert result is True |
|
|
|
def test_local_model_invalid(self, tmp_path): |
|
model = AutoModelForCausalLM.from_pretrained("gpt2") |
|
model.save_pretrained(tmp_path / "peft-gpt2-invalid") |
|
result = check_if_peft_model(tmp_path / "peft-gpt2-invalid") |
|
assert result is False |
|
|
|
def test_local_model_broken_config(self, tmp_path): |
|
with open(tmp_path / "adapter_config.json", "w") as f: |
|
f.write('{"foo": "bar"}') |
|
|
|
result = check_if_peft_model(tmp_path) |
|
assert result is False |
|
|
|
def test_local_model_non_default_name(self, tmp_path): |
|
model = AutoModelForCausalLM.from_pretrained("gpt2") |
|
config = LoraConfig() |
|
model = get_peft_model(model, config, adapter_name="other") |
|
model.save_pretrained(tmp_path / "peft-gpt2-other") |
|
|
|
|
|
result = check_if_peft_model(tmp_path / "peft-gpt2-other") |
|
assert result is False |
|
|
|
|
|
result = check_if_peft_model(tmp_path / "peft-gpt2-other" / "other") |
|
assert result is True |
|
|