JunhaoZhuang commited on
Commit
2af336b
·
verified ·
1 Parent(s): 9b67bab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -308,7 +308,6 @@ def change_ckpt(style):
308
  global MultiResNetModel
309
  global cur_style
310
 
311
- cur_style = style
312
 
313
  MultiResNetModel.load_state_dict(torch.load(MultiResNetModel_path, map_location='cpu'), strict=True)
314
  MultiResNetModel.to('cuda', dtype=weight_dtype)
@@ -427,6 +426,7 @@ def colorize_image(input_style, extracted_line, reference_images, resolution, se
427
  if input_style != cur_style:
428
  gr.Info("Loading the model...")
429
  change_ckpt(input_style)
 
430
 
431
  tar_width, tar_height = resolution
432
 
 
308
  global MultiResNetModel
309
  global cur_style
310
 
 
311
 
312
  MultiResNetModel.load_state_dict(torch.load(MultiResNetModel_path, map_location='cpu'), strict=True)
313
  MultiResNetModel.to('cuda', dtype=weight_dtype)
 
426
  if input_style != cur_style:
427
  gr.Info("Loading the model...")
428
  change_ckpt(input_style)
429
+ cur_style = input_style
430
 
431
  tar_width, tar_height = resolution
432