Files changed (1) hide show
  1. app.py +104 -42
app.py CHANGED
@@ -17,44 +17,51 @@ MODELS = {
17
  "Juggernaut-XL-V9-GE-RDPhoto2": "AiWise/Juggernaut-XL-V9-GE-RDPhoto2-Lightning_4S",
18
  "SatPony-Lightning": "John6666/satpony-lightning-v2-sdxl"
19
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- config_file = hf_hub_download(
22
- "xinsir/controlnet-union-sdxl-1.0",
23
- filename="config_promax.json",
24
- )
25
- config = ControlNetModel_Union.load_config(config_file)
26
- controlnet_model = ControlNetModel_Union.from_config(config)
27
- model_file = hf_hub_download(
28
- "xinsir/controlnet-union-sdxl-1.0",
29
- filename="diffusion_pytorch_model_promax.safetensors",
30
- )
31
- state_dict = load_state_dict(model_file)
32
- model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
33
- controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
34
- )
35
- model.to(device="cuda", dtype=torch.float16)
36
- vae = AutoencoderKL.from_pretrained(
37
- "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
38
- ).to("cuda")
39
- pipe = StableDiffusionXLFillPipeline.from_pretrained(
40
- "SG161222/RealVisXL_V5.0_Lightning",
41
- torch_dtype=torch.float16,
42
- vae=vae,
43
- controlnet=model,
44
- variant="fp16",
45
- )
46
- pipe = StableDiffusionXLFillPipeline.from_pretrained(
47
- "GraydientPlatformAPI/lustify-lightning",
48
- torch_dtype=torch.float16,
49
- vae=vae,
50
- controlnet=model,
51
- )
52
- pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
53
- pipe.to("cuda")
54
 
55
  @spaces.GPU(duration=12)
56
  def fill_image(prompt, image, model_selection, paste_back):
57
  print(f"Received image: {image}")
 
 
58
  if image is None:
59
  yield None, None
60
  return
@@ -191,13 +198,7 @@ def preview_image_and_mask(image, width, height, overlap_percentage, resize_opti
191
  @spaces.GPU(duration=12)
192
  def inpaint(prompt, image, inpaint_model, paste_back):
193
  global pipe
194
- if pipe.config.model_name != MODELS[model_name]:
195
- pipe = StableDiffusionXLFillPipeline.from_pretrained(
196
- MODELS[model_name],
197
- torch_dtype=torch.float16,
198
- vae=vae,
199
- controlnet=model,
200
- ).to("cuda")
201
  mask = Image.fromarray(image["mask"]).convert("L")
202
  image = Image.fromarray(image["image"])
203
  inpaint_final_prompt = f"score_9, score_8_up, score_7_up, {prompt}"
@@ -235,6 +236,8 @@ def outpaint(image, width, height, overlap_percentage, num_inference_steps, resi
235
 
236
  @spaces.GPU(duration=12)
237
  def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
 
 
238
  background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
239
  if not can_expand(background.width, background.height, width, height, alignment):
240
  alignment = "Middle"
@@ -259,7 +262,7 @@ def infer(image, width, height, overlap_percentage, num_inference_steps, resize_
259
  image = image.convert("RGBA")
260
  cnet_image.paste(image, (0, 0), mask)
261
  yield background, cnet_image
262
-
263
  def use_output_as_input(output_image):
264
  return gr.update(value=output_image[1])
265
 
@@ -360,6 +363,8 @@ with gr.Blocks(css=css, fill_height=True) as demo:
360
  label="Generated Image",
361
  )
362
  use_as_input_button = gr.Button("Use as Input Image", visible=False)
 
 
363
  use_as_input_button.click(
364
  fn=use_output_as_input, inputs=[result], outputs=[input_image]
365
  )
@@ -371,10 +376,24 @@ with gr.Blocks(css=css, fill_height=True) as demo:
371
  fn=lambda: gr.update(visible=False),
372
  inputs=None,
373
  outputs=use_as_input_button,
 
 
 
 
374
  ).then(
375
  fn=fill_image,
376
  inputs=[prompt, input_image, model_selection, paste_back],
377
  outputs=[result],
 
 
 
 
 
 
 
 
 
 
378
  ).then(
379
  fn=lambda: gr.update(visible=True),
380
  inputs=None,
@@ -388,10 +407,24 @@ with gr.Blocks(css=css, fill_height=True) as demo:
388
  fn=lambda: gr.update(visible=False),
389
  inputs=None,
390
  outputs=use_as_input_button,
 
 
 
 
391
  ).then(
392
  fn=fill_image,
393
  inputs=[prompt, input_image, model_selection, paste_back],
394
  outputs=[result],
 
 
 
 
 
 
 
 
 
 
395
  ).then(
396
  fn=lambda: gr.update(visible=True),
397
  inputs=None,
@@ -487,6 +520,8 @@ with gr.Blocks(css=css, fill_height=True) as demo:
487
  use_as_input_button_outpaint = gr.Button("Use as Input Image", visible=False)
488
  history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
489
  preview_image = gr.Image(label="Preview")
 
 
490
 
491
  target_ratio.change(
492
  fn=preload_presets,
@@ -525,16 +560,30 @@ with gr.Blocks(css=css, fill_height=True) as demo:
525
  fn=clear_result,
526
  inputs=None,
527
  outputs=result_outpaint,
 
 
 
 
528
  ).then(
529
  fn=infer,
530
  inputs=[input_image_outpaint, width_slider, height_slider, overlap_percentage, num_inference_steps,
531
  resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
532
  overlap_left, overlap_right, overlap_top, overlap_bottom],
533
  outputs=[result_outpaint],
 
 
 
 
 
534
  ).then(
535
  fn=lambda x, history: update_history(x[1], history),
536
  inputs=[result_outpaint, history_gallery],
537
  outputs=history_gallery,
 
 
 
 
 
538
  ).then(
539
  fn=lambda: gr.update(visible=True),
540
  inputs=None,
@@ -544,16 +593,30 @@ with gr.Blocks(css=css, fill_height=True) as demo:
544
  fn=clear_result,
545
  inputs=None,
546
  outputs=result_outpaint,
 
 
 
 
547
  ).then(
548
  fn=infer,
549
  inputs=[input_image_outpaint, width_slider, height_slider, overlap_percentage, num_inference_steps,
550
  resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
551
  overlap_left, overlap_right, overlap_top, overlap_bottom],
552
  outputs=[result_outpaint],
 
 
 
 
 
553
  ).then(
554
  fn=lambda x, history: update_history(x[1], history),
555
  inputs=[result_outpaint, history_gallery],
556
  outputs=history_gallery,
 
 
 
 
 
557
  ).then(
558
  fn=lambda: gr.update(visible=True),
559
  inputs=None,
@@ -566,5 +629,4 @@ with gr.Blocks(css=css, fill_height=True) as demo:
566
  outputs=[preview_image],
567
  queue=False
568
  )
569
-
570
  demo.launch(show_error=True)
 
17
  "Juggernaut-XL-V9-GE-RDPhoto2": "AiWise/Juggernaut-XL-V9-GE-RDPhoto2-Lightning_4S",
18
  "SatPony-Lightning": "John6666/satpony-lightning-v2-sdxl"
19
  }
20
+ def init_pipeline(model_name):
21
+ config_file = hf_hub_download(
22
+ "xinsir/controlnet-union-sdxl-1.0",
23
+ filename="config_promax.json",
24
+ )
25
+ config = ControlNetModel_Union.load_config(config_file)
26
+ controlnet_model = ControlNetModel_Union.from_config(config)
27
+ model_file = hf_hub_download(
28
+ "xinsir/controlnet-union-sdxl-1.0",
29
+ filename="diffusion_pytorch_model_promax.safetensors",
30
+ )
31
+ state_dict = load_state_dict(model_file)
32
+ model, _,_, _,_ = ControlNetModel_Union._load_pretrained_model(
33
+ controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
34
+ )
35
+ model.to(device="cuda", dtype=torch.float16)
36
+ vae = AutoencoderKL.from_pretrained(
37
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
38
+ ).to("cuda")
39
+ pipe = StableDiffusionXLFillPipeline.from_pretrained(
40
+ MODELS[model_name],
41
+ torch_dtype=torch.float16,
42
+ vae=vae,
43
+ controlnet=model,
44
+ variant="fp16",
45
+ )
46
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
47
+ pipe.to("cuda")
48
+ return pipe
49
+
50
+ # Initialize with the default model
51
+ default_model_name = "RealVisXL V5.0 Lightning"
52
+ pipe = init_pipeline(default_model_name)
53
 
54
+ def update_pipeline(model_selection):
55
+ global pipe
56
+ if pipe.config.model_name != MODELS[model_selection]:
57
+ pipe = init_pipeline(model_selection)
58
+ return pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  @spaces.GPU(duration=12)
61
  def fill_image(prompt, image, model_selection, paste_back):
62
  print(f"Received image: {image}")
63
+ global pipe
64
+ update_pipeline(model_selection)
65
  if image is None:
66
  yield None, None
67
  return
 
198
  @spaces.GPU(duration=12)
199
  def inpaint(prompt, image, inpaint_model, paste_back):
200
  global pipe
201
+ update_pipeline(inpaint_model)
 
 
 
 
 
 
202
  mask = Image.fromarray(image["mask"]).convert("L")
203
  image = Image.fromarray(image["image"])
204
  inpaint_final_prompt = f"score_9, score_8_up, score_7_up, {prompt}"
 
236
 
237
  @spaces.GPU(duration=12)
238
  def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
239
+ global pipe
240
+ update_pipeline(model_selection) # Added this line
241
  background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
242
  if not can_expand(background.width, background.height, width, height, alignment):
243
  alignment = "Middle"
 
262
  image = image.convert("RGBA")
263
  cnet_image.paste(image, (0, 0), mask)
264
  yield background, cnet_image
265
+
266
  def use_output_as_input(output_image):
267
  return gr.update(value=output_image[1])
268
 
 
363
  label="Generated Image",
364
  )
365
  use_as_input_button = gr.Button("Use as Input Image", visible=False)
366
+ loading_message = gr.Label(label="Status", value="", visible=False) # Added loading message label
367
+
368
  use_as_input_button.click(
369
  fn=use_output_as_input, inputs=[result], outputs=[input_image]
370
  )
 
376
  fn=lambda: gr.update(visible=False),
377
  inputs=None,
378
  outputs=use_as_input_button,
379
+ ).then(
380
+ fn=lambda: gr.update(value="Loading Model...", visible=True), # Show loading message
381
+ inputs=None,
382
+ outputs=[loading_message, use_as_input_button]
383
  ).then(
384
  fn=fill_image,
385
  inputs=[prompt, input_image, model_selection, paste_back],
386
  outputs=[result],
387
+ ).then(
388
+ fn=lambda: gr.update(value="Model Loaded", visible=True), # Show loaded message
389
+ inputs=None,
390
+ outputs=[loading_message],
391
+ queue=False
392
+ ).then(
393
+ fn=lambda: gr.update(value="", visible=False), # Hide loading message
394
+ inputs=None,
395
+ outputs=[loading_message],
396
+ queue=False
397
  ).then(
398
  fn=lambda: gr.update(visible=True),
399
  inputs=None,
 
407
  fn=lambda: gr.update(visible=False),
408
  inputs=None,
409
  outputs=use_as_input_button,
410
+ ).then(
411
+ fn=lambda: gr.update(value="Loading Model...", visible=True), # Show loading message
412
+ inputs=None,
413
+ outputs=[loading_message, use_as_input_button]
414
  ).then(
415
  fn=fill_image,
416
  inputs=[prompt, input_image, model_selection, paste_back],
417
  outputs=[result],
418
+ ).then(
419
+ fn=lambda: gr.update(value="Model Loaded", visible=True), # Show loaded message
420
+ inputs=None,
421
+ outputs=[loading_message],
422
+ queue=False
423
+ ).then(
424
+ fn=lambda: gr.update(value="", visible=False), # Hide loading message
425
+ inputs=None,
426
+ outputs=[loading_message],
427
+ queue=False
428
  ).then(
429
  fn=lambda: gr.update(visible=True),
430
  inputs=None,
 
520
  use_as_input_button_outpaint = gr.Button("Use as Input Image", visible=False)
521
  history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
522
  preview_image = gr.Image(label="Preview")
523
+ loading_message_outpaint = gr.Label(label="Status", value="", visible=False) # Added loading message label
524
+
525
 
526
  target_ratio.change(
527
  fn=preload_presets,
 
560
  fn=clear_result,
561
  inputs=None,
562
  outputs=result_outpaint,
563
+ ).then(
564
+ fn=lambda: gr.update(value="Loading Model...", visible=True), # Show loading message
565
+ inputs=None,
566
+ outputs=[loading_message_outpaint, use_as_input_button_outpaint]
567
  ).then(
568
  fn=infer,
569
  inputs=[input_image_outpaint, width_slider, height_slider, overlap_percentage, num_inference_steps,
570
  resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
571
  overlap_left, overlap_right, overlap_top, overlap_bottom],
572
  outputs=[result_outpaint],
573
+ ).then(
574
+ fn=lambda: gr.update(value="Model Loaded", visible=True), # Show loaded message
575
+ inputs=None,
576
+ outputs=[loading_message_outpaint],
577
+ queue=False
578
  ).then(
579
  fn=lambda x, history: update_history(x[1], history),
580
  inputs=[result_outpaint, history_gallery],
581
  outputs=history_gallery,
582
+ ).then(
583
+ fn=lambda: gr.update(value="", visible=False), # Hide loading message
584
+ inputs=None,
585
+ outputs=[loading_message_outpaint],
586
+ queue=False
587
  ).then(
588
  fn=lambda: gr.update(visible=True),
589
  inputs=None,
 
593
  fn=clear_result,
594
  inputs=None,
595
  outputs=result_outpaint,
596
+ ).then(
597
+ fn=lambda: gr.update(value="Loading Model...", visible=True), # Show loading message
598
+ inputs=None,
599
+ outputs=[loading_message_outpaint, use_as_input_button_outpaint]
600
  ).then(
601
  fn=infer,
602
  inputs=[input_image_outpaint, width_slider, height_slider, overlap_percentage, num_inference_steps,
603
  resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
604
  overlap_left, overlap_right, overlap_top, overlap_bottom],
605
  outputs=[result_outpaint],
606
+ ).then(
607
+ fn=lambda: gr.update(value="Model Loaded", visible=True), # Show loaded message
608
+ inputs=None,
609
+ outputs=[loading_message_outpaint],
610
+ queue=False
611
  ).then(
612
  fn=lambda x, history: update_history(x[1], history),
613
  inputs=[result_outpaint, history_gallery],
614
  outputs=history_gallery,
615
+ ).then(
616
+ fn=lambda: gr.update(value="", visible=False), # Hide loading message
617
+ inputs=None,
618
+ outputs=[loading_message_outpaint],
619
+ queue=False
620
  ).then(
621
  fn=lambda: gr.update(visible=True),
622
  inputs=None,
 
629
  outputs=[preview_image],
630
  queue=False
631
  )
 
632
  demo.launch(show_error=True)