Robotics
Transformers
Safetensors
act-policy
model_hub_mixin
pytorch_model_hub_mixin
alexandersoare commited on
Commit
a823c0e
·
verified ·
1 Parent(s): d8b8d58

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. convert_weights.py +161 -0
README.md CHANGED
@@ -11,4 +11,4 @@ python3 imitate_episodes.py \
11
  --seed 0 \
12
  ```
13
 
14
- Also, the code was patched to train with all the data rather than reserving a validation split.
 
11
  --seed 0 \
12
  ```
13
 
14
+ Also, the code was patched to train with all the data rather than reserving a validation split. The weights are converted with [`convert_weights.py`](convert_weights.py).
convert_weights.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ from omegaconf import OmegaConf
5
+
6
+ from lerobot.common.datasets.factory import make_dataset
7
+ from lerobot.common.policies.factory import make_policy
8
+ from lerobot.common.utils.utils import init_hydra_config
9
+
10
+ PATH_TO_ORIGINAL_WEIGHTS = "/home/alexander/Projects/act/outputs/sim_transfer_cube_human_vae/policy_last.ckpt"
11
+ PATH_TO_CONFIG = (
12
+ "outputs/train/act_aloha_sim_transfer_cube_human_final_video/.hydra/config.yaml"
13
+ )
14
+ PATH_TO_SAVE_NEW_WEIGHTS = "/tmp/act"
15
+
16
+
17
+ cfg = init_hydra_config(PATH_TO_CONFIG)
18
+
19
+ policy = make_policy(hydra_cfg=cfg, dataset_stats=make_dataset(cfg).stats)
20
+
21
+ state_dict = torch.load(PATH_TO_ORIGINAL_WEIGHTS)
22
+
23
+ # Remove keys based on what they start with.
24
+
25
+ start_removals = [
26
+ # There is a bug that means the pretrained model doesn't even use the final decoder layers.
27
+ *[f"model.transformer.decoder.layers.{i}" for i in range(1, 7)],
28
+ "model.is_pad_head.",
29
+ ]
30
+
31
+ for to_remove in start_removals:
32
+ for k in list(state_dict.keys()):
33
+ if k.startswith(to_remove):
34
+ del state_dict[k]
35
+
36
+
37
+ # Replace keys based on what they start with.
38
+
39
+ start_replacements = [
40
+ ("model.query_embed.weight", "model.pos_embed.weight"),
41
+ ("model.pos_table", "model.vae_encoder_pos_enc"),
42
+ ("model.pos_embed.weight", "model.decoder_pos_embed.weight"),
43
+ ("model.encoder.", "model.vae_encoder."),
44
+ ("model.encoder_action_proj.", "model.vae_encoder_action_input_proj."),
45
+ ("model.encoder_joint_proj.", "model.vae_encoder_robot_state_input_proj."),
46
+ ("model.latent_proj.", "model.vae_encoder_latent_output_proj."),
47
+ ("model.latent_proj.", "model.vae_encoder_latent_output_proj."),
48
+ ("model.input_proj.", "model.encoder_img_feat_input_proj."),
49
+ ("model.input_proj_robot_state", "model.encoder_robot_state_input_proj"),
50
+ ("model.latent_out_proj.", "model.encoder_latent_input_proj."),
51
+ ("model.transformer.encoder.", "model.encoder."),
52
+ ("model.transformer.decoder.", "model.decoder."),
53
+ ("model.backbones.0.0.body.", "model.backbone."),
54
+ ("model.additional_pos_embed.weight", "model.encoder_robot_and_latent_pos_embed.weight"),
55
+ ("model.cls_embed.weight", "model.vae_encoder_cls_embed.weight"),
56
+ ]
57
+
58
+ for to_replace, replace_with in start_replacements:
59
+ for k in list(state_dict.keys()):
60
+ if k.startswith(to_replace):
61
+ k_ = replace_with + k.removeprefix(to_replace)
62
+ state_dict[k_] = state_dict[k]
63
+ del state_dict[k]
64
+
65
+
66
+ state_dict["normalize_inputs.buffer_observation_images_top.mean"] = torch.tensor(
67
+ [[[0.4850]], [[0.4560]], [[0.4060]]]
68
+ )
69
+ state_dict["normalize_inputs.buffer_observation_images_top.std"] = torch.tensor(
70
+ [[[0.2290]], [[0.2240]], [[0.2250]]]
71
+ )
72
+ state_dict["normalize_inputs.buffer_observation_state.mean"] = torch.tensor(
73
+ [
74
+ -0.0074,
75
+ -0.6319,
76
+ 1.0357,
77
+ -0.0503,
78
+ -0.4620,
79
+ -0.0747,
80
+ 0.4747,
81
+ -0.0362,
82
+ -0.3320,
83
+ 0.9039,
84
+ -0.2206,
85
+ -0.3101,
86
+ -0.2348,
87
+ 0.6842,
88
+ ]
89
+ )
90
+ state_dict["normalize_inputs.buffer_observation_state.std"] = torch.tensor(
91
+ [
92
+ 0.0122,
93
+ 0.2975,
94
+ 0.1673,
95
+ 0.0473,
96
+ 0.1486,
97
+ 0.0879,
98
+ 0.3175,
99
+ 0.1050,
100
+ 0.2793,
101
+ 0.1809,
102
+ 0.2660,
103
+ 0.3047,
104
+ 0.5299,
105
+ 0.2550,
106
+ ]
107
+ )
108
+ state_dict["unnormalize_outputs.buffer_action.mean"] = torch.tensor(
109
+ [
110
+ -0.0076,
111
+ -0.6282,
112
+ 1.0313,
113
+ -0.0466,
114
+ -0.4721,
115
+ -0.0745,
116
+ 0.3739,
117
+ -0.0372,
118
+ -0.3261,
119
+ 0.8997,
120
+ -0.2137,
121
+ -0.3184,
122
+ -0.2336,
123
+ 0.5519,
124
+ ]
125
+ )
126
+ state_dict["normalize_targets.buffer_action.mean"] = state_dict["unnormalize_outputs.buffer_action.mean"]
127
+ state_dict["unnormalize_outputs.buffer_action.std"] = torch.tensor(
128
+ [
129
+ 0.0125,
130
+ 0.2957,
131
+ 0.1670,
132
+ 0.0458,
133
+ 0.1483,
134
+ 0.0876,
135
+ 0.3067,
136
+ 0.1060,
137
+ 0.2757,
138
+ 0.1806,
139
+ 0.2630,
140
+ 0.3071,
141
+ 0.5305,
142
+ 0.3838,
143
+ ]
144
+ )
145
+ state_dict["normalize_targets.buffer_action.std"] = state_dict["unnormalize_outputs.buffer_action.std"]
146
+
147
+ missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
148
+
149
+ if len(missing_keys) != 0:
150
+ print("MISSING KEYS")
151
+ print(missing_keys)
152
+ if len(unexpected_keys) != 0:
153
+ print("UNEXPECTED KEYS")
154
+ print(unexpected_keys)
155
+
156
+ if len(missing_keys) != 0 or len(unexpected_keys) != 0:
157
+ print("Failed due to mismatch in state dicts.")
158
+ exit()
159
+
160
+ policy.save_pretrained(PATH_TO_SAVE_NEW_WEIGHTS)
161
+ OmegaConf.save(cfg, Path(PATH_TO_SAVE_NEW_WEIGHTS) / "config.yaml")