JunhaoZhuang commited on
Commit
83dd084
·
verified ·
1 Parent(s): 785ea27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -176,7 +176,7 @@ transform = transforms.Compose([
176
  transforms.ToTensor(),
177
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
178
  ])
179
- weight_dtype = torch.float32
180
 
181
  # line model
182
  line_model_path = os.path.join(model_global_path, 'LE', 'erika.pth')
@@ -201,7 +201,7 @@ global MultiResNetModel
201
  global cur_style
202
 
203
  cur_style = 'line + shadow'
204
- weight_dtype = torch.float32
205
 
206
  block_out_channels = [128, 128, 256, 512, 512]
207
  MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
@@ -313,7 +313,7 @@ print('loaded pipeline')
313
 
314
  @spaces.GPU
315
  def change_ckpt(style):
316
- weight_dtype = torch.float32
317
 
318
  if style == 'line':
319
  MultiResNetModel_path = os.path.join(model_global_path, 'line_GSRP', 'MultiResNetModel.bin')
 
176
  transforms.ToTensor(),
177
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
178
  ])
179
+ weight_dtype = torch.float16
180
 
181
  # line model
182
  line_model_path = os.path.join(model_global_path, 'LE', 'erika.pth')
 
201
  global cur_style
202
 
203
  cur_style = 'line + shadow'
204
+ weight_dtype = torch.float16
205
 
206
  block_out_channels = [128, 128, 256, 512, 512]
207
  MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
 
313
 
314
  @spaces.GPU
315
  def change_ckpt(style):
316
+ weight_dtype = torch.float16
317
 
318
  if style == 'line':
319
  MultiResNetModel_path = os.path.join(model_global_path, 'line_GSRP', 'MultiResNetModel.bin')