Spaces:
Runtime error
Runtime error
File size: 4,660 Bytes
f5fdf51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
import argparse
import torch
import numpy as np
def parse_args():
parser = argparse.ArgumentParser("Reparameterize YOLO-World")
parser.add_argument('--model', help='model checkpoints to reparameterize')
parser.add_argument('--out-dir', help='output checkpoints')
parser.add_argument(
'--text-embed',
help='text embeddings to reparameterized into YOLO-World')
parser.add_argument('--conv-neck',
action='store_true',
help='whether using 1x1 conv in RepVL-PAN')
args = parser.parse_args()
return args
def convert_head(scale, bias, text_embed):
N, D = text_embed.shape
weight = (text_embed * scale.exp()).view(N, D, 1, 1)
bias = torch.ones(N) * bias
return weight, bias
def reparameterize_head(state_dict, embeds):
cls_layers = [
'bbox_head.head_module.cls_contrasts.0',
'bbox_head.head_module.cls_contrasts.1',
'bbox_head.head_module.cls_contrasts.2'
]
for i in range(3):
scale = state_dict[cls_layers[i] + '.logit_scale']
bias = state_dict[cls_layers[i] + '.bias']
weight, bias = convert_head(scale, bias, embeds)
state_dict[cls_layers[i] + '.conv.weight'] = weight
state_dict[cls_layers[i] + '.conv.bias'] = bias
del state_dict[cls_layers[i] + '.bias']
del state_dict[cls_layers[i] + '.logit_scale']
return state_dict
def convert_neck_split_conv(input_state_dict, block_name, text_embeds,
num_heads):
if block_name + '.guide_fc.weight' not in input_state_dict:
return input_state_dict
guide_fc_weight = input_state_dict[block_name + '.guide_fc.weight']
guide_fc_bias = input_state_dict[block_name + '.guide_fc.bias']
guide = text_embeds @ guide_fc_weight.transpose(0,
1) + guide_fc_bias[None, :]
N, D = guide.shape
guide = list(guide.split(D // num_heads, dim=1))
del input_state_dict[block_name + '.guide_fc.weight']
del input_state_dict[block_name + '.guide_fc.bias']
for i in range(num_heads):
input_state_dict[block_name +
f'.guide_convs.{i}.weight'] = guide[i][:, :, None,
None]
return input_state_dict
def convert_neck_weight(input_state_dict, block_name, embeds, num_heads):
guide_fc_weight = input_state_dict[block_name + '.guide_fc.weight']
guide_fc_bias = input_state_dict[block_name + '.guide_fc.bias']
guide = embeds @ guide_fc_weight.transpose(0, 1) + guide_fc_bias[None, :]
N, D = guide.shape
del input_state_dict[block_name + '.guide_fc.weight']
del input_state_dict[block_name + '.guide_fc.bias']
input_state_dict[block_name + '.guide_weight'] = guide.view(
N, D // num_heads, num_heads)
return input_state_dict
def reparameterize_neck(state_dict, embeds, type='conv'):
neck_blocks = [
'neck.top_down_layers.0.attn_block',
'neck.top_down_layers.1.attn_block',
'neck.bottom_up_layers.0.attn_block',
'neck.bottom_up_layers.1.attn_block'
]
if "neck.top_down_layers.0.attn_block.bias" not in state_dict:
return state_dict
for block in neck_blocks:
num_heads = state_dict[block + '.bias'].shape[0]
if type == 'conv':
convert_neck_split_conv(state_dict, block, embeds, num_heads)
else:
convert_neck_weight(state_dict, block, embeds, num_heads)
return state_dict
def main():
args = parse_args()
# load checkpoint
model = torch.load(args.model, map_location='cpu')
state_dict = model['state_dict']
# load embeddings
embeddings = torch.from_numpy(np.load(args.text_embed))
# remove text encoder
keys = list(state_dict.keys())
keys = [x for x in keys if "text_model" not in x]
state_dict_wo_text = {x: state_dict[x] for x in keys}
print("removing text encoder")
state_dict_wo_text = reparameterize_head(state_dict_wo_text, embeddings)
print("reparameterizing head")
if args.conv_neck:
neck_type = "conv"
else:
neck_type = "linear"
state_dict_wo_text = reparameterize_neck(state_dict_wo_text, embeddings,
neck_type)
print("reparameterizing neck")
model['state_dict'] = state_dict_wo_text
model_name = os.path.basename(args.model)
model_name = model_name.replace('.pth', f'_rep_{neck_type}.pth')
torch.save(model, os.path.join(args.out_dir, model_name))
if __name__ == "__main__":
main()
|