ar08 commited on
Commit
0040d2b
·
verified ·
1 Parent(s): 114b2a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -56
app.py CHANGED
@@ -4,30 +4,35 @@ import numpy as np
4
  from diffusers import StableDiffusionImg2ImgPipeline
5
  from PIL import Image
6
  from typing import Generator, List
 
 
 
 
 
 
 
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
  model_id = "nitrosocke/Ghibli-Diffusion"
10
 
11
- # Load the pipeline
12
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
13
  model_id,
14
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
15
  )
16
  pipe = pipe.to(device)
17
- pipe.enable_attention_slicing()
 
18
 
19
  def resize_and_crop(image: Image.Image, target_size: int = 512) -> Image.Image:
20
- """Resize and crop the image to the target size while maintaining aspect ratio."""
 
 
 
21
  width, height = image.size
22
- if width > height:
23
- left = (width - height) // 2
24
- right = left + height
25
- image = image.crop((left, 0, right, height))
26
- elif height > width:
27
- top = (height - width) // 2
28
- bottom = top + width
29
- image = image.crop((0, top, width, bottom))
30
- return image.resize((target_size, target_size))
31
 
32
  def generate_ghibli_style(
33
  input_image: Image.Image,
@@ -35,90 +40,103 @@ def generate_ghibli_style(
35
  strength: float = 0.6,
36
  guidance_scale: float = 7.5
37
  ) -> Generator[Image.Image, None, None]:
38
- """Generator that yields intermediate images at each diffusion step."""
39
  prompt = "ghibli style, detailed anime portrait, studio ghibli, anime artwork"
40
  negative_prompt = "blurry, low quality, sketch, cartoon, 3d, deformed, disfigured"
41
 
42
- # Preprocess image
43
  input_image = resize_and_crop(input_image)
44
  init_image = input_image.convert("RGB")
 
 
45
 
46
- # Prepare latent variables
47
- init_image = pipe.image_processor.preprocess(init_image)
48
- init_latents = pipe.vae.encode(init_image.to(device)).latent_dist.sample()
49
  init_latents = pipe.vae.config.scaling_factor * init_latents
 
 
50
 
51
- # Prepare scheduler
52
  pipe.scheduler.set_timesteps(steps, device=device)
53
  timesteps = pipe.scheduler.timesteps[int(steps * strength):]
54
- noise = torch.randn_like(init_latents)
55
  latents = pipe.scheduler.add_noise(init_latents, noise, timesteps[:1])
 
 
56
 
57
- # Prepare text embeddings
58
  text_inputs = pipe.tokenizer(
59
  prompt,
60
  padding="max_length",
61
  max_length=pipe.tokenizer.model_max_length,
62
  return_tensors="pt"
63
  )
64
- text_embeddings = pipe.text_encoder(text_inputs.input_ids.to(device))[0]
65
-
66
- # Unconditional embedding
67
  uncond_input = pipe.tokenizer(
68
- [negative_prompt] * init_image.shape[0],
69
  padding="max_length",
70
  max_length=text_embeddings.shape[1],
71
  return_tensors="pt"
72
  )
73
- uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(device))[0]
74
-
75
- # Classifier-free guidance
76
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
 
 
77
 
78
- # Diffusion process
79
  for i, t in enumerate(gr.Progress().tqdm(timesteps, desc="Generating")):
80
- # Expand latents for classifier-free guidance
81
- latent_model_input = torch.cat([latents] * 2)
82
- latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
83
-
84
- # Predict noise
85
- noise_pred = pipe.unet(
86
- latent_model_input,
87
- t,
88
- encoder_hidden_states=text_embeddings
89
- ).sample
90
-
91
- # Perform guidance
92
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
93
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
94
-
95
- # Compute previous step
96
- latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
97
-
98
- # Decode and yield image
99
  with torch.no_grad():
100
  image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
101
  image = pipe.image_processor.postprocess(image, output_type="pil")[0]
102
 
103
  yield image
 
 
 
 
 
 
 
 
104
 
105
  # Gradio interface
106
  with gr.Blocks() as demo:
107
- gr.Markdown("# ✨ Studio Ghibli Style Transformer ✨")
108
- gr.Markdown("Upload a portrait photo to transform it into a Studio Ghibli-style artwork!")
109
 
110
  with gr.Row():
111
  with gr.Column():
112
  input_image = gr.Image(label="Input Image", type="pil")
113
- steps_slider = gr.Slider(10, 50, value=25, label="Number of Steps")
114
- strength_slider = gr.Slider(0.1, 0.9, value=0.6, label="Transformation Strength")
115
  generate_btn = gr.Button("✨ Transform!", variant="primary")
116
 
117
  with gr.Column():
118
  gallery = gr.Gallery(
119
  label="Generation Progress",
120
  show_label=True,
121
- columns=5,
122
  preview=True,
123
  object_fit="contain",
124
  height=600
@@ -131,8 +149,6 @@ with gr.Blocks() as demo:
131
  concurrency_limit=1
132
  )
133
 
134
-
135
-
136
  if __name__ == "__main__":
137
-
138
  demo.launch()
 
4
  from diffusers import StableDiffusionImg2ImgPipeline
5
  from PIL import Image
6
  from typing import Generator, List
7
+ import gc
8
+ import os
9
+
10
+ # Configure CPU optimization
11
+ os.environ["OMP_NUM_THREADS"] = "1"
12
+ os.environ["MKL_NUM_THREADS"] = "1"
13
+ torch.set_num_threads(1)
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  model_id = "nitrosocke/Ghibli-Diffusion"
17
 
18
+ # Memory-optimized pipeline loading
19
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
20
  model_id,
21
+ torch_dtype=torch.float32, # Keep float32 for CPU stability
22
  )
23
  pipe = pipe.to(device)
24
+ pipe.enable_attention_slicing(slice_size=4)
25
+ pipe.enable_sequential_cpu_offload() if device == "cuda" else None
26
 
27
  def resize_and_crop(image: Image.Image, target_size: int = 512) -> Image.Image:
28
+ """Optimized image preprocessing with downsampling"""
29
+ width, height = image.size
30
+ scale = max(target_size/width, target_size/height)
31
+ image = image.resize((int(width*scale), int(height*scale)), Image.LANCZOS)
32
  width, height = image.size
33
+ left = (width - target_size) // 2
34
+ top = (height - target_size) // 2
35
+ return image.crop((left, top, left+target_size, top+target_size))
 
 
 
 
 
 
36
 
37
  def generate_ghibli_style(
38
  input_image: Image.Image,
 
40
  strength: float = 0.6,
41
  guidance_scale: float = 7.5
42
  ) -> Generator[Image.Image, None, None]:
43
+ """Memory-optimized generator with aggressive cleanup"""
44
  prompt = "ghibli style, detailed anime portrait, studio ghibli, anime artwork"
45
  negative_prompt = "blurry, low quality, sketch, cartoon, 3d, deformed, disfigured"
46
 
47
+ # Preprocess with garbage collection
48
  input_image = resize_and_crop(input_image)
49
  init_image = input_image.convert("RGB")
50
+ del input_image
51
+ gc.collect()
52
 
53
+ # Prepare latent variables with memory mapping
54
+ init_tensor = pipe.image_processor.preprocess(init_image).to(device=device, dtype=torch.float32)
55
+ init_latents = pipe.vae.encode(init_tensor).latent_dist.sample()
56
  init_latents = pipe.vae.config.scaling_factor * init_latents
57
+ del init_tensor
58
+ gc.collect()
59
 
60
+ # Configure scheduler
61
  pipe.scheduler.set_timesteps(steps, device=device)
62
  timesteps = pipe.scheduler.timesteps[int(steps * strength):]
63
+ noise = torch.randn_like(init_latents, device=device)
64
  latents = pipe.scheduler.add_noise(init_latents, noise, timesteps[:1])
65
+ del init_latents, noise
66
+ gc.collect()
67
 
68
+ # Memory-efficient text encoding
69
  text_inputs = pipe.tokenizer(
70
  prompt,
71
  padding="max_length",
72
  max_length=pipe.tokenizer.model_max_length,
73
  return_tensors="pt"
74
  )
75
+ text_embeddings = pipe.text_encoder(text_inputs.input_ids.to(device))[0].to(torch.float32)
76
+
 
77
  uncond_input = pipe.tokenizer(
78
+ [negative_prompt],
79
  padding="max_length",
80
  max_length=text_embeddings.shape[1],
81
  return_tensors="pt"
82
  )
83
+ uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(device))[0].to(torch.float32)
84
+
 
85
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
86
+ del uncond_embeddings, uncond_input, text_inputs
87
+ gc.collect()
88
 
89
+ # Diffusion process with memory cleanup
90
  for i, t in enumerate(gr.Progress().tqdm(timesteps, desc="Generating")):
91
+ # Memory-optimized UNet inference
92
+ with torch.inference_mode():
93
+ latent_model_input = torch.cat([latents] * 2)
94
+ latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
95
+
96
+ noise_pred = pipe.unet(
97
+ latent_model_input,
98
+ t,
99
+ encoder_hidden_states=text_embeddings,
100
+ return_dict=False,
101
+ )[0]
102
+
103
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
104
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
105
+
106
+ latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
107
+
108
+ # Memory-efficient decoding
 
109
  with torch.no_grad():
110
  image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
111
  image = pipe.image_processor.postprocess(image, output_type="pil")[0]
112
 
113
  yield image
114
+
115
+ # Aggressive memory cleanup
116
+ del latent_model_input, noise_pred, noise_pred_uncond, noise_pred_text
117
+ gc.collect()
118
+
119
+ # Final cleanup
120
+ del latents, text_embeddings
121
+ gc.collect()
122
 
123
  # Gradio interface
124
  with gr.Blocks() as demo:
125
+ gr.Markdown("# ✨ Studio Ghibli Style Transformer (CPU Optimized) ✨")
126
+ gr.Markdown("Upload a portrait photo to transform it into a Studio Ghibli-style artwork (max 10GB RAM usage)!")
127
 
128
  with gr.Row():
129
  with gr.Column():
130
  input_image = gr.Image(label="Input Image", type="pil")
131
+ steps_slider = gr.Slider(10, 40, value=25, step=5, label="Number of Steps")
132
+ strength_slider = gr.Slider(0.4, 0.8, value=0.6, step=0.1, label="Transformation Strength")
133
  generate_btn = gr.Button("✨ Transform!", variant="primary")
134
 
135
  with gr.Column():
136
  gallery = gr.Gallery(
137
  label="Generation Progress",
138
  show_label=True,
139
+ columns=4,
140
  preview=True,
141
  object_fit="contain",
142
  height=600
 
149
  concurrency_limit=1
150
  )
151
 
 
 
152
  if __name__ == "__main__":
153
+ demo.queue(concurrency_count=1)
154
  demo.launch()