Files changed (1) hide show
  1. app.py +88 -32
app.py CHANGED
@@ -3,56 +3,112 @@ import torch
3
  from diffusers import StableDiffusionImg2ImgPipeline
4
  from PIL import Image
5
  import numpy as np
 
6
 
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
  model_id = "nitrosocke/Ghibli-Diffusion"
9
 
10
- # Load the model (keep safety_checker to avoid warning)
11
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
12
  model_id,
13
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
14
  )
15
- pipe.to(device)
16
  pipe.enable_attention_slicing()
17
 
18
- # Function to convert PIL image to latent-compatible numpy
19
- def pil_to_np(image):
20
- return np.array(image).astype(np.uint8)
21
-
22
- # Generator with step-wise callback
23
- def generate_ghibli_style(image, steps=25):
24
- prompt = "ghibli style portrait"
 
 
 
 
 
 
25
  intermediate_images = []
26
 
27
- def callback(step: int, timestep: int, latents):
28
  with torch.no_grad():
29
- img = pipe.decode_latents(latents)
30
- img = pipe.numpy_to_pil(img)[0]
31
- intermediate_images.append(img)
 
 
 
 
 
32
 
 
33
  with torch.inference_mode():
34
- pipe(
 
35
  prompt=prompt,
36
- image=image,
37
- strength=0.6,
38
- guidance_scale=6.0,
 
39
  num_inference_steps=steps,
40
  callback=callback,
41
- callback_steps=1,
42
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- return intermediate_images
45
-
46
- # Gradio Interface without deprecated style()
47
- iface = gr.Interface(
48
- fn=generate_ghibli_style,
49
- inputs=[
50
- gr.Image(type="pil", label="Upload a photo"),
51
- gr.Slider(minimum=10, maximum=50, value=25, step=1, label="Inference Steps")
52
- ],
53
- outputs=gr.Gallery(label="Ghibli-style Generation Progress"),
54
- title="✨ Studio Ghibli Portrait Generator ✨",
55
- description="Upload a photo and watch it transform into a Ghibli-style portrait step by step!"
56
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- iface.launch()
 
 
 
3
  from diffusers import StableDiffusionImg2ImgPipeline
4
  from PIL import Image
5
  import numpy as np
6
+ from typing import Generator, List
7
 
8
+ # Set up device and model
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  model_id = "nitrosocke/Ghibli-Diffusion"
11
 
12
+ # Load the pipeline
13
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
14
  model_id,
15
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
16
  )
17
+ pipe = pipe.to(device)
18
  pipe.enable_attention_slicing()
19
 
20
+ def generate_ghibli_style(
21
+ input_image: Image.Image,
22
+ steps: int = 25,
23
+ strength: float = 0.6,
24
+ guidance_scale: float = 7.0,
25
+ progress: gr.Progress = gr.Progress()
26
+ ) -> Generator[List[Image.Image], None, None]:
27
+ """
28
+ Generate Ghibli-style images in real-time with intermediate steps
29
+ """
30
+ prompt = "ghibli style, high quality, detailed portrait"
31
+ negative_prompt = "low quality, blurry, bad anatomy"
32
+
33
  intermediate_images = []
34
 
35
+ def callback(step: int, timestep: int, latents: torch.Tensor):
36
  with torch.no_grad():
37
+ # Decode the latents to image
38
+ image = pipe.decode_latents(latents)
39
+ image = pipe.numpy_to_pil(image)[0]
40
+ intermediate_images.append(image)
41
+
42
+ # Update progress and yield the current images
43
+ progress(step / steps, desc="Generating...")
44
+ yield intermediate_images
45
 
46
+ # Run the pipeline
47
  with torch.inference_mode():
48
+ # Create a generator that will yield the images
49
+ generator = pipe(
50
  prompt=prompt,
51
+ image=input_image,
52
+ negative_prompt=negative_prompt,
53
+ strength=strength,
54
+ guidance_scale=guidance_scale,
55
  num_inference_steps=steps,
56
  callback=callback,
57
+ callback_steps=1 # Call after every step
58
  )
59
+
60
+ # Yield the final result
61
+ final_image = generator.images[0]
62
+ intermediate_images.append(final_image)
63
+ yield intermediate_images
64
+
65
+ # Custom CSS for better appearance
66
+ css = """
67
+ .gallery {
68
+ min-height: 500px;
69
+ }
70
+ .gallery img {
71
+ max-height: 400px;
72
+ object-fit: contain;
73
+ }
74
+ """
75
+
76
+ # Gradio interface
77
+ with gr.Blocks(css=css) as demo:
78
+ gr.Markdown("# ✨ Studio Ghibli Portrait Generator ✨")
79
+ gr.Markdown("Upload a photo and watch it transform into a Ghibli-style portrait in real-time!")
80
 
81
+ with gr.Row():
82
+ with gr.Column():
83
+ input_image = gr.Image(label="Upload Photo", type="pil")
84
+ steps_slider = gr.Slider(10, 50, value=25, step=1, label="Inference Steps")
85
+ strength_slider = gr.Slider(0.1, 0.9, value=0.6, step=0.05, label="Transformation Strength")
86
+ generate_btn = gr.Button("Generate", variant="primary")
87
+
88
+ with gr.Column():
89
+ gallery = gr.Gallery(
90
+ label="Generation Progress",
91
+ show_label=True,
92
+ elem_id="gallery",
93
+ preview=True
94
+ )
95
+
96
+ # Example images
97
+ gr.Examples(
98
+ examples=[
99
+ ["examples/portrait1.jpg", 25, 0.6],
100
+ ["examples/portrait2.jpg", 30, 0.5],
101
+ ],
102
+ inputs=[input_image, steps_slider, strength_slider],
103
+ label="Try these examples!"
104
+ )
105
+
106
+ generate_btn.click(
107
+ fn=generate_ghibli_style,
108
+ inputs=[input_image, steps_slider, strength_slider],
109
+ outputs=gallery
110
+ )
111
 
112
+ # Launch the app
113
+ if __name__ == "__main__":
114
+ demo.queue(concurrency_count=1).launch(share=True)