File size: 2,567 Bytes
1867713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
081d660
1867713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python3

"""Convert Satlas-Pretrain model checkpoints to a format accepted by TorchGeo.

Reference implementation:

* https://github.com/allenai/satlaspretrain_models/blob/main/satlaspretrain_models/models/backbones.py
"""

import glob
import hashlib
import os

import timm
import torch
import torchvision


for checkpoint in glob.iglob('*.pth'):
    # Skip if already converted
    if '-' in checkpoint:
        continue

    print(checkpoint)

    # Map to CPU
    state_dict = torch.load(checkpoint, map_location=torch.device('cpu'), weights_only=True)

    # Extract backbone
    if 'backbone.backbone.resnet.conv1.weight' in state_dict:
        state_dict = {key.replace('backbone.backbone.resnet.', ''): value for key, value in state_dict.items() if key.startswith('backbone.backbone.resnet.')}
    elif 'backbone.resnet.conv1.weight' in state_dict:
        state_dict = {key.replace('backbone.resnet.', ''): value for key, value in state_dict.items() if key.startswith('backbone.resnet.')}
    elif 'backbone.backbone.backbone.features.0.0.weight' in state_dict:
        state_dict = {key.replace('backbone.backbone.backbone.', ''): value for key, value in state_dict.items() if key.startswith('backbone.backbone.backbone.')}
    elif 'backbone.backbone.features.0.0.weight' in state_dict:
        state_dict = {key.replace('backbone.backbone.', ''): value for key, value in state_dict.items() if key.startswith('backbone.backbone.')}

    if 'resnet' in checkpoint:
        # Extract # channels
        in_chans = state_dict['conv1.weight'].shape[1]

        # Create model
        model_name = checkpoint.split('_')[1]
        model = timm.create_model(model_name, in_chans=in_chans)
    elif 'swin' in checkpoint:
        # Extract # channels
        out_channels, num_channels, kernel_size_0, kernel_size_1 = state_dict['features.0.0.weight'].shape

        # Create model
        if 'swint' in checkpoint:
            model = torchvision.models.swin_v2_t()
        elif 'swinb' in checkpoint:
            model = torchvision.models.swin_v2_b()

        model.features[0][0] = torch.nn.Conv2d(num_channels, out_channels, kernel_size=(kernel_size_0, kernel_size_1), stride=(4, 4))

    # Load weights
    model.load_state_dict(state_dict)

    # Save model
    torch.save(model.state_dict(), f'{checkpoint}.tmp')

    # Compute the checksum
    with open(f'{checkpoint}.tmp', 'rb') as f:
        checksum = hashlib.file_digest(f, 'sha256').hexdigest()

    # Rename
    os.rename(f'{checkpoint}.tmp', f'{checkpoint[:-4]}-{checksum[:8]}.pth')