ar08 commited on
Commit
ee9363c
·
verified ·
1 Parent(s): 0040d2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -137
app.py CHANGED
@@ -1,154 +1,58 @@
1
  import gradio as gr
2
  import torch
3
- import numpy as np
4
  from diffusers import StableDiffusionImg2ImgPipeline
5
  from PIL import Image
6
- from typing import Generator, List
7
- import gc
8
- import os
9
-
10
- # Configure CPU optimization
11
- os.environ["OMP_NUM_THREADS"] = "1"
12
- os.environ["MKL_NUM_THREADS"] = "1"
13
- torch.set_num_threads(1)
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  model_id = "nitrosocke/Ghibli-Diffusion"
17
 
18
- # Memory-optimized pipeline loading
19
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
20
  model_id,
21
- torch_dtype=torch.float32, # Keep float32 for CPU stability
22
  )
23
- pipe = pipe.to(device)
24
- pipe.enable_attention_slicing(slice_size=4)
25
- pipe.enable_sequential_cpu_offload() if device == "cuda" else None
26
-
27
- def resize_and_crop(image: Image.Image, target_size: int = 512) -> Image.Image:
28
- """Optimized image preprocessing with downsampling"""
29
- width, height = image.size
30
- scale = max(target_size/width, target_size/height)
31
- image = image.resize((int(width*scale), int(height*scale)), Image.LANCZOS)
32
- width, height = image.size
33
- left = (width - target_size) // 2
34
- top = (height - target_size) // 2
35
- return image.crop((left, top, left+target_size, top+target_size))
36
-
37
- def generate_ghibli_style(
38
- input_image: Image.Image,
39
- steps: int = 25,
40
- strength: float = 0.6,
41
- guidance_scale: float = 7.5
42
- ) -> Generator[Image.Image, None, None]:
43
- """Memory-optimized generator with aggressive cleanup"""
44
- prompt = "ghibli style, detailed anime portrait, studio ghibli, anime artwork"
45
- negative_prompt = "blurry, low quality, sketch, cartoon, 3d, deformed, disfigured"
46
-
47
- # Preprocess with garbage collection
48
- input_image = resize_and_crop(input_image)
49
- init_image = input_image.convert("RGB")
50
- del input_image
51
- gc.collect()
52
-
53
- # Prepare latent variables with memory mapping
54
- init_tensor = pipe.image_processor.preprocess(init_image).to(device=device, dtype=torch.float32)
55
- init_latents = pipe.vae.encode(init_tensor).latent_dist.sample()
56
- init_latents = pipe.vae.config.scaling_factor * init_latents
57
- del init_tensor
58
- gc.collect()
59
-
60
- # Configure scheduler
61
- pipe.scheduler.set_timesteps(steps, device=device)
62
- timesteps = pipe.scheduler.timesteps[int(steps * strength):]
63
- noise = torch.randn_like(init_latents, device=device)
64
- latents = pipe.scheduler.add_noise(init_latents, noise, timesteps[:1])
65
- del init_latents, noise
66
- gc.collect()
67
 
68
- # Memory-efficient text encoding
69
- text_inputs = pipe.tokenizer(
70
- prompt,
71
- padding="max_length",
72
- max_length=pipe.tokenizer.model_max_length,
73
- return_tensors="pt"
74
- )
75
- text_embeddings = pipe.text_encoder(text_inputs.input_ids.to(device))[0].to(torch.float32)
76
-
77
- uncond_input = pipe.tokenizer(
78
- [negative_prompt],
79
- padding="max_length",
80
- max_length=text_embeddings.shape[1],
81
- return_tensors="pt"
82
- )
83
- uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(device))[0].to(torch.float32)
84
-
85
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
86
- del uncond_embeddings, uncond_input, text_inputs
87
- gc.collect()
88
 
89
- # Diffusion process with memory cleanup
90
- for i, t in enumerate(gr.Progress().tqdm(timesteps, desc="Generating")):
91
- # Memory-optimized UNet inference
92
- with torch.inference_mode():
93
- latent_model_input = torch.cat([latents] * 2)
94
- latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
95
-
96
- noise_pred = pipe.unet(
97
- latent_model_input,
98
- t,
99
- encoder_hidden_states=text_embeddings,
100
- return_dict=False,
101
- )[0]
102
 
103
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
104
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
105
-
106
- latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
107
-
108
- # Memory-efficient decoding
109
  with torch.no_grad():
110
- image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
111
- image = pipe.image_processor.postprocess(image, output_type="pil")[0]
112
-
113
- yield image
114
-
115
- # Aggressive memory cleanup
116
- del latent_model_input, noise_pred, noise_pred_uncond, noise_pred_text
117
- gc.collect()
118
-
119
- # Final cleanup
120
- del latents, text_embeddings
121
- gc.collect()
122
-
123
- # Gradio interface
124
- with gr.Blocks() as demo:
125
- gr.Markdown("# ✨ Studio Ghibli Style Transformer (CPU Optimized) ✨")
126
- gr.Markdown("Upload a portrait photo to transform it into a Studio Ghibli-style artwork (max 10GB RAM usage)!")
127
 
128
- with gr.Row():
129
- with gr.Column():
130
- input_image = gr.Image(label="Input Image", type="pil")
131
- steps_slider = gr.Slider(10, 40, value=25, step=5, label="Number of Steps")
132
- strength_slider = gr.Slider(0.4, 0.8, value=0.6, step=0.1, label="Transformation Strength")
133
- generate_btn = gr.Button("✨ Transform!", variant="primary")
134
-
135
- with gr.Column():
136
- gallery = gr.Gallery(
137
- label="Generation Progress",
138
- show_label=True,
139
- columns=4,
140
- preview=True,
141
- object_fit="contain",
142
- height=600
143
- )
144
-
145
- generate_btn.click(
146
- fn=generate_ghibli_style,
147
- inputs=[input_image, steps_slider, strength_slider],
148
- outputs=gallery,
149
- concurrency_limit=1
150
- )
151
 
152
- if __name__ == "__main__":
153
- demo.queue(concurrency_count=1)
154
- demo.launch()
 
1
  import gradio as gr
2
  import torch
 
3
  from diffusers import StableDiffusionImg2ImgPipeline
4
  from PIL import Image
5
+ import numpy as np
 
 
 
 
 
 
 
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
  model_id = "nitrosocke/Ghibli-Diffusion"
9
 
10
+ # Load the model (keep safety_checker to avoid warning)
11
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
12
  model_id,
13
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
14
  )
15
+ pipe.to(device)
16
+ pipe.enable_attention_slicing()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Function to convert PIL image to latent-compatible numpy
19
+ def pil_to_np(image):
20
+ return np.array(image).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # Generator with step-wise callback
23
+ def generate_ghibli_style(image, steps=25):
24
+ prompt = "ghibli style portrait"
25
+ intermediate_images = []
 
 
 
 
 
 
 
 
 
26
 
27
+ def callback(step: int, timestep: int, latents):
 
 
 
 
 
28
  with torch.no_grad():
29
+ img = pipe.decode_latents(latents)
30
+ img = pipe.numpy_to_pil(img)[0]
31
+ intermediate_images.append(img)
32
+
33
+ with torch.inference_mode():
34
+ pipe(
35
+ prompt=prompt,
36
+ image=image,
37
+ strength=0.6,
38
+ guidance_scale=6.0,
39
+ num_inference_steps=steps,
40
+ callback=callback,
41
+ callback_steps=1,
42
+ )
 
 
 
43
 
44
+ return intermediate_images
45
+
46
+ # Gradio Interface without deprecated style()
47
+ iface = gr.Interface(
48
+ fn=generate_ghibli_style,
49
+ inputs=[
50
+ gr.Image(type="pil", label="Upload a photo"),
51
+ gr.Slider(minimum=10, maximum=50, value=25, step=1, label="Inference Steps")
52
+ ],
53
+ outputs=gr.Gallery(label="Ghibli-style Generation Progress"),
54
+ title="✨ Studio Ghibli Portrait Generator ✨",
55
+ description="Upload a photo and watch it transform into a Ghibli-style portrait step by step!"
56
+ )
 
 
 
 
 
 
 
 
 
 
57
 
58
+ iface.launch()