Spaces:
Running
Running
# Copyright 2025 the LlamaFactory team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import pytest | |
import torch | |
from llamafactory.model.model_utils.packing import get_seqlens_in_batch, get_unpad_data | |
def test_get_seqlens_in_batch(attention_mask, golden_seq_lens): | |
attention_mask_with_indices = torch.tensor(attention_mask) | |
seqlens_in_batch = get_seqlens_in_batch(attention_mask_with_indices) | |
assert torch.all(seqlens_in_batch == torch.tensor(golden_seq_lens)) | |
def test_get_unpad_data(attention_mask, golden_indices, golden_cu_seqlens, golden_max_seqlen): | |
attention_mask_with_indices = torch.tensor(attention_mask) | |
indices, cu_seqlens, max_seqlen_in_batch = get_unpad_data(attention_mask_with_indices) | |
assert torch.all(indices == torch.tensor(golden_indices)) | |
assert torch.all(cu_seqlens == torch.tensor(golden_cu_seqlens, dtype=torch.int32)) | |
assert max_seqlen_in_batch == golden_max_seqlen | |