ar08 commited on
Commit
cec94f6
·
verified ·
1 Parent(s): 5205339

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -13
app.py CHANGED
@@ -1,27 +1,62 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import StableDiffusionPipeline
 
 
4
 
5
  device = "cuda" if torch.cuda.is_available() else "cpu"
6
  model_id = "nitrosocke/Ghibli-Diffusion"
7
 
8
- # Load the model once and keep it in memory
9
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
 
 
 
 
10
  pipe.to(device)
11
- pipe.enable_attention_slicing() # Optimize memory usage
 
 
 
 
12
 
13
- def generate_ghibli_style(image):
 
14
  prompt = "ghibli style portrait"
15
- with torch.inference_mode(): # Disables gradient calculations for faster inference
16
- result = pipe(prompt, image=image, strength=0.6, guidance_scale=6.5, num_inference_steps=25).images[0] # Reduced steps & optimized scale
17
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
 
19
  iface = gr.Interface(
20
  fn=generate_ghibli_style,
21
- inputs=gr.Image(type="pil"),
22
- outputs=gr.Image(),
23
- title="Studio Ghibli Portrait Generator",
24
- description="Upload a photo to generate a Ghibli-style portrait!"
 
 
 
25
  )
26
 
27
- iface.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers import StableDiffusionImg2ImgPipeline
4
+ from PIL import Image
5
+ import numpy as np
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
  model_id = "nitrosocke/Ghibli-Diffusion"
9
 
10
+ # Load the model
11
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
12
+ model_id,
13
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
14
+ safety_checker=None
15
+ )
16
  pipe.to(device)
17
+ pipe.enable_attention_slicing()
18
+
19
+ # Function to convert PIL image to latent-compatible numpy
20
+ def pil_to_np(image):
21
+ return np.array(image).astype(np.uint8)
22
 
23
+ # Generator with step-wise callback
24
+ def generate_ghibli_style(image, steps=25):
25
  prompt = "ghibli style portrait"
26
+ np_image = pil_to_np(image)
27
+ intermediate_images = []
28
+
29
+ def callback(step: int, timestep: int, latents):
30
+ # Decode latents to image and store for preview
31
+ with torch.no_grad():
32
+ img = pipe.decode_latents(latents)
33
+ img = pipe.numpy_to_pil(img)[0]
34
+ intermediate_images.append(img)
35
+
36
+ # Run the generation
37
+ with torch.inference_mode():
38
+ pipe(
39
+ prompt=prompt,
40
+ image=image,
41
+ strength=0.6,
42
+ guidance_scale=6.0,
43
+ num_inference_steps=steps,
44
+ callback=callback,
45
+ callback_steps=1, # Callback at every step
46
+ )
47
+
48
+ return intermediate_images
49
 
50
+ # Gradio Interface with image gallery preview
51
  iface = gr.Interface(
52
  fn=generate_ghibli_style,
53
+ inputs=[
54
+ gr.Image(type="pil", label="Upload a photo"),
55
+ gr.Slider(minimum=10, maximum=50, value=25, step=1, label="Inference Steps")
56
+ ],
57
+ outputs=gr.Gallery(label="Ghibli-style Generation Progress").style(grid=4),
58
+ title="✨ Studio Ghibli Portrait Generator ✨",
59
+ description="Upload a photo and watch it transform into a Ghibli-style portrait step by step!"
60
  )
61
 
62
+ iface.launch(share=True)