kaggle / working /peft /tests /test_poly.py
1112lee's picture
nice-model
9d6cb8e verified
#!/usr/bin/env python3
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from peft import PeftModel, PolyConfig, TaskType, get_peft_model
class TestPoly(unittest.TestCase):
def test_poly(self):
torch.manual_seed(0)
model_name_or_path = "google/flan-t5-small"
atol, rtol = 1e-6, 1e-6
r = 8 # rank of lora in poly
n_tasks = 3 # number of tasks
n_skills = 2 # number of skills (loras)
n_splits = 4 # number of heads
lr = 1e-2
num_epochs = 10
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
peft_config = PolyConfig(
task_type=TaskType.SEQ_2_SEQ_LM,
poly_type="poly",
r=r,
n_tasks=n_tasks,
n_skills=n_skills,
n_splits=n_splits,
)
model = get_peft_model(base_model, peft_config)
# generate some dummy data
text = os.__doc__.splitlines()
assert len(text) > 10
inputs = tokenizer(text, return_tensors="pt", padding=True)
inputs["task_ids"] = torch.arange(len(text)) % n_tasks
inputs["labels"] = tokenizer((["A", "B"] * 100)[: len(text)], return_tensors="pt")["input_ids"]
# simple training loop
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
losses = []
for _ in range(num_epochs):
outputs = model(**inputs)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
losses.append(loss.item())
# loss improved by at least 50%
assert losses[-1] < (0.5 * losses[0])
# check that saving and loading works
torch.manual_seed(0)
model.eval()
logits_before = model(**inputs).logits
tokens_before = model.generate(**inputs)
with model.disable_adapter():
logits_disabled = model(**inputs).logits
tokens_disabled = model.generate(**inputs)
assert not torch.allclose(logits_before, logits_disabled, atol=atol, rtol=rtol)
assert not torch.allclose(tokens_before, tokens_disabled, atol=atol, rtol=rtol)
# saving and loading
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
loaded = PeftModel.from_pretrained(base_model, tmp_dir)
torch.manual_seed(0)
output_after = loaded(**inputs).logits
tokens_after = loaded.generate(**inputs)
assert torch.allclose(logits_before, output_after, atol=atol, rtol=rtol)
assert torch.allclose(tokens_before, tokens_after, atol=atol, rtol=rtol)