JunhaoZhuang commited on
Commit
6fc8df6
·
verified ·
1 Parent(s): 2005201

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -5
app.py CHANGED
@@ -174,14 +174,16 @@ image_encoder = CLIPVisionModelWithProjection.from_pretrained(os.path.join(model
174
  # model_sketch = create_model_sketch('default').to('cuda') # create a model given opt.model and other options
175
  # model_sketch.eval()
176
 
177
-
178
  global pipeline
179
  global MultiResNetModel
180
  global cur_style
181
- cur_style = 'line + shadow'
182
 
183
  @spaces.GPU
184
  def load_ckpt():
 
 
 
 
185
  weight_dtype = torch.float16
186
 
187
  block_out_channels = [128, 128, 256, 512, 512]
@@ -291,10 +293,8 @@ def load_ckpt():
291
 
292
  print('loaded pipeline')
293
 
294
- return pipeline, MultiResNetModel
295
-
296
 
297
- pipeline, MultiResNetModel = load_ckpt()
298
 
299
  @spaces.GPU
300
  def change_ckpt(style):
@@ -311,6 +311,10 @@ def change_ckpt(style):
311
  else:
312
  raise ValueError("Invalid style: {}".format(style))
313
 
 
 
 
 
314
  cur_style = style
315
 
316
  MultiResNetModel.load_state_dict(torch.load(MultiResNetModel_path, map_location='cpu'), strict=True)
@@ -349,6 +353,7 @@ def process_multi_images(files):
349
 
350
  @spaces.GPU
351
  def extract_lines(image):
 
352
  src = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
353
 
354
  rows = int(np.ceil(src.shape[0] / 16)) * 16
@@ -384,6 +389,7 @@ def extract_line_image(query_image_, resolution):
384
 
385
  @spaces.GPU
386
  def extract_sketch_line_image(query_image_, input_style):
 
387
  if input_style != cur_style:
388
  change_ckpt(input_style)
389
 
@@ -425,6 +431,10 @@ def colorize_image(extracted_line, reference_images, resolution, seed, num_infer
425
  reference_images = process_multi_images(reference_images)
426
  fix_random_seeds(seed)
427
 
 
 
 
 
428
  tar_width, tar_height = resolution
429
 
430
  gr.Info("Image retrieval in progress...")
 
174
  # model_sketch = create_model_sketch('default').to('cuda') # create a model given opt.model and other options
175
  # model_sketch.eval()
176
 
 
177
  global pipeline
178
  global MultiResNetModel
179
  global cur_style
 
180
 
181
  @spaces.GPU
182
  def load_ckpt():
183
+ global pipeline
184
+ global MultiResNetModel
185
+ global cur_style
186
+ cur_style = 'line + shadow'
187
  weight_dtype = torch.float16
188
 
189
  block_out_channels = [128, 128, 256, 512, 512]
 
293
 
294
  print('loaded pipeline')
295
 
 
 
296
 
297
+ load_ckpt()
298
 
299
  @spaces.GPU
300
  def change_ckpt(style):
 
311
  else:
312
  raise ValueError("Invalid style: {}".format(style))
313
 
314
+ global pipeline
315
+ global MultiResNetModel
316
+ global cur_style
317
+
318
  cur_style = style
319
 
320
  MultiResNetModel.load_state_dict(torch.load(MultiResNetModel_path, map_location='cpu'), strict=True)
 
353
 
354
  @spaces.GPU
355
  def extract_lines(image):
356
+ global line_model
357
  src = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
358
 
359
  rows = int(np.ceil(src.shape[0] / 16)) * 16
 
389
 
390
  @spaces.GPU
391
  def extract_sketch_line_image(query_image_, input_style):
392
+ global cur_style
393
  if input_style != cur_style:
394
  change_ckpt(input_style)
395
 
 
431
  reference_images = process_multi_images(reference_images)
432
  fix_random_seeds(seed)
433
 
434
+ global pipeline
435
+ global MultiResNetModel
436
+ global cur_style
437
+
438
  tar_width, tar_height = resolution
439
 
440
  gr.Info("Image retrieval in progress...")