Anurag Bhardwaj commited on
Commit
fbe554c
·
verified ·
1 Parent(s): 8b62135

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -35
app.py CHANGED
@@ -3,14 +3,15 @@ import sys
3
  import subprocess
4
  import importlib.util
5
 
6
- # Add 'safetensors' to our required packages.
7
  required_packages = {
8
  "gradio": "gradio",
9
  "diffusers": "diffusers",
10
  "torch": "torch",
11
  "PIL": "pillow",
12
  "transformers": "transformers",
13
- "safetensors": "safetensors"
 
14
  }
15
 
16
  def install_package(package_name):
@@ -29,60 +30,42 @@ import torch
29
  from PIL import Image
30
  from diffusers import StableDiffusionImg2ImgPipeline
31
  from safetensors.torch import load_file
 
32
 
33
  def monkeypatch_lora(unet, lora_path, alpha=1.0):
34
  """
35
- A simplistic implementation to merge LoRA weights into the UNet.
36
- This function loads a LoRA weights file (in safetensors format) and applies the deltas
37
- to the corresponding weights in the UNet. The logic here assumes that the LoRA keys include
38
- either "lora_up" or "lora_down" and that the corresponding base weight can be obtained by replacing
39
- these substrings with "weight".
40
-
41
- Parameters:
42
- - unet: The UNet model of the diffusion pipeline.
43
- - lora_path: Path (or identifier) to the LoRA weights file.
44
- - alpha: A scaling factor for the LoRA weights.
45
  """
46
  print(f"Loading LoRA weights from: {lora_path}")
47
- # Load the LoRA weights (assumed to be in safetensors format).
48
  lora_state = load_file(lora_path)
49
-
50
- # Get the current state dict of the UNet.
51
  unet_state = unet.state_dict()
52
-
53
- # Iterate over the LoRA weights and merge them.
54
  for key, delta in lora_state.items():
55
- # Example mapping: if key contains "lora_up" or "lora_down", map it to a base weight key.
56
  if "lora_up" in key or "lora_down" in key:
57
- # Derive the corresponding base key.
58
  base_key = key.replace("lora_up", "weight").replace("lora_down", "weight")
59
  if base_key in unet_state:
60
- # Merge the LoRA delta scaled by alpha into the base weight.
61
  unet_state[base_key] = unet_state[base_key] + delta.to(unet_state[base_key].device) * alpha
62
  print(f"Applied LoRA delta for {base_key}")
63
  else:
64
  print(f"Warning: Base weight {base_key} not found in UNet state dict.")
65
  else:
66
  print(f"Skipping key {key} as it does not appear to be a LoRA weight.")
67
-
68
- # Load the updated state dict into the UNet.
69
  unet.load_state_dict(unet_state)
70
  print("LoRA merging completed.")
71
 
72
  def load_model():
73
  """
74
  Load the base Stable Diffusion model and apply the FLUX.1-dev LoRA weights.
75
- The base model is runwayml/stable-diffusion-v1-5 which contains all required components.
76
- The FLUX.1-dev weights (LoRA) are then merged into the UNet.
77
  """
78
- # Base model identifier.
79
  base_model_id = "runwayml/stable-diffusion-v1-5"
80
  hf_token = os.environ.get("HF_TOKEN")
81
  if hf_token is None:
82
- raise ValueError("HF_TOKEN environment variable is not set. "
83
- "Please set your Hugging Face token to access the gated repository.")
84
-
85
- # Load the base model with authentication.
86
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
87
  base_model_id,
88
  torch_dtype=torch.float16,
@@ -91,16 +74,22 @@ def load_model():
91
  device = "cuda" if torch.cuda.is_available() else "cpu"
92
  pipe = pipe.to(device)
93
 
94
- # Path to the FLUX.1-dev LoRA weights.
95
- # You can either use a local file path (e.g., "./flux_ghibsky_lora.safetensors") or
96
- # download from the gated repository if permitted. Here, we assume a local file.
97
  lora_weights_path = "./flux_ghibsky_lora.safetensors"
 
 
98
  if not os.path.exists(lora_weights_path):
99
- raise FileNotFoundError(f"LoRA weights file not found at {lora_weights_path}.")
 
 
 
 
 
 
 
100
 
101
  # Apply the LoRA weights to the UNet.
102
  monkeypatch_lora(pipe.unet, lora_weights_path, alpha=1.0)
103
-
104
  print("Base model loaded and FLUX.1-dev LoRA weights merged.")
105
  return pipe
106
 
@@ -110,7 +99,7 @@ pipe = load_model()
110
  def transform_image(image: Image.Image, strength: float, steps: int) -> Image.Image:
111
  """
112
  Transforms the uploaded image into Ghibli-inspired art.
113
- The prompt is automatically prefixed with "GHIBSKY style".
114
  """
115
  prompt = (
116
  "GHIBSKY style, a portrait transformed into dreamy, Ghibli-inspired art, "
 
3
  import subprocess
4
  import importlib.util
5
 
6
+ # Required packages including huggingface_hub for downloading files.
7
  required_packages = {
8
  "gradio": "gradio",
9
  "diffusers": "diffusers",
10
  "torch": "torch",
11
  "PIL": "pillow",
12
  "transformers": "transformers",
13
+ "safetensors": "safetensors",
14
+ "huggingface_hub": "huggingface_hub"
15
  }
16
 
17
  def install_package(package_name):
 
30
  from PIL import Image
31
  from diffusers import StableDiffusionImg2ImgPipeline
32
  from safetensors.torch import load_file
33
+ from huggingface_hub import hf_hub_download
34
 
35
  def monkeypatch_lora(unet, lora_path, alpha=1.0):
36
  """
37
+ Merge LoRA weights into the UNet model.
38
+ This function loads a LoRA weights file (safetensors format) and applies the deltas
39
+ to the corresponding base weights of the UNet.
 
 
 
 
 
 
 
40
  """
41
  print(f"Loading LoRA weights from: {lora_path}")
 
42
  lora_state = load_file(lora_path)
 
 
43
  unet_state = unet.state_dict()
44
+
 
45
  for key, delta in lora_state.items():
 
46
  if "lora_up" in key or "lora_down" in key:
 
47
  base_key = key.replace("lora_up", "weight").replace("lora_down", "weight")
48
  if base_key in unet_state:
 
49
  unet_state[base_key] = unet_state[base_key] + delta.to(unet_state[base_key].device) * alpha
50
  print(f"Applied LoRA delta for {base_key}")
51
  else:
52
  print(f"Warning: Base weight {base_key} not found in UNet state dict.")
53
  else:
54
  print(f"Skipping key {key} as it does not appear to be a LoRA weight.")
 
 
55
  unet.load_state_dict(unet_state)
56
  print("LoRA merging completed.")
57
 
58
  def load_model():
59
  """
60
  Load the base Stable Diffusion model and apply the FLUX.1-dev LoRA weights.
61
+ If the LoRA weights file is not found locally, it will be downloaded from the Hugging Face Hub.
 
62
  """
 
63
  base_model_id = "runwayml/stable-diffusion-v1-5"
64
  hf_token = os.environ.get("HF_TOKEN")
65
  if hf_token is None:
66
+ raise ValueError("HF_TOKEN environment variable is not set. Please set it to access gated repositories.")
67
+
68
+ # Load the base model.
 
69
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
70
  base_model_id,
71
  torch_dtype=torch.float16,
 
74
  device = "cuda" if torch.cuda.is_available() else "cpu"
75
  pipe = pipe.to(device)
76
 
77
+ # Define the expected local path for the LoRA weights.
 
 
78
  lora_weights_path = "./flux_ghibsky_lora.safetensors"
79
+
80
+ # If the file does not exist locally, attempt to download it.
81
  if not os.path.exists(lora_weights_path):
82
+ print(f"LoRA weights file not found at {lora_weights_path}. Attempting to download from Hugging Face Hub...")
83
+ # Download the file from the gated repository.
84
+ lora_weights_path = hf_hub_download(
85
+ repo_id="black-forest-labs/FLUX.1-dev",
86
+ filename="flux_ghibsky_lora.safetensors",
87
+ use_auth_token=hf_token
88
+ )
89
+ print(f"Downloaded LoRA weights to {lora_weights_path}.")
90
 
91
  # Apply the LoRA weights to the UNet.
92
  monkeypatch_lora(pipe.unet, lora_weights_path, alpha=1.0)
 
93
  print("Base model loaded and FLUX.1-dev LoRA weights merged.")
94
  return pipe
95
 
 
99
  def transform_image(image: Image.Image, strength: float, steps: int) -> Image.Image:
100
  """
101
  Transforms the uploaded image into Ghibli-inspired art.
102
+ The prompt is prefixed with "GHIBSKY style".
103
  """
104
  prompt = (
105
  "GHIBSKY style, a portrait transformed into dreamy, Ghibli-inspired art, "