Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
from unittest.mock import MagicMock, patch | |
import pytest | |
import torch | |
from llm_studio.src.utils.modeling_utils import ( | |
check_disk_space, | |
load_checkpoint, | |
save_checkpoint, | |
unwrap_model, | |
) | |
def test_unwrap_model(): | |
# Create a dummy model | |
model = torch.nn.Linear(10, 10) | |
# Wrap it in DataParallel | |
wrapped_model = torch.nn.DataParallel(model) | |
assert wrapped_model != model | |
assert isinstance(wrapped_model, torch.nn.DataParallel) | |
# Test unwrapping | |
unwrapped = unwrap_model(wrapped_model) | |
assert unwrapped == model | |
assert not isinstance(unwrapped, torch.nn.DataParallel) | |
def test_check_disk_space(free_space, dtype, should_raise): | |
# Mock model and shutil | |
model = MagicMock() | |
model.parameters.return_value = [torch.ones(1000, 1000, dtype=dtype)] | |
with patch("shutil.disk_usage", return_value=(0, 0, free_space)): | |
if should_raise: | |
with pytest.raises(ValueError): | |
check_disk_space(model, "/dummy/path") | |
else: | |
check_disk_space(model, "/dummy/path") # Should not raise | |
class DummyModel(torch.nn.Module): | |
def __init__(self, use_classification_head=True): | |
super(DummyModel, self).__init__() | |
self.backbone = torch.nn.Linear(10, 10) | |
if use_classification_head: | |
self.classification_head = torch.nn.Linear(10, 10) | |
def test_save_checkpoint(use_classification_head): | |
model = DummyModel(use_classification_head) | |
cfg = MagicMock() | |
cfg.environment.use_deepspeed = False | |
cfg.environment._local_rank = 0 | |
cfg.training.lora = False | |
with tempfile.TemporaryDirectory() as tmpdir: | |
save_checkpoint(model, tmpdir, cfg) | |
assert os.path.exists(os.path.join(tmpdir, "checkpoint.pth")) | |
if use_classification_head: | |
assert os.path.exists(os.path.join(tmpdir, "classification_head.pth")) | |
else: | |
assert not os.path.exists(os.path.join(tmpdir, "classification_head.pth")) | |
def test_load_checkpoint(): | |
model = DummyModel() | |
cfg = MagicMock() | |
cfg.architecture.pretrained_weights = "dummy_weights.pth" | |
cfg.environment.use_deepspeed = False | |
cfg.environment._local_rank = 0 | |
# Mock torch.load | |
dummy_state_dict = {"model": model.state_dict()} | |
with patch("torch.load", return_value=dummy_state_dict): | |
load_checkpoint(cfg, model, strict=True) | |
def test_load_checkpoint_mismatch(): | |
model = DummyModel(use_classification_head=True) | |
model_no_classification_head = DummyModel(use_classification_head=False) | |
cfg = MagicMock() | |
cfg.architecture.pretrained_weights = "dummy_weights.pth" | |
cfg.environment.use_deepspeed = False | |
cfg.environment._local_rank = 0 | |
# Mock torch.load | |
dummy_state_dict = {"model": model_no_classification_head.state_dict()} | |
with patch("torch.load", return_value=dummy_state_dict): | |
with pytest.raises(RuntimeError): | |
load_checkpoint(cfg, model, strict=True) | |
load_checkpoint(cfg, model, strict=False) | |