|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import tempfile |
|
import unittest |
|
|
|
import torch |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
|
from peft import PeftModel, PolyConfig, TaskType, get_peft_model |
|
|
|
|
|
class TestPoly(unittest.TestCase): |
|
def test_poly(self): |
|
torch.manual_seed(0) |
|
model_name_or_path = "google/flan-t5-small" |
|
|
|
atol, rtol = 1e-6, 1e-6 |
|
r = 8 |
|
n_tasks = 3 |
|
n_skills = 2 |
|
n_splits = 4 |
|
lr = 1e-2 |
|
num_epochs = 10 |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
|
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) |
|
|
|
peft_config = PolyConfig( |
|
task_type=TaskType.SEQ_2_SEQ_LM, |
|
poly_type="poly", |
|
r=r, |
|
n_tasks=n_tasks, |
|
n_skills=n_skills, |
|
n_splits=n_splits, |
|
) |
|
|
|
model = get_peft_model(base_model, peft_config) |
|
|
|
|
|
text = os.__doc__.splitlines() |
|
assert len(text) > 10 |
|
inputs = tokenizer(text, return_tensors="pt", padding=True) |
|
inputs["task_ids"] = torch.arange(len(text)) % n_tasks |
|
inputs["labels"] = tokenizer((["A", "B"] * 100)[: len(text)], return_tensors="pt")["input_ids"] |
|
|
|
|
|
model.train() |
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr) |
|
losses = [] |
|
for _ in range(num_epochs): |
|
outputs = model(**inputs) |
|
loss = outputs.loss |
|
loss.backward() |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
losses.append(loss.item()) |
|
|
|
|
|
assert losses[-1] < (0.5 * losses[0]) |
|
|
|
|
|
torch.manual_seed(0) |
|
model.eval() |
|
logits_before = model(**inputs).logits |
|
tokens_before = model.generate(**inputs) |
|
|
|
with model.disable_adapter(): |
|
logits_disabled = model(**inputs).logits |
|
tokens_disabled = model.generate(**inputs) |
|
|
|
assert not torch.allclose(logits_before, logits_disabled, atol=atol, rtol=rtol) |
|
assert not torch.allclose(tokens_before, tokens_disabled, atol=atol, rtol=rtol) |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
model.save_pretrained(tmp_dir) |
|
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) |
|
loaded = PeftModel.from_pretrained(base_model, tmp_dir) |
|
|
|
torch.manual_seed(0) |
|
output_after = loaded(**inputs).logits |
|
tokens_after = loaded.generate(**inputs) |
|
assert torch.allclose(logits_before, output_after, atol=atol, rtol=rtol) |
|
assert torch.allclose(tokens_before, tokens_after, atol=atol, rtol=rtol) |
|
|