Upload folder using huggingface_hub
Browse files- README.md +1 -1
- convert_weights.py +161 -0
README.md
CHANGED
@@ -11,4 +11,4 @@ python3 imitate_episodes.py \
|
|
11 |
--seed 0 \
|
12 |
```
|
13 |
|
14 |
-
Also, the code was patched to train with all the data rather than reserving a validation split.
|
|
|
11 |
--seed 0 \
|
12 |
```
|
13 |
|
14 |
+
Also, the code was patched to train with all the data rather than reserving a validation split. The weights are converted with [`convert_weights.py`](convert_weights.py).
|
convert_weights.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
|
6 |
+
from lerobot.common.datasets.factory import make_dataset
|
7 |
+
from lerobot.common.policies.factory import make_policy
|
8 |
+
from lerobot.common.utils.utils import init_hydra_config
|
9 |
+
|
10 |
+
PATH_TO_ORIGINAL_WEIGHTS = "/home/alexander/Projects/act/outputs/sim_transfer_cube_human_vae/policy_last.ckpt"
|
11 |
+
PATH_TO_CONFIG = (
|
12 |
+
"outputs/train/act_aloha_sim_transfer_cube_human_final_video/.hydra/config.yaml"
|
13 |
+
)
|
14 |
+
PATH_TO_SAVE_NEW_WEIGHTS = "/tmp/act"
|
15 |
+
|
16 |
+
|
17 |
+
cfg = init_hydra_config(PATH_TO_CONFIG)
|
18 |
+
|
19 |
+
policy = make_policy(hydra_cfg=cfg, dataset_stats=make_dataset(cfg).stats)
|
20 |
+
|
21 |
+
state_dict = torch.load(PATH_TO_ORIGINAL_WEIGHTS)
|
22 |
+
|
23 |
+
# Remove keys based on what they start with.
|
24 |
+
|
25 |
+
start_removals = [
|
26 |
+
# There is a bug that means the pretrained model doesn't even use the final decoder layers.
|
27 |
+
*[f"model.transformer.decoder.layers.{i}" for i in range(1, 7)],
|
28 |
+
"model.is_pad_head.",
|
29 |
+
]
|
30 |
+
|
31 |
+
for to_remove in start_removals:
|
32 |
+
for k in list(state_dict.keys()):
|
33 |
+
if k.startswith(to_remove):
|
34 |
+
del state_dict[k]
|
35 |
+
|
36 |
+
|
37 |
+
# Replace keys based on what they start with.
|
38 |
+
|
39 |
+
start_replacements = [
|
40 |
+
("model.query_embed.weight", "model.pos_embed.weight"),
|
41 |
+
("model.pos_table", "model.vae_encoder_pos_enc"),
|
42 |
+
("model.pos_embed.weight", "model.decoder_pos_embed.weight"),
|
43 |
+
("model.encoder.", "model.vae_encoder."),
|
44 |
+
("model.encoder_action_proj.", "model.vae_encoder_action_input_proj."),
|
45 |
+
("model.encoder_joint_proj.", "model.vae_encoder_robot_state_input_proj."),
|
46 |
+
("model.latent_proj.", "model.vae_encoder_latent_output_proj."),
|
47 |
+
("model.latent_proj.", "model.vae_encoder_latent_output_proj."),
|
48 |
+
("model.input_proj.", "model.encoder_img_feat_input_proj."),
|
49 |
+
("model.input_proj_robot_state", "model.encoder_robot_state_input_proj"),
|
50 |
+
("model.latent_out_proj.", "model.encoder_latent_input_proj."),
|
51 |
+
("model.transformer.encoder.", "model.encoder."),
|
52 |
+
("model.transformer.decoder.", "model.decoder."),
|
53 |
+
("model.backbones.0.0.body.", "model.backbone."),
|
54 |
+
("model.additional_pos_embed.weight", "model.encoder_robot_and_latent_pos_embed.weight"),
|
55 |
+
("model.cls_embed.weight", "model.vae_encoder_cls_embed.weight"),
|
56 |
+
]
|
57 |
+
|
58 |
+
for to_replace, replace_with in start_replacements:
|
59 |
+
for k in list(state_dict.keys()):
|
60 |
+
if k.startswith(to_replace):
|
61 |
+
k_ = replace_with + k.removeprefix(to_replace)
|
62 |
+
state_dict[k_] = state_dict[k]
|
63 |
+
del state_dict[k]
|
64 |
+
|
65 |
+
|
66 |
+
state_dict["normalize_inputs.buffer_observation_images_top.mean"] = torch.tensor(
|
67 |
+
[[[0.4850]], [[0.4560]], [[0.4060]]]
|
68 |
+
)
|
69 |
+
state_dict["normalize_inputs.buffer_observation_images_top.std"] = torch.tensor(
|
70 |
+
[[[0.2290]], [[0.2240]], [[0.2250]]]
|
71 |
+
)
|
72 |
+
state_dict["normalize_inputs.buffer_observation_state.mean"] = torch.tensor(
|
73 |
+
[
|
74 |
+
-0.0074,
|
75 |
+
-0.6319,
|
76 |
+
1.0357,
|
77 |
+
-0.0503,
|
78 |
+
-0.4620,
|
79 |
+
-0.0747,
|
80 |
+
0.4747,
|
81 |
+
-0.0362,
|
82 |
+
-0.3320,
|
83 |
+
0.9039,
|
84 |
+
-0.2206,
|
85 |
+
-0.3101,
|
86 |
+
-0.2348,
|
87 |
+
0.6842,
|
88 |
+
]
|
89 |
+
)
|
90 |
+
state_dict["normalize_inputs.buffer_observation_state.std"] = torch.tensor(
|
91 |
+
[
|
92 |
+
0.0122,
|
93 |
+
0.2975,
|
94 |
+
0.1673,
|
95 |
+
0.0473,
|
96 |
+
0.1486,
|
97 |
+
0.0879,
|
98 |
+
0.3175,
|
99 |
+
0.1050,
|
100 |
+
0.2793,
|
101 |
+
0.1809,
|
102 |
+
0.2660,
|
103 |
+
0.3047,
|
104 |
+
0.5299,
|
105 |
+
0.2550,
|
106 |
+
]
|
107 |
+
)
|
108 |
+
state_dict["unnormalize_outputs.buffer_action.mean"] = torch.tensor(
|
109 |
+
[
|
110 |
+
-0.0076,
|
111 |
+
-0.6282,
|
112 |
+
1.0313,
|
113 |
+
-0.0466,
|
114 |
+
-0.4721,
|
115 |
+
-0.0745,
|
116 |
+
0.3739,
|
117 |
+
-0.0372,
|
118 |
+
-0.3261,
|
119 |
+
0.8997,
|
120 |
+
-0.2137,
|
121 |
+
-0.3184,
|
122 |
+
-0.2336,
|
123 |
+
0.5519,
|
124 |
+
]
|
125 |
+
)
|
126 |
+
state_dict["normalize_targets.buffer_action.mean"] = state_dict["unnormalize_outputs.buffer_action.mean"]
|
127 |
+
state_dict["unnormalize_outputs.buffer_action.std"] = torch.tensor(
|
128 |
+
[
|
129 |
+
0.0125,
|
130 |
+
0.2957,
|
131 |
+
0.1670,
|
132 |
+
0.0458,
|
133 |
+
0.1483,
|
134 |
+
0.0876,
|
135 |
+
0.3067,
|
136 |
+
0.1060,
|
137 |
+
0.2757,
|
138 |
+
0.1806,
|
139 |
+
0.2630,
|
140 |
+
0.3071,
|
141 |
+
0.5305,
|
142 |
+
0.3838,
|
143 |
+
]
|
144 |
+
)
|
145 |
+
state_dict["normalize_targets.buffer_action.std"] = state_dict["unnormalize_outputs.buffer_action.std"]
|
146 |
+
|
147 |
+
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
148 |
+
|
149 |
+
if len(missing_keys) != 0:
|
150 |
+
print("MISSING KEYS")
|
151 |
+
print(missing_keys)
|
152 |
+
if len(unexpected_keys) != 0:
|
153 |
+
print("UNEXPECTED KEYS")
|
154 |
+
print(unexpected_keys)
|
155 |
+
|
156 |
+
if len(missing_keys) != 0 or len(unexpected_keys) != 0:
|
157 |
+
print("Failed due to mismatch in state dicts.")
|
158 |
+
exit()
|
159 |
+
|
160 |
+
policy.save_pretrained(PATH_TO_SAVE_NEW_WEIGHTS)
|
161 |
+
OmegaConf.save(cfg, Path(PATH_TO_SAVE_NEW_WEIGHTS) / "config.yaml")
|