Anurag Bhardwaj commited on
Commit
de1e61c
·
verified ·
1 Parent(s): fbe554c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -108
app.py CHANGED
@@ -1,129 +1,216 @@
1
- import os
2
  import sys
3
  import subprocess
4
  import importlib.util
 
5
 
6
- # Required packages including huggingface_hub for downloading files.
7
  required_packages = {
8
  "gradio": "gradio",
9
- "diffusers": "diffusers",
10
  "torch": "torch",
 
11
  "PIL": "pillow",
12
- "transformers": "transformers",
13
- "safetensors": "safetensors",
14
- "huggingface_hub": "huggingface_hub"
15
  }
16
 
17
  def install_package(package_name):
18
- """Install package using pip."""
19
  subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])
20
 
21
- # Check and install missing packages.
22
- for module_name, pip_name in required_packages.items():
23
- if importlib.util.find_spec(module_name) is None:
24
- print(f"Package {module_name} not found. Installing {pip_name}...")
25
- install_package(pip_name)
26
 
27
- # Now import the required packages.
28
  import gradio as gr
 
 
29
  import torch
 
30
  from PIL import Image
31
- from diffusers import StableDiffusionImg2ImgPipeline
32
- from safetensors.torch import load_file
33
- from huggingface_hub import hf_hub_download
34
-
35
- def monkeypatch_lora(unet, lora_path, alpha=1.0):
36
- """
37
- Merge LoRA weights into the UNet model.
38
- This function loads a LoRA weights file (safetensors format) and applies the deltas
39
- to the corresponding base weights of the UNet.
40
- """
41
- print(f"Loading LoRA weights from: {lora_path}")
42
- lora_state = load_file(lora_path)
43
- unet_state = unet.state_dict()
44
-
45
- for key, delta in lora_state.items():
46
- if "lora_up" in key or "lora_down" in key:
47
- base_key = key.replace("lora_up", "weight").replace("lora_down", "weight")
48
- if base_key in unet_state:
49
- unet_state[base_key] = unet_state[base_key] + delta.to(unet_state[base_key].device) * alpha
50
- print(f"Applied LoRA delta for {base_key}")
51
- else:
52
- print(f"Warning: Base weight {base_key} not found in UNet state dict.")
53
- else:
54
- print(f"Skipping key {key} as it does not appear to be a LoRA weight.")
55
- unet.load_state_dict(unet_state)
56
- print("LoRA merging completed.")
57
-
58
- def load_model():
59
- """
60
- Load the base Stable Diffusion model and apply the FLUX.1-dev LoRA weights.
61
- If the LoRA weights file is not found locally, it will be downloaded from the Hugging Face Hub.
62
- """
63
- base_model_id = "runwayml/stable-diffusion-v1-5"
64
- hf_token = os.environ.get("HF_TOKEN")
65
- if hf_token is None:
66
- raise ValueError("HF_TOKEN environment variable is not set. Please set it to access gated repositories.")
67
-
68
- # Load the base model.
69
- pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
70
- base_model_id,
71
- torch_dtype=torch.float16,
72
- use_auth_token=hf_token
73
- )
74
- device = "cuda" if torch.cuda.is_available() else "cpu"
75
- pipe = pipe.to(device)
76
-
77
- # Define the expected local path for the LoRA weights.
78
- lora_weights_path = "./flux_ghibsky_lora.safetensors"
79
 
80
- # If the file does not exist locally, attempt to download it.
81
- if not os.path.exists(lora_weights_path):
82
- print(f"LoRA weights file not found at {lora_weights_path}. Attempting to download from Hugging Face Hub...")
83
- # Download the file from the gated repository.
84
- lora_weights_path = hf_hub_download(
85
- repo_id="black-forest-labs/FLUX.1-dev",
86
- filename="flux_ghibsky_lora.safetensors",
87
- use_auth_token=hf_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  )
89
- print(f"Downloaded LoRA weights to {lora_weights_path}.")
90
-
91
- # Apply the LoRA weights to the UNet.
92
- monkeypatch_lora(pipe.unet, lora_weights_path, alpha=1.0)
93
- print("Base model loaded and FLUX.1-dev LoRA weights merged.")
94
- return pipe
95
-
96
- # Load the model once at startup.
97
- pipe = load_model()
98
-
99
- def transform_image(image: Image.Image, strength: float, steps: int) -> Image.Image:
100
- """
101
- Transforms the uploaded image into Ghibli-inspired art.
102
- The prompt is prefixed with "GHIBSKY style".
103
- """
104
- prompt = (
105
- "GHIBSKY style, a portrait transformed into dreamy, Ghibli-inspired art, "
106
- "featuring serene skies, surreal details, and intricate brush strokes"
107
- )
108
- result = pipe(prompt=prompt, image=image, strength=strength, num_inference_steps=steps)
109
- return result.images[0]
110
-
111
- # Create a Gradio interface.
112
- demo = gr.Interface(
113
- fn=transform_image,
114
- inputs=[
115
- gr.Image(type="pil", label="Upload your portrait image"),
116
- gr.Slider(0.1, 0.9, value=0.6, label="Transformation Strength"),
117
- gr.Slider(20, 100, step=5, value=50, label="Inference Steps")
118
- ],
119
- outputs=gr.Image(type="pil", label="Ghibli-Inspired Art"),
120
- title="GHIBSKY Art Transformer",
121
- description=(
122
- "Upload your portrait image and see it transformed into enchanting, Ghibli-inspired art. "
123
- "This demo uses a base Stable Diffusion model with FLUX.1-dev LoRA weights merged into it "
124
- "to achieve the unique GHIBSKY style. Ensure your HF_TOKEN is set to access gated repositories."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  )
126
- )
127
 
128
- if __name__ == "__main__":
129
- demo.launch()
 
 
1
  import sys
2
  import subprocess
3
  import importlib.util
4
+ import os
5
 
6
+ # List of required packages.
7
  required_packages = {
8
  "gradio": "gradio",
9
+ "numpy": "numpy",
10
  "torch": "torch",
11
+ "diffusers": "diffusers",
12
  "PIL": "pillow",
13
+ "spaces": "spaces" # If this is a custom package in your environment.
 
 
14
  }
15
 
16
  def install_package(package_name):
 
17
  subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])
18
 
19
+ # Auto-install any missing packages.
20
+ for mod, pkg in required_packages.items():
21
+ if importlib.util.find_spec(mod) is None:
22
+ print(f"Module {mod} not found, installing {pkg}...")
23
+ install_package(pkg)
24
 
25
+ import random
26
  import gradio as gr
27
+ import numpy as np
28
+ import spaces
29
  import torch
30
+ from diffusers import DiffusionPipeline
31
  from PIL import Image
32
+
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+
35
+ # Model identifiers.
36
+ repo_id = "black-forest-labs/FLUX.1-dev"
37
+ adapter_id = "alvarobartt/ghibli-characters-flux-lora"
38
+
39
+ # Retrieve HF token from environment (if required to access gated repositories).
40
+ hf_token = os.environ.get("HF_TOKEN", None)
41
+
42
+ # Load the base model from the repository.
43
+ pipeline = DiffusionPipeline.from_pretrained(
44
+ repo_id,
45
+ torch_dtype=torch.bfloat16,
46
+ use_auth_token=hf_token # Only needed if the repo is gated.
47
+ )
48
+ pipeline.load_lora_weights(adapter_id)
49
+ pipeline = pipeline.to(device)
50
+
51
+ MAX_SEED = np.iinfo(np.int32).max
52
+ MAX_IMAGE_SIZE = 1024
53
+
54
+ @spaces.GPU(duration=80)
55
+ def inference(
56
+ prompt: str,
57
+ seed: int,
58
+ randomize_seed: bool,
59
+ width: int,
60
+ height: int,
61
+ guidance_scale: float,
62
+ num_inference_steps: int,
63
+ lora_scale: float,
64
+ progress: gr.Progress = gr.Progress(track_tqdm=True),
65
+ ):
66
+ if randomize_seed:
67
+ seed = random.randint(0, MAX_SEED)
68
+ generator = torch.Generator(device=device).manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ image = pipeline(
71
+ prompt=prompt,
72
+ guidance_scale=guidance_scale,
73
+ num_inference_steps=num_inference_steps,
74
+ width=width,
75
+ height=height,
76
+ generator=generator,
77
+ joint_attention_kwargs={"scale": lora_scale},
78
+ ).images[0]
79
+
80
+ return image, seed
81
+
82
+ examples = [
83
+ (
84
+ "Ghibli style futuristic stormtrooper with glossy white armor and a sleek helmet,"
85
+ " standing heroically on a lush alien planet, vibrant flowers blooming around, soft"
86
+ " sunlight illuminating the scene, a gentle breeze rustling the leaves"
87
+ ),
88
+ ]
89
+
90
+ css = """
91
+ #col-container {
92
+ margin: 0 auto;
93
+ max-width: 640px;
94
+ }
95
+ """
96
+
97
+ with gr.Blocks(css=css) as demo:
98
+ with gr.Column(elem_id="col-container"):
99
+ gr.Markdown("# FLUX.1 Studio Ghibli LoRA")
100
+ gr.Markdown(
101
+ "[alvarobartt/ghibli-characters-flux-lora](https://huggingface.co/alvarobartt/ghibli-characters-flux-lora)"
102
+ " is a LoRA fine-tune of [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)"
103
+ " with [alvarobartt/ghibli-characters](https://huggingface.co/datasets/alvarobartt/ghibli-characters)."
104
  )
105
+
106
+ with gr.Accordion("How to generate nice prompts?", open=False):
107
+ gr.Markdown(
108
+ "What worked best for me to generate high-quality prompts of well-known characters,"
109
+ " was to prompt either [Claude 3 Haiku](https://claude.ai), [GPT4-o](https://chatgpt.com/),"
110
+ " or [Perplexity](https://www.perplexity.ai/) with:\n\nYou are an"
111
+ " expert prompt writer for diffusion text to image models, and you've been provided"
112
+ " the following prompt template:\n\n\"Ghibli style [character description] with"
113
+ " [distinctive features], [action or pose], [environment or background],"
114
+ " [lighting or atmosphere], [additional details].\"\n\nCould you create a prompt"
115
+ " to generate [CHARACTER NAME] as a Studio Ghibli character following that template? "
116
+ "[MORE DETAILS IF NEEDED]\n"
117
+ )
118
+
119
+ with gr.Row():
120
+ prompt = gr.Text(
121
+ label="Prompt",
122
+ show_label=False,
123
+ max_lines=1,
124
+ placeholder="Enter your prompt",
125
+ container=False,
126
+ )
127
+
128
+ run_button = gr.Button("Run", scale=0)
129
+
130
+ result = gr.Image(label="Result", show_label=False)
131
+
132
+ with gr.Accordion("Advanced Settings", open=False):
133
+ seed = gr.Slider(
134
+ label="Seed",
135
+ minimum=0,
136
+ maximum=MAX_SEED,
137
+ step=1,
138
+ value=42,
139
+ )
140
+
141
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
142
+
143
+ with gr.Row():
144
+ width = gr.Slider(
145
+ label="Width",
146
+ minimum=256,
147
+ maximum=MAX_IMAGE_SIZE,
148
+ step=32,
149
+ value=1024,
150
+ )
151
+
152
+ height = gr.Slider(
153
+ label="Height",
154
+ minimum=256,
155
+ maximum=MAX_IMAGE_SIZE,
156
+ step=32,
157
+ value=768,
158
+ )
159
+
160
+ with gr.Row():
161
+ guidance_scale = gr.Slider(
162
+ label="Guidance scale",
163
+ minimum=0.0,
164
+ maximum=10.0,
165
+ step=0.1,
166
+ value=3.5,
167
+ )
168
+
169
+ num_inference_steps = gr.Slider(
170
+ label="Number of inference steps",
171
+ minimum=1,
172
+ maximum=50,
173
+ step=1,
174
+ value=30,
175
+ )
176
+
177
+ lora_scale = gr.Slider(
178
+ label="LoRA scale",
179
+ minimum=0.0,
180
+ maximum=1.0,
181
+ step=0.1,
182
+ value=1.0,
183
+ )
184
+
185
+ gr.Examples(
186
+ examples=examples,
187
+ fn=lambda x: (Image.open("./example.jpg"), 42),
188
+ inputs=[prompt],
189
+ outputs=[result, seed],
190
+ run_on_click=True,
191
+ )
192
+
193
+ gr.Markdown(
194
+ "### Disclaimer\n\n"
195
+ "License is non-commercial for both FLUX.1-dev and the Studio Ghibli dataset; "
196
+ "but free to use for personal and non-commercial purposes."
197
+ )
198
+
199
+ gr.on(
200
+ triggers=[run_button.click, prompt.submit],
201
+ fn=inference,
202
+ inputs=[
203
+ prompt,
204
+ seed,
205
+ randomize_seed,
206
+ width,
207
+ height,
208
+ guidance_scale,
209
+ num_inference_steps,
210
+ lora_scale,
211
+ ],
212
+ outputs=[result, seed],
213
  )
 
214
 
215
+ demo.queue()
216
+ demo.launch()