# Copyright 2024-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from transformers import AutoModelForCausalLM from peft import LoraConfig, get_peft_model from peft.helpers import check_if_peft_model class TestCheckIsPeftModel: def test_valid_hub_model(self): result = check_if_peft_model("peft-internal-testing/gpt2-lora-random") assert result is True def test_invalid_hub_model(self): result = check_if_peft_model("gpt2") assert result is False def test_nonexisting_hub_model(self): result = check_if_peft_model("peft-internal-testing/non-existing-model") assert result is False def test_local_model_valid(self, tmp_path): model = AutoModelForCausalLM.from_pretrained("gpt2") config = LoraConfig() model = get_peft_model(model, config) model.save_pretrained(tmp_path / "peft-gpt2-valid") result = check_if_peft_model(tmp_path / "peft-gpt2-valid") assert result is True def test_local_model_invalid(self, tmp_path): model = AutoModelForCausalLM.from_pretrained("gpt2") model.save_pretrained(tmp_path / "peft-gpt2-invalid") result = check_if_peft_model(tmp_path / "peft-gpt2-invalid") assert result is False def test_local_model_broken_config(self, tmp_path): with open(tmp_path / "adapter_config.json", "w") as f: f.write('{"foo": "bar"}') result = check_if_peft_model(tmp_path) assert result is False def test_local_model_non_default_name(self, tmp_path): model = AutoModelForCausalLM.from_pretrained("gpt2") config = LoraConfig() model = get_peft_model(model, config, adapter_name="other") model.save_pretrained(tmp_path / "peft-gpt2-other") # no default adapter here result = check_if_peft_model(tmp_path / "peft-gpt2-other") assert result is False # with adapter name result = check_if_peft_model(tmp_path / "peft-gpt2-other" / "other") assert result is True