jsu27 commited on
Commit
188ed7e
·
1 Parent(s): 4c1d330
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -10,12 +10,12 @@ from decomp_diffusion.diffusion.respace import SpacedDiffusion
10
  from decomp_diffusion.gen_image import *
11
  from download import download_model
12
 
 
 
13
  # fix randomness
14
  th.manual_seed(0)
15
  np.random.seed(0)
16
 
17
- import gradio as gr
18
-
19
 
20
  def get_pil_im(im, resolution=64):
21
  im = imresize(im, (resolution, resolution))[:, :, :3]
@@ -112,7 +112,7 @@ model_kwargs.update(dict(
112
  clevr_model = create_diffusion_model(**model_kwargs)
113
  clevr_model.eval()
114
 
115
- device = 'cuda'
116
  clevr_model.to(device)
117
 
118
  print(f'loading from {ckpt_path}')
 
10
  from decomp_diffusion.gen_image import *
11
  from download import download_model
12
 
13
+ import gradio as gr
14
+
15
  # fix randomness
16
  th.manual_seed(0)
17
  np.random.seed(0)
18
 
 
 
19
 
20
  def get_pil_im(im, resolution=64):
21
  im = imresize(im, (resolution, resolution))[:, :, :3]
 
112
  clevr_model = create_diffusion_model(**model_kwargs)
113
  clevr_model.eval()
114
 
115
+ device = 'cuda' if th.cuda.is_available() else 'cpu'
116
  clevr_model.to(device)
117
 
118
  print(f'loading from {ckpt_path}')