ntt123 commited on
Commit
02d623b
·
1 Parent(s): 27914bf
Files changed (1) hide show
  1. sample.py +1 -1
sample.py CHANGED
@@ -174,7 +174,7 @@ def sample(
174
  embedding_vocab_size=model_config["embedding_vocab_size"],
175
  learn_sigma=model_config["learn_sigma"],
176
  in_channels=data_config["data_dim"],
177
- ).to(device).bfloat16
178
 
179
  state_dict = find_model(ckpt_path)
180
  model.load_state_dict(state_dict)
 
174
  embedding_vocab_size=model_config["embedding_vocab_size"],
175
  learn_sigma=model_config["learn_sigma"],
176
  in_channels=data_config["data_dim"],
177
+ ).to(device)
178
 
179
  state_dict = find_model(ckpt_path)
180
  model.load_state_dict(state_dict)