Spaces:
Running
Running
Anurag Bhardwaj
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -3,14 +3,15 @@ import sys
|
|
3 |
import subprocess
|
4 |
import importlib.util
|
5 |
|
6 |
-
#
|
7 |
required_packages = {
|
8 |
"gradio": "gradio",
|
9 |
"diffusers": "diffusers",
|
10 |
"torch": "torch",
|
11 |
"PIL": "pillow",
|
12 |
"transformers": "transformers",
|
13 |
-
"safetensors": "safetensors"
|
|
|
14 |
}
|
15 |
|
16 |
def install_package(package_name):
|
@@ -29,60 +30,42 @@ import torch
|
|
29 |
from PIL import Image
|
30 |
from diffusers import StableDiffusionImg2ImgPipeline
|
31 |
from safetensors.torch import load_file
|
|
|
32 |
|
33 |
def monkeypatch_lora(unet, lora_path, alpha=1.0):
|
34 |
"""
|
35 |
-
|
36 |
-
This function loads a LoRA weights file (
|
37 |
-
to the corresponding weights
|
38 |
-
either "lora_up" or "lora_down" and that the corresponding base weight can be obtained by replacing
|
39 |
-
these substrings with "weight".
|
40 |
-
|
41 |
-
Parameters:
|
42 |
-
- unet: The UNet model of the diffusion pipeline.
|
43 |
-
- lora_path: Path (or identifier) to the LoRA weights file.
|
44 |
-
- alpha: A scaling factor for the LoRA weights.
|
45 |
"""
|
46 |
print(f"Loading LoRA weights from: {lora_path}")
|
47 |
-
# Load the LoRA weights (assumed to be in safetensors format).
|
48 |
lora_state = load_file(lora_path)
|
49 |
-
|
50 |
-
# Get the current state dict of the UNet.
|
51 |
unet_state = unet.state_dict()
|
52 |
-
|
53 |
-
# Iterate over the LoRA weights and merge them.
|
54 |
for key, delta in lora_state.items():
|
55 |
-
# Example mapping: if key contains "lora_up" or "lora_down", map it to a base weight key.
|
56 |
if "lora_up" in key or "lora_down" in key:
|
57 |
-
# Derive the corresponding base key.
|
58 |
base_key = key.replace("lora_up", "weight").replace("lora_down", "weight")
|
59 |
if base_key in unet_state:
|
60 |
-
# Merge the LoRA delta scaled by alpha into the base weight.
|
61 |
unet_state[base_key] = unet_state[base_key] + delta.to(unet_state[base_key].device) * alpha
|
62 |
print(f"Applied LoRA delta for {base_key}")
|
63 |
else:
|
64 |
print(f"Warning: Base weight {base_key} not found in UNet state dict.")
|
65 |
else:
|
66 |
print(f"Skipping key {key} as it does not appear to be a LoRA weight.")
|
67 |
-
|
68 |
-
# Load the updated state dict into the UNet.
|
69 |
unet.load_state_dict(unet_state)
|
70 |
print("LoRA merging completed.")
|
71 |
|
72 |
def load_model():
|
73 |
"""
|
74 |
Load the base Stable Diffusion model and apply the FLUX.1-dev LoRA weights.
|
75 |
-
|
76 |
-
The FLUX.1-dev weights (LoRA) are then merged into the UNet.
|
77 |
"""
|
78 |
-
# Base model identifier.
|
79 |
base_model_id = "runwayml/stable-diffusion-v1-5"
|
80 |
hf_token = os.environ.get("HF_TOKEN")
|
81 |
if hf_token is None:
|
82 |
-
raise ValueError("HF_TOKEN environment variable is not set. "
|
83 |
-
|
84 |
-
|
85 |
-
# Load the base model with authentication.
|
86 |
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
87 |
base_model_id,
|
88 |
torch_dtype=torch.float16,
|
@@ -91,16 +74,22 @@ def load_model():
|
|
91 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
92 |
pipe = pipe.to(device)
|
93 |
|
94 |
-
#
|
95 |
-
# You can either use a local file path (e.g., "./flux_ghibsky_lora.safetensors") or
|
96 |
-
# download from the gated repository if permitted. Here, we assume a local file.
|
97 |
lora_weights_path = "./flux_ghibsky_lora.safetensors"
|
|
|
|
|
98 |
if not os.path.exists(lora_weights_path):
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
# Apply the LoRA weights to the UNet.
|
102 |
monkeypatch_lora(pipe.unet, lora_weights_path, alpha=1.0)
|
103 |
-
|
104 |
print("Base model loaded and FLUX.1-dev LoRA weights merged.")
|
105 |
return pipe
|
106 |
|
@@ -110,7 +99,7 @@ pipe = load_model()
|
|
110 |
def transform_image(image: Image.Image, strength: float, steps: int) -> Image.Image:
|
111 |
"""
|
112 |
Transforms the uploaded image into Ghibli-inspired art.
|
113 |
-
The prompt is
|
114 |
"""
|
115 |
prompt = (
|
116 |
"GHIBSKY style, a portrait transformed into dreamy, Ghibli-inspired art, "
|
|
|
3 |
import subprocess
|
4 |
import importlib.util
|
5 |
|
6 |
+
# Required packages including huggingface_hub for downloading files.
|
7 |
required_packages = {
|
8 |
"gradio": "gradio",
|
9 |
"diffusers": "diffusers",
|
10 |
"torch": "torch",
|
11 |
"PIL": "pillow",
|
12 |
"transformers": "transformers",
|
13 |
+
"safetensors": "safetensors",
|
14 |
+
"huggingface_hub": "huggingface_hub"
|
15 |
}
|
16 |
|
17 |
def install_package(package_name):
|
|
|
30 |
from PIL import Image
|
31 |
from diffusers import StableDiffusionImg2ImgPipeline
|
32 |
from safetensors.torch import load_file
|
33 |
+
from huggingface_hub import hf_hub_download
|
34 |
|
35 |
def monkeypatch_lora(unet, lora_path, alpha=1.0):
|
36 |
"""
|
37 |
+
Merge LoRA weights into the UNet model.
|
38 |
+
This function loads a LoRA weights file (safetensors format) and applies the deltas
|
39 |
+
to the corresponding base weights of the UNet.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
"""
|
41 |
print(f"Loading LoRA weights from: {lora_path}")
|
|
|
42 |
lora_state = load_file(lora_path)
|
|
|
|
|
43 |
unet_state = unet.state_dict()
|
44 |
+
|
|
|
45 |
for key, delta in lora_state.items():
|
|
|
46 |
if "lora_up" in key or "lora_down" in key:
|
|
|
47 |
base_key = key.replace("lora_up", "weight").replace("lora_down", "weight")
|
48 |
if base_key in unet_state:
|
|
|
49 |
unet_state[base_key] = unet_state[base_key] + delta.to(unet_state[base_key].device) * alpha
|
50 |
print(f"Applied LoRA delta for {base_key}")
|
51 |
else:
|
52 |
print(f"Warning: Base weight {base_key} not found in UNet state dict.")
|
53 |
else:
|
54 |
print(f"Skipping key {key} as it does not appear to be a LoRA weight.")
|
|
|
|
|
55 |
unet.load_state_dict(unet_state)
|
56 |
print("LoRA merging completed.")
|
57 |
|
58 |
def load_model():
|
59 |
"""
|
60 |
Load the base Stable Diffusion model and apply the FLUX.1-dev LoRA weights.
|
61 |
+
If the LoRA weights file is not found locally, it will be downloaded from the Hugging Face Hub.
|
|
|
62 |
"""
|
|
|
63 |
base_model_id = "runwayml/stable-diffusion-v1-5"
|
64 |
hf_token = os.environ.get("HF_TOKEN")
|
65 |
if hf_token is None:
|
66 |
+
raise ValueError("HF_TOKEN environment variable is not set. Please set it to access gated repositories.")
|
67 |
+
|
68 |
+
# Load the base model.
|
|
|
69 |
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
70 |
base_model_id,
|
71 |
torch_dtype=torch.float16,
|
|
|
74 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
75 |
pipe = pipe.to(device)
|
76 |
|
77 |
+
# Define the expected local path for the LoRA weights.
|
|
|
|
|
78 |
lora_weights_path = "./flux_ghibsky_lora.safetensors"
|
79 |
+
|
80 |
+
# If the file does not exist locally, attempt to download it.
|
81 |
if not os.path.exists(lora_weights_path):
|
82 |
+
print(f"LoRA weights file not found at {lora_weights_path}. Attempting to download from Hugging Face Hub...")
|
83 |
+
# Download the file from the gated repository.
|
84 |
+
lora_weights_path = hf_hub_download(
|
85 |
+
repo_id="black-forest-labs/FLUX.1-dev",
|
86 |
+
filename="flux_ghibsky_lora.safetensors",
|
87 |
+
use_auth_token=hf_token
|
88 |
+
)
|
89 |
+
print(f"Downloaded LoRA weights to {lora_weights_path}.")
|
90 |
|
91 |
# Apply the LoRA weights to the UNet.
|
92 |
monkeypatch_lora(pipe.unet, lora_weights_path, alpha=1.0)
|
|
|
93 |
print("Base model loaded and FLUX.1-dev LoRA weights merged.")
|
94 |
return pipe
|
95 |
|
|
|
99 |
def transform_image(image: Image.Image, strength: float, steps: int) -> Image.Image:
|
100 |
"""
|
101 |
Transforms the uploaded image into Ghibli-inspired art.
|
102 |
+
The prompt is prefixed with "GHIBSKY style".
|
103 |
"""
|
104 |
prompt = (
|
105 |
"GHIBSKY style, a portrait transformed into dreamy, Ghibli-inspired art, "
|