Spaces:
Sleeping
Sleeping
import torch | |
from llm_studio.python_configs.text_causal_language_modeling_config import ( | |
ConfigProblemBase, | |
) | |
from llm_studio.src.models.text_causal_language_modeling_model import Model | |
from llm_studio.src.utils.modeling_utils import TokenStoppingCriteria, activate_neftune | |
def test_token_stopping_criteria(): | |
token_stopping_criteria = TokenStoppingCriteria( | |
stop_word_ids=torch.tensor([0, 1, 2, 8]), prompt_input_ids_len=4 | |
) | |
input_ids = torch.tensor( | |
[ | |
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], | |
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], | |
[2, 3, 4, 5, 6, 7, 8, 9, 10, 11], | |
[3, 4, 5, 6, 7, 8, 9, 10, 11, 12], | |
[4, 5, 6, 7, 8, 9, 10, 11, 12, 13], | |
[5, 6, 7, 8, 9, 10, 11, 12, 13, 14], | |
] | |
).long() | |
# prompt input len is 4, so generated ids of last sample of the batch are | |
# [9, 10, 11, 12, 13, 14], do not trigger stopping criteria | |
assert not token_stopping_criteria(input_ids=input_ids, scores=None) | |
token_stopping_criteria = TokenStoppingCriteria( | |
stop_word_ids=torch.tensor([6]), prompt_input_ids_len=0 | |
) | |
# first item reads [ 0, 1, 2, 3, 4, 5], so do not trigger stopping criteria | |
assert not token_stopping_criteria(input_ids=input_ids[:, :6], scores=None) | |
assert token_stopping_criteria(input_ids=input_ids[:, :7], scores=None) | |
# Test stopping criteria with compound tokens | |
token_stopping_criteria = TokenStoppingCriteria( | |
stop_word_ids=torch.tensor([[6, 7]]), prompt_input_ids_len=0 | |
) | |
assert not token_stopping_criteria(input_ids=input_ids[:, :6], scores=None) | |
assert not token_stopping_criteria(input_ids=input_ids[:, :7], scores=None) | |
assert token_stopping_criteria(input_ids=input_ids[:, :8], scores=None) | |
# Test stopping criteria with stop word ids being longer than generated text | |
token_stopping_criteria = TokenStoppingCriteria( | |
stop_word_ids=torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]]), | |
prompt_input_ids_len=0, | |
) | |
assert not token_stopping_criteria(input_ids=input_ids, scores=None) | |
def test_neftune_is_disabled_in_inference(): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
cfg = ConfigProblemBase(llm_backbone="h2oai/llama2-0b-unit-test") | |
cfg.architecture.backbone_dtype = "float32" | |
model = Model(cfg).eval().to(device) | |
input_batch = { | |
"input_ids": torch.randint( | |
0, | |
1000, | |
(1, 10), | |
).to(device), | |
"attention_mask": torch.ones(1, 10).to(device), | |
} | |
with torch.no_grad(): | |
outputs = model.backbone(**input_batch) | |
activate_neftune(model, neftune_noise_alpha=10) | |
assert model.backbone.get_input_embeddings().neftune_noise_alpha == 10 | |
with torch.no_grad(): | |
outputs_after_neftune = model.backbone(**input_batch) | |
assert torch.allclose(outputs["logits"], outputs_after_neftune["logits"]) | |
# state dict does not contain neftune noise | |
assert [key for key in model.state_dict() if "neftune" in key] == [] | |