File size: 10,234 Bytes
9d6cb8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import pickle
import tempfile
import unittest
import warnings

import pytest
from parameterized import parameterized

from peft import (
    AdaLoraConfig,
    AdaptionPromptConfig,
    BOFTConfig,
    IA3Config,
    LoHaConfig,
    LoraConfig,
    MultitaskPromptTuningConfig,
    OFTConfig,
    PeftConfig,
    PeftType,
    PolyConfig,
    PrefixTuningConfig,
    PromptEncoder,
    PromptEncoderConfig,
    PromptTuningConfig,
    VeraConfig,
)


PEFT_MODELS_TO_TEST = [("lewtun/tiny-random-OPTForCausalLM-delta", "v1")]

ALL_CONFIG_CLASSES = (
    AdaptionPromptConfig,
    AdaLoraConfig,
    IA3Config,
    LoHaConfig,
    LoraConfig,
    MultitaskPromptTuningConfig,
    PrefixTuningConfig,
    PromptEncoderConfig,
    PromptTuningConfig,
    OFTConfig,
    PolyConfig,
    BOFTConfig,
    VeraConfig,
)


class PeftConfigTester(unittest.TestCase):
    @parameterized.expand(ALL_CONFIG_CLASSES)
    def test_methods(self, config_class):
        r"""
        Test if all configs have the expected methods. Here we test
        - to_dict
        - save_pretrained
        - from_pretrained
        - from_json_file
        """
        # test if all configs have the expected methods
        config = config_class()
        assert hasattr(config, "to_dict")
        assert hasattr(config, "save_pretrained")
        assert hasattr(config, "from_pretrained")
        assert hasattr(config, "from_json_file")

    @parameterized.expand(ALL_CONFIG_CLASSES)
    def test_task_type(self, config_class):
        config_class(task_type="test")

    def test_from_peft_type(self):
        r"""
        Test if the config is correctly loaded using:
        - from_peft_type
        """
        from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING

        for peft_type in PeftType:
            expected_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type]
            config = PeftConfig.from_peft_type(peft_type=peft_type)
            assert type(config) is expected_cls

    @parameterized.expand(ALL_CONFIG_CLASSES)
    def test_from_pretrained(self, config_class):
        r"""
        Test if the config is correctly loaded using:
        - from_pretrained
        """
        for model_name, revision in PEFT_MODELS_TO_TEST:
            # Test we can load config from delta
            config_class.from_pretrained(model_name, revision=revision)

    @parameterized.expand(ALL_CONFIG_CLASSES)
    def test_save_pretrained(self, config_class):
        r"""
        Test if the config is correctly saved and loaded using
        - save_pretrained
        """
        config = config_class()
        with tempfile.TemporaryDirectory() as tmp_dirname:
            config.save_pretrained(tmp_dirname)

            config_from_pretrained = config_class.from_pretrained(tmp_dirname)
            assert config.to_dict() == config_from_pretrained.to_dict()

    @parameterized.expand(ALL_CONFIG_CLASSES)
    def test_from_json_file(self, config_class):
        config = config_class()
        with tempfile.TemporaryDirectory() as tmp_dirname:
            config.save_pretrained(tmp_dirname)

            config_from_json = config_class.from_json_file(os.path.join(tmp_dirname, "adapter_config.json"))
            assert config.to_dict() == config_from_json

    @parameterized.expand(ALL_CONFIG_CLASSES)
    def test_to_dict(self, config_class):
        r"""
        Test if the config can be correctly converted to a dict using:
        - to_dict
        """
        config = config_class()
        assert isinstance(config.to_dict(), dict)

    @parameterized.expand(ALL_CONFIG_CLASSES)
    def test_from_pretrained_cache_dir(self, config_class):
        r"""
        Test if the config is correctly loaded with extra kwargs
        """
        with tempfile.TemporaryDirectory() as tmp_dirname:
            for model_name, revision in PEFT_MODELS_TO_TEST:
                # Test we can load config from delta
                config_class.from_pretrained(model_name, revision=revision, cache_dir=tmp_dirname)

    def test_from_pretrained_cache_dir_remote(self):
        r"""
        Test if the config is correctly loaded with a checkpoint from the hub
        """
        with tempfile.TemporaryDirectory() as tmp_dirname:
            PeftConfig.from_pretrained("ybelkada/test-st-lora", cache_dir=tmp_dirname)
            assert "models--ybelkada--test-st-lora" in os.listdir(tmp_dirname)

    @parameterized.expand(ALL_CONFIG_CLASSES)
    def test_set_attributes(self, config_class):
        # manually set attributes and check if they are correctly written
        config = config_class(peft_type="test")

        # save pretrained
        with tempfile.TemporaryDirectory() as tmp_dirname:
            config.save_pretrained(tmp_dirname)

            config_from_pretrained = config_class.from_pretrained(tmp_dirname)
            assert config.to_dict() == config_from_pretrained.to_dict()

    @parameterized.expand(ALL_CONFIG_CLASSES)
    def test_config_copy(self, config_class):
        # see https://github.com/huggingface/peft/issues/424
        config = config_class()
        copied = copy.copy(config)
        assert config.to_dict() == copied.to_dict()

    @parameterized.expand(ALL_CONFIG_CLASSES)
    def test_config_deepcopy(self, config_class):
        # see https://github.com/huggingface/peft/issues/424
        config = config_class()
        copied = copy.deepcopy(config)
        assert config.to_dict() == copied.to_dict()

    @parameterized.expand(ALL_CONFIG_CLASSES)
    def test_config_pickle_roundtrip(self, config_class):
        # see https://github.com/huggingface/peft/issues/424
        config = config_class()
        copied = pickle.loads(pickle.dumps(config))
        assert config.to_dict() == copied.to_dict()

    def test_prompt_encoder_warning_num_layers(self):
        # This test checks that if a prompt encoder config is created with an argument that is ignored, there should be
        # warning. However, there should be no warning if the default value is used.
        kwargs = {
            "num_virtual_tokens": 20,
            "num_transformer_submodules": 1,
            "token_dim": 768,
            "encoder_hidden_size": 768,
        }

        # there should be no warning with just default argument for encoder_num_layer
        config = PromptEncoderConfig(**kwargs)
        with warnings.catch_warnings():
            PromptEncoder(config)

        # when changing encoder_num_layer, there should be a warning for MLP since that value is not used
        config = PromptEncoderConfig(encoder_num_layers=123, **kwargs)
        with pytest.warns(UserWarning) as record:
            PromptEncoder(config)
        expected_msg = "for MLP, the argument `encoder_num_layers` is ignored. Exactly 2 MLP layers are used."
        assert str(record.list[0].message) == expected_msg

    @parameterized.expand([LoHaConfig, LoraConfig, IA3Config, OFTConfig, BOFTConfig])
    def test_save_pretrained_with_target_modules(self, config_class):
        # See #1041, #1045
        config = config_class(target_modules=["a", "list"])
        with tempfile.TemporaryDirectory() as tmp_dirname:
            config.save_pretrained(tmp_dirname)

            config_from_pretrained = config_class.from_pretrained(tmp_dirname)
            assert config.to_dict() == config_from_pretrained.to_dict()
            # explicit test that target_modules should be converted to set
            assert isinstance(config_from_pretrained.target_modules, set)

    def test_regex_with_layer_indexing_lora(self):
        # This test checks that an error is raised if `target_modules` is a regex expression and `layers_to_transform` or
        # `layers_pattern` are not None

        invalid_config1 = {"target_modules": ".*foo", "layers_to_transform": [0]}
        invalid_config2 = {"target_modules": ".*foo", "layers_pattern": ["bar"]}

        valid_config = {"target_modules": ["foo"], "layers_pattern": ["bar"], "layers_to_transform": [0]}

        with pytest.raises(ValueError, match="`layers_to_transform` cannot be used when `target_modules` is a str."):
            LoraConfig(**invalid_config1)

        with pytest.raises(ValueError, match="`layers_pattern` cannot be used when `target_modules` is a str."):
            LoraConfig(**invalid_config2)

        # should run without errors
        LoraConfig(**valid_config)

    def test_ia3_is_feedforward_subset_invalid_config(self):
        # This test checks that the IA3 config raises a value error if the feedforward_modules argument
        # is not a subset of the target_modules argument

        # an example invalid config
        invalid_config = {"target_modules": ["k", "v"], "feedforward_modules": ["q"]}

        with pytest.raises(ValueError, match="^`feedforward_modules` should be a subset of `target_modules`$"):
            IA3Config(**invalid_config)

    def test_ia3_is_feedforward_subset_valid_config(self):
        # This test checks that the IA3 config is created without errors with valid arguments.
        # feedforward_modules should be a subset of target_modules if both are lists

        # an example valid config with regex expressions.
        valid_config_regex_exp = {
            "target_modules": ".*.(SelfAttention|EncDecAttention|DenseReluDense).*(q|v|wo)$",
            "feedforward_modules": ".*.DenseReluDense.wo$",
        }
        # an example valid config with module lists.
        valid_config_list = {"target_modules": ["k", "v", "wo"], "feedforward_modules": ["wo"]}

        # should run without errors
        IA3Config(**valid_config_regex_exp)
        IA3Config(**valid_config_list)