Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -174,14 +174,16 @@ image_encoder = CLIPVisionModelWithProjection.from_pretrained(os.path.join(model
|
|
174 |
# model_sketch = create_model_sketch('default').to('cuda') # create a model given opt.model and other options
|
175 |
# model_sketch.eval()
|
176 |
|
177 |
-
|
178 |
global pipeline
|
179 |
global MultiResNetModel
|
180 |
global cur_style
|
181 |
-
cur_style = 'line + shadow'
|
182 |
|
183 |
@spaces.GPU
|
184 |
def load_ckpt():
|
|
|
|
|
|
|
|
|
185 |
weight_dtype = torch.float16
|
186 |
|
187 |
block_out_channels = [128, 128, 256, 512, 512]
|
@@ -291,10 +293,8 @@ def load_ckpt():
|
|
291 |
|
292 |
print('loaded pipeline')
|
293 |
|
294 |
-
return pipeline, MultiResNetModel
|
295 |
-
|
296 |
|
297 |
-
|
298 |
|
299 |
@spaces.GPU
|
300 |
def change_ckpt(style):
|
@@ -311,6 +311,10 @@ def change_ckpt(style):
|
|
311 |
else:
|
312 |
raise ValueError("Invalid style: {}".format(style))
|
313 |
|
|
|
|
|
|
|
|
|
314 |
cur_style = style
|
315 |
|
316 |
MultiResNetModel.load_state_dict(torch.load(MultiResNetModel_path, map_location='cpu'), strict=True)
|
@@ -349,6 +353,7 @@ def process_multi_images(files):
|
|
349 |
|
350 |
@spaces.GPU
|
351 |
def extract_lines(image):
|
|
|
352 |
src = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
|
353 |
|
354 |
rows = int(np.ceil(src.shape[0] / 16)) * 16
|
@@ -384,6 +389,7 @@ def extract_line_image(query_image_, resolution):
|
|
384 |
|
385 |
@spaces.GPU
|
386 |
def extract_sketch_line_image(query_image_, input_style):
|
|
|
387 |
if input_style != cur_style:
|
388 |
change_ckpt(input_style)
|
389 |
|
@@ -425,6 +431,10 @@ def colorize_image(extracted_line, reference_images, resolution, seed, num_infer
|
|
425 |
reference_images = process_multi_images(reference_images)
|
426 |
fix_random_seeds(seed)
|
427 |
|
|
|
|
|
|
|
|
|
428 |
tar_width, tar_height = resolution
|
429 |
|
430 |
gr.Info("Image retrieval in progress...")
|
|
|
174 |
# model_sketch = create_model_sketch('default').to('cuda') # create a model given opt.model and other options
|
175 |
# model_sketch.eval()
|
176 |
|
|
|
177 |
global pipeline
|
178 |
global MultiResNetModel
|
179 |
global cur_style
|
|
|
180 |
|
181 |
@spaces.GPU
|
182 |
def load_ckpt():
|
183 |
+
global pipeline
|
184 |
+
global MultiResNetModel
|
185 |
+
global cur_style
|
186 |
+
cur_style = 'line + shadow'
|
187 |
weight_dtype = torch.float16
|
188 |
|
189 |
block_out_channels = [128, 128, 256, 512, 512]
|
|
|
293 |
|
294 |
print('loaded pipeline')
|
295 |
|
|
|
|
|
296 |
|
297 |
+
load_ckpt()
|
298 |
|
299 |
@spaces.GPU
|
300 |
def change_ckpt(style):
|
|
|
311 |
else:
|
312 |
raise ValueError("Invalid style: {}".format(style))
|
313 |
|
314 |
+
global pipeline
|
315 |
+
global MultiResNetModel
|
316 |
+
global cur_style
|
317 |
+
|
318 |
cur_style = style
|
319 |
|
320 |
MultiResNetModel.load_state_dict(torch.load(MultiResNetModel_path, map_location='cpu'), strict=True)
|
|
|
353 |
|
354 |
@spaces.GPU
|
355 |
def extract_lines(image):
|
356 |
+
global line_model
|
357 |
src = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
|
358 |
|
359 |
rows = int(np.ceil(src.shape[0] / 16)) * 16
|
|
|
389 |
|
390 |
@spaces.GPU
|
391 |
def extract_sketch_line_image(query_image_, input_style):
|
392 |
+
global cur_style
|
393 |
if input_style != cur_style:
|
394 |
change_ckpt(input_style)
|
395 |
|
|
|
431 |
reference_images = process_multi_images(reference_images)
|
432 |
fix_random_seeds(seed)
|
433 |
|
434 |
+
global pipeline
|
435 |
+
global MultiResNetModel
|
436 |
+
global cur_style
|
437 |
+
|
438 |
tar_width, tar_height = resolution
|
439 |
|
440 |
gr.Info("Image retrieval in progress...")
|