Upload folder using huggingface_hub
Browse files- README.md +7 -35
- config.yaml +3 -5
- convert_weights.py +80 -0
- image_pusht_diffusion_policy_cnn.yaml +185 -0
- model.safetensors +1 -1
README.md
CHANGED
@@ -1,39 +1,11 @@
|
|
1 |
-
|
2 |
|
3 |
-
|
4 |
-
Learning via Action Diffusion](https://arxiv.org/abs/2303.04137)) trained for the `PushT` environment from [gym-pusht](https://github.com/huggingface/gym-pusht).
|
5 |
|
6 |
-
|
|
|
|
|
7 |
|
8 |
-
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
## Training Details
|
13 |
-
|
14 |
-
TODO commit hash.
|
15 |
-
|
16 |
-
Trained with [LeRobot@d747195](https://github.com/huggingface/lerobot/tree/d747195c5733c4f68d4bfbe62632d6fc1b605712).
|
17 |
-
|
18 |
-
The model was trained using [LeRobot's training script](https://github.com/huggingface/lerobot/blob/d747195c5733c4f68d4bfbe62632d6fc1b605712/lerobot/scripts/train.py) and with the [pusht](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3) dataset.
|
19 |
-
|
20 |
-
Here are the [loss](./train_loss.csv), [evaluation score](./eval_avg_max_reward.csv), [evaluation success rate](./eval_pc_success.csv) (with 50 rollouts) during training.
|
21 |
-
|
22 |
-

|
23 |
-
|
24 |
-
This took about 7 hours to train on an Nvida RTX 3090.
|
25 |
-
|
26 |
-
## Evaluation
|
27 |
-
|
28 |
-
The model was evaluated on the `PushT` environment from [gym-pusht](https://github.com/huggingface/gym-pusht) and compared to a similar model trained with the original [Diffusion Policy code](https://github.com/real-stanford/diffusion_policy). There are two evaluation metrics on a per-episode basis:
|
29 |
-
|
30 |
-
- Maximum overlap with target (seen as `eval/avg_max_reward` in the charts above). This ranges in [0, 1].
|
31 |
-
- Success: whether or not the maximum overlap is at least 95%.
|
32 |
-
|
33 |
-
Here are the metrics for 500 episodes worth of evaluation. For the succes rate we add an extra row with confidence bounds. This assumes a uniform prior over success probability and computes the beta posterior, then calculates the mean and lower/upper confidence bounds (with a 68.2% confidence interval centered on the mean).
|
34 |
-
|
35 |
-
<blank>|Ours|Theirs
|
36 |
-
-|-|-
|
37 |
-
Average max. overlap ratio | 0.959 | 0.957
|
38 |
-
Success rate for 500 episodes (%) | 63.8 | 64.2
|
39 |
-
Beta distribution lower/mean/upper (%) | 61.6 / 63.7 / 65.9 | 62.0 / 64.1 / 66.3
|
|
|
1 |
+
This branch contains the model weights obtained from training on the original Diffusion Policy repository.
|
2 |
|
3 |
+
This is the command that was used for training:
|
|
|
4 |
|
5 |
+
```bash
|
6 |
+
python train.py --config-dir=. --config-name=image_pusht_diffusion_policy_cnn.yaml training.seed=42 logging.name=benchmark
|
7 |
+
```
|
8 |
|
9 |
+
The configuration file `image_pusht_diffusion_policy_cnn.yaml` is included in this branch.
|
10 |
|
11 |
+
The weights were converted with [`convert_weights.py`](convert_weights.py).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config.yaml
CHANGED
@@ -7,8 +7,8 @@ training:
|
|
7 |
online_steps_between_rollouts: 1
|
8 |
online_sampling_ratio: 0.5
|
9 |
online_env_seed: ???
|
10 |
-
eval_freq:
|
11 |
-
save_freq:
|
12 |
log_freq: 250
|
13 |
save_model: true
|
14 |
batch_size: 64
|
@@ -45,15 +45,13 @@ training:
|
|
45 |
- 1.2
|
46 |
- 1.3
|
47 |
- 1.4
|
48 |
-
n_end_keyframes_dropped: ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps}
|
49 |
-
+ 1
|
50 |
eval:
|
51 |
n_episodes: 50
|
52 |
batch_size: 50
|
53 |
use_async_envs: false
|
54 |
wandb:
|
55 |
enable: true
|
56 |
-
disable_artifact:
|
57 |
project: lerobot
|
58 |
notes: ''
|
59 |
fps: 10
|
|
|
7 |
online_steps_between_rollouts: 1
|
8 |
online_sampling_ratio: 0.5
|
9 |
online_env_seed: ???
|
10 |
+
eval_freq: 5000
|
11 |
+
save_freq: 5000
|
12 |
log_freq: 250
|
13 |
save_model: true
|
14 |
batch_size: 64
|
|
|
45 |
- 1.2
|
46 |
- 1.3
|
47 |
- 1.4
|
|
|
|
|
48 |
eval:
|
49 |
n_episodes: 50
|
50 |
batch_size: 50
|
51 |
use_async_envs: false
|
52 |
wandb:
|
53 |
enable: true
|
54 |
+
disable_artifact: false
|
55 |
project: lerobot
|
56 |
notes: ''
|
57 |
fps: 10
|
convert_weights.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from itertools import product
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
|
7 |
+
from lerobot.common.datasets.factory import make_dataset
|
8 |
+
from lerobot.common.policies.factory import make_policy
|
9 |
+
from lerobot.common.utils.utils import init_hydra_config
|
10 |
+
|
11 |
+
PATH_TO_ORIGINAL_WEIGHTS = "/tmp/dp.pt"
|
12 |
+
PATH_TO_CONFIG = "/home/alexander/Projects/lerobot/lerobot/configs/default.yaml"
|
13 |
+
PATH_TO_SAVE_NEW_WEIGHTS = "/tmp/dp"
|
14 |
+
|
15 |
+
cfg = init_hydra_config(PATH_TO_CONFIG)
|
16 |
+
|
17 |
+
policy = make_policy(cfg, dataset_stats=make_dataset(cfg).stats)
|
18 |
+
|
19 |
+
state_dict = torch.load(PATH_TO_ORIGINAL_WEIGHTS)
|
20 |
+
|
21 |
+
# Remove keys based on what they start with.
|
22 |
+
|
23 |
+
start_removals = ["normalizer.", "obs_encoder.obs_nets.rgb.backbone.nets.0.nets.0"]
|
24 |
+
|
25 |
+
for to_remove in start_removals:
|
26 |
+
for k in list(state_dict.keys()):
|
27 |
+
if k.startswith(to_remove):
|
28 |
+
del state_dict[k]
|
29 |
+
|
30 |
+
|
31 |
+
# Replace keys based on what they start with.
|
32 |
+
|
33 |
+
start_replacements = [
|
34 |
+
("obs_encoder.obs_nets.image.backbone.nets", "rgb_encoder.backbone"),
|
35 |
+
("obs_encoder.obs_nets.image.pool", "rgb_encoder.pool"),
|
36 |
+
("obs_encoder.obs_nets.image.nets.3", "rgb_encoder.out"),
|
37 |
+
*[(f"model.up_modules.{i}.2.conv.", f"model.up_modules.{i}.2.") for i in range(2)],
|
38 |
+
*[(f"model.down_modules.{i}.2.conv.", f"model.down_modules.{i}.2.") for i in range(2)],
|
39 |
+
*[
|
40 |
+
(f"model.mid_modules.{i}.blocks.{k}.", f"model.mid_modules.{i}.conv{k + 1}.")
|
41 |
+
for i, k in product(range(3), range(2))
|
42 |
+
],
|
43 |
+
*[
|
44 |
+
(f"model.down_modules.{i}.{j}.blocks.{k}.", f"model.down_modules.{i}.{j}.conv{k + 1}.")
|
45 |
+
for i, j, k in product(range(3), range(2), range(2))
|
46 |
+
],
|
47 |
+
*[
|
48 |
+
(f"model.up_modules.{i}.{j}.blocks.{k}.", f"model.up_modules.{i}.{j}.conv{k + 1}.")
|
49 |
+
for i, j, k in product(range(3), range(2), range(2))
|
50 |
+
],
|
51 |
+
("model.", "unet.")
|
52 |
+
]
|
53 |
+
|
54 |
+
for to_replace, replace_with in start_replacements:
|
55 |
+
for k in list(state_dict.keys()):
|
56 |
+
if k.startswith(to_replace):
|
57 |
+
k_ = replace_with + k.removeprefix(to_replace)
|
58 |
+
state_dict[k_] = state_dict[k]
|
59 |
+
del state_dict[k]
|
60 |
+
|
61 |
+
missing_keys, unexpected_keys = policy.diffusion.load_state_dict(state_dict, strict=False)
|
62 |
+
|
63 |
+
unexpected_keys = set(unexpected_keys)
|
64 |
+
allowed_unexpected_keys = eval(
|
65 |
+
"{'obs_encoder.obs_nets.image.nets.0.nets.7.1.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.downsample.0.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.0.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.conv1.weight', 'obs_encoder.obs_nets.image.nets.1.nets.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.bn1.weight', 'obs_encoder.obs_nets.image.nets.1.pos_x', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.bn2.bias', 'obs_encoder.obs_nets.image.nets.1.nets.bias', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.bn1.weight', '_dummy_variable', 'mask_generator._dummy_variable', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.downsample.1.bias', 'obs_encoder.obs_nets.image.nets.1.temperature', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.downsample.1.bias', 'obs_encoder.obs_nets.image.nets.1.pos_y', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.downsample.1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.downsample.1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.downsample.0.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.downsample.1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.downsample.0.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.downsample.1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.bn2.bias'}"
|
66 |
+
)
|
67 |
+
if len(missing_keys) != 0:
|
68 |
+
print("MISSING KEYS")
|
69 |
+
print(missing_keys)
|
70 |
+
if unexpected_keys != allowed_unexpected_keys:
|
71 |
+
print("UNEXPECTED KEYS")
|
72 |
+
print(unexpected_keys)
|
73 |
+
|
74 |
+
if len(missing_keys) != 0 or unexpected_keys != allowed_unexpected_keys:
|
75 |
+
print("Failed due to mismatch in state dicts.")
|
76 |
+
exit()
|
77 |
+
|
78 |
+
torch.save(policy.state_dict(), "/tmp/policy.pt")
|
79 |
+
policy.save_pretrained(PATH_TO_SAVE_NEW_WEIGHTS)
|
80 |
+
OmegaConf.save(cfg, Path(PATH_TO_SAVE_NEW_WEIGHTS) / "config.yaml")
|
image_pusht_diffusion_policy_cnn.yaml
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: diffusion_policy.workspace.train_diffusion_unet_hybrid_workspace.TrainDiffusionUnetHybridWorkspace
|
2 |
+
checkpoint:
|
3 |
+
save_last_ckpt: true
|
4 |
+
save_last_snapshot: false
|
5 |
+
topk:
|
6 |
+
format_str: epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt
|
7 |
+
k: 5
|
8 |
+
mode: max
|
9 |
+
monitor_key: test_mean_score
|
10 |
+
dataloader:
|
11 |
+
batch_size: 64
|
12 |
+
num_workers: 8
|
13 |
+
persistent_workers: false
|
14 |
+
pin_memory: true
|
15 |
+
shuffle: true
|
16 |
+
dataset_obs_steps: 2
|
17 |
+
ema:
|
18 |
+
_target_: diffusion_policy.model.diffusion.ema_model.EMAModel
|
19 |
+
inv_gamma: 1.0
|
20 |
+
max_value: 0.9999
|
21 |
+
min_value: 0.0
|
22 |
+
power: 0.75
|
23 |
+
update_after_step: 0
|
24 |
+
exp_name: default
|
25 |
+
horizon: 16
|
26 |
+
keypoint_visible_rate: 1.0
|
27 |
+
logging:
|
28 |
+
group: null
|
29 |
+
id: null
|
30 |
+
mode: online
|
31 |
+
name: 2023.01.16-20.20.06_train_diffusion_unet_hybrid_pusht_image
|
32 |
+
project: diffusion_policy_debug
|
33 |
+
resume: true
|
34 |
+
tags:
|
35 |
+
- train_diffusion_unet_hybrid
|
36 |
+
- pusht_image
|
37 |
+
- default
|
38 |
+
multi_run:
|
39 |
+
run_dir: data/outputs/2023.01.16/20.20.06_train_diffusion_unet_hybrid_pusht_image
|
40 |
+
wandb_name_base: 2023.01.16-20.20.06_train_diffusion_unet_hybrid_pusht_image
|
41 |
+
n_action_steps: 8
|
42 |
+
n_latency_steps: 0
|
43 |
+
n_obs_steps: 2
|
44 |
+
name: train_diffusion_unet_hybrid
|
45 |
+
obs_as_global_cond: true
|
46 |
+
optimizer:
|
47 |
+
_target_: torch.optim.AdamW
|
48 |
+
betas:
|
49 |
+
- 0.95
|
50 |
+
- 0.999
|
51 |
+
eps: 1.0e-08
|
52 |
+
lr: 0.0001
|
53 |
+
weight_decay: 1.0e-06
|
54 |
+
past_action_visible: false
|
55 |
+
policy:
|
56 |
+
_target_: diffusion_policy.policy.diffusion_unet_hybrid_image_policy.DiffusionUnetHybridImagePolicy
|
57 |
+
cond_predict_scale: true
|
58 |
+
crop_shape:
|
59 |
+
- 84
|
60 |
+
- 84
|
61 |
+
diffusion_step_embed_dim: 128
|
62 |
+
down_dims:
|
63 |
+
# - 256
|
64 |
+
# - 512
|
65 |
+
# - 1024
|
66 |
+
- 512
|
67 |
+
- 1024
|
68 |
+
- 2048
|
69 |
+
eval_fixed_crop: true
|
70 |
+
horizon: 16
|
71 |
+
kernel_size: 5
|
72 |
+
n_action_steps: 8
|
73 |
+
n_groups: 8
|
74 |
+
n_obs_steps: 2
|
75 |
+
noise_scheduler:
|
76 |
+
_target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
|
77 |
+
beta_end: 0.02
|
78 |
+
beta_schedule: squaredcos_cap_v2
|
79 |
+
beta_start: 0.0001
|
80 |
+
clip_sample: true
|
81 |
+
num_train_timesteps: 100
|
82 |
+
prediction_type: epsilon
|
83 |
+
variance_type: fixed_small
|
84 |
+
num_inference_steps: 100
|
85 |
+
obs_as_global_cond: true
|
86 |
+
obs_encoder_group_norm: true
|
87 |
+
shape_meta:
|
88 |
+
action:
|
89 |
+
shape:
|
90 |
+
- 2
|
91 |
+
obs:
|
92 |
+
agent_pos:
|
93 |
+
shape:
|
94 |
+
- 2
|
95 |
+
type: low_dim
|
96 |
+
image:
|
97 |
+
shape:
|
98 |
+
- 3
|
99 |
+
- 96
|
100 |
+
- 96
|
101 |
+
type: rgb
|
102 |
+
shape_meta:
|
103 |
+
action:
|
104 |
+
shape:
|
105 |
+
- 2
|
106 |
+
obs:
|
107 |
+
agent_pos:
|
108 |
+
shape:
|
109 |
+
- 2
|
110 |
+
type: low_dim
|
111 |
+
image:
|
112 |
+
shape:
|
113 |
+
- 3
|
114 |
+
- 96
|
115 |
+
- 96
|
116 |
+
type: rgb
|
117 |
+
task:
|
118 |
+
dataset:
|
119 |
+
_target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset
|
120 |
+
horizon: 16
|
121 |
+
max_train_episodes: null
|
122 |
+
pad_after: 7
|
123 |
+
pad_before: 1
|
124 |
+
seed: 42
|
125 |
+
val_ratio: 0
|
126 |
+
zarr_path: data/pusht/pusht_cchi_v7_replay.zarr
|
127 |
+
env_runner:
|
128 |
+
_target_: diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner
|
129 |
+
fps: 10
|
130 |
+
legacy_test: true
|
131 |
+
max_steps: 300
|
132 |
+
n_action_steps: 8
|
133 |
+
n_envs: null
|
134 |
+
n_obs_steps: 2
|
135 |
+
n_test: 50
|
136 |
+
n_test_vis: 4
|
137 |
+
n_train: 6
|
138 |
+
n_train_vis: 2
|
139 |
+
past_action: false
|
140 |
+
test_start_seed: 100000
|
141 |
+
train_start_seed: 0
|
142 |
+
image_shape:
|
143 |
+
- 3
|
144 |
+
- 96
|
145 |
+
- 96
|
146 |
+
name: pusht_image
|
147 |
+
shape_meta:
|
148 |
+
action:
|
149 |
+
shape:
|
150 |
+
- 2
|
151 |
+
obs:
|
152 |
+
agent_pos:
|
153 |
+
shape:
|
154 |
+
- 2
|
155 |
+
type: low_dim
|
156 |
+
image:
|
157 |
+
shape:
|
158 |
+
- 3
|
159 |
+
- 96
|
160 |
+
- 96
|
161 |
+
type: rgb
|
162 |
+
task_name: pusht_image
|
163 |
+
training:
|
164 |
+
checkpoint_every: 50
|
165 |
+
debug: false
|
166 |
+
device: cuda:0
|
167 |
+
gradient_accumulate_every: 1
|
168 |
+
lr_scheduler: cosine
|
169 |
+
lr_warmup_steps: 500
|
170 |
+
max_train_steps: null
|
171 |
+
max_val_steps: null
|
172 |
+
num_epochs: 500
|
173 |
+
resume: true
|
174 |
+
rollout_every: 50
|
175 |
+
sample_every: 5
|
176 |
+
seed: 42
|
177 |
+
tqdm_interval_sec: 1.0
|
178 |
+
use_ema: true
|
179 |
+
val_every: 50000000
|
180 |
+
val_dataloader:
|
181 |
+
batch_size: 64
|
182 |
+
num_workers: 8
|
183 |
+
persistent_workers: false
|
184 |
+
pin_memory: true
|
185 |
+
shuffle: false
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1050862612
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9150bda22091932686db52309233586c2695be418bda16fa7202e497f56bfab8
|
3 |
size 1050862612
|