Spaces:
Runtime error
Runtime error
import spaces | |
import os | |
import gradio as gr | |
import torch | |
import safetensors | |
from huggingface_hub import hf_hub_download | |
from diffusers.utils import load_image, check_min_version | |
from controlnet_flux import FluxControlNetModel | |
from transformer_flux import FluxTransformer2DModel | |
from pipeline_flux_cnet import FluxControlNetInpaintingPipeline | |
from PIL import Image, ImageDraw | |
import numpy as np | |
import subprocess | |
from transformers import T5EncoderModel | |
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig | |
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig | |
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True) | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# Ensure that the minimal version of diffusers is installed | |
check_min_version("0.30.2") | |
quant_config = TransformersBitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
) | |
text_encoder_2_4bit = T5EncoderModel.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
subfolder="text_encoder_2", | |
quantization_config=quant_config, | |
torch_dtype=torch.bfloat16, | |
token=HF_TOKEN | |
) | |
# quant_config = DiffusersBitsAndBytesConfig( | |
# load_in_4bit=True, | |
# bnb_4bit_use_double_quant=True, | |
# ) | |
transformerx = FluxTransformer2DModel.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
subfolder="transformer", | |
torch_dtype=torch.bfloat16, | |
token=HF_TOKEN | |
) | |
# text_encoder_8bit = T5EncoderModel.from_pretrained( | |
# "black-forest-labs/FLUX.1-dev", | |
# subfolder="text_encoder_2", | |
# quantization_config=quant_config, | |
# torch_dtype=torch.bfloat16, | |
# use_safetensors=True, | |
# token=HF_TOKEN | |
# ) | |
# Build pipeline | |
controlnet = FluxControlNetModel.from_pretrained( | |
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", | |
# subfolder="controlnet", | |
torch_dtype=torch.bfloat16, | |
token=HF_TOKEN | |
) | |
pipe = FluxControlNetInpaintingPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
controlnet=controlnet, | |
# text_encoder_2=text_encoder_8bit, | |
transformer=transformerx, | |
torch_dtype=torch.bfloat16, | |
# device_map="balanced", | |
token=HF_TOKEN | |
) | |
# pipe.text_encoder_2 = text_encoder_2_4bit | |
# pipe.transformer = transformer_4bit | |
pipe.transformer.to(torch.bfloat16) | |
pipe.controlnet.to(torch.bfloat16) | |
pipe.to("cuda") | |
pipe.load_lora_weights("alimama-creative/FLUX.1-Turbo-Alpha", adapter_name="turbo") | |
pipe.set_adapters(["turbo"], adapter_weights=[0.95]) | |
pipe.fuse_lora(lora_scale=1) | |
pipe.unload_lora_weights() | |
# We can utilize the enable_group_offload method for Diffusers model implementations | |
# pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True) | |
# For any other model implementations, the apply_group_offloading function can be used | |
# pipe.push_to_hub("FLUX.1-Inpainting-8step_uncensored", private=True, token=HF_TOKEN) | |
# pipe.enable_vae_tiling() | |
# pipe.enable_model_cpu_offload() | |
print(pipe.hf_device_map) | |
def create_mask_from_editor(editor_value): | |
""" | |
Create a mask from the ImageEditor value. | |
Args: | |
editor_value: Dictionary from EditorValue with 'background', 'layers', and 'composite' | |
Returns: | |
PIL Image with white mask | |
""" | |
# The 'composite' key contains the final image with all layers applied | |
composite_image = editor_value['composite'] | |
# Convert to numpy array | |
composite_array = np.array(composite_image) | |
# Create mask where the composite image is white | |
mask_array = np.all(composite_array == (255, 255, 255), axis=-1).astype(np.uint8) * 255 | |
mask_image = Image.fromarray(mask_array) | |
return mask_image | |
def create_mask_on_image(image, xyxy): | |
""" | |
Create a white mask on the image given xyxy coordinates. | |
Args: | |
image: PIL Image | |
xyxy: List of [x1, y1, x2, y2] coordinates | |
Returns: | |
PIL Image with white mask | |
""" | |
# Convert to numpy array | |
img_array = np.array(image) | |
# Create mask | |
mask = Image.new('RGB', image.size, (0, 0, 0)) | |
draw = ImageDraw.Draw(mask) | |
# Draw white rectangle | |
draw.rectangle(xyxy, fill=(255, 255, 255)) | |
# Convert mask to array | |
mask_array = np.array(mask) | |
# Apply mask to image | |
masked_array = np.where(mask_array == 255, 255, img_array) | |
return Image.fromarray(mask_array), Image.fromarray(masked_array) | |
def create_diptych_image(image): | |
# Create a diptych image with original on left and black on right | |
width, height = image.size | |
diptych = Image.new('RGB', (width * 2, height), 'black') | |
diptych.paste(image, (0, 0)) | |
return diptych | |
def inpaint_image(image, prompt, subject, editor_value): | |
# Load image and mask | |
size = (1536, 768) | |
image = load_image(image).convert("RGB").resize((768, 768)) | |
diptych_image = create_diptych_image(image) | |
# mask = load_image(mask_path).convert("RGB").resize(size) | |
# mask, mask_image = create_mask_on_image(image, [250, 275, 500, 400]) | |
mask, mask_image = create_mask_on_image(diptych_image, [768, 0, 1536, 768]) | |
generator = torch.Generator(device="cuda").manual_seed(24) | |
# Load and preprocess image | |
# Calculate attention scale mask | |
attn_scale_factor = 1.5 | |
# Create a tensor of ones with same size as diptych image | |
H, W = size[1]//16, size[0]//16 | |
attn_scale_mask = torch.zeros(size[1], size[0]) | |
attn_scale_mask[:, 768:] = 1.0 # height, width | |
attn_scale_mask = torch.nn.functional.interpolate(attn_scale_mask[None, None, :, :], (H, W), mode='nearest-exact').flatten() | |
attn_scale_mask = attn_scale_mask[None, None, :, None].repeat(1, 24, 1, H*W) | |
# Get inverted attention mask by subtracting from 1.0 | |
transposed_inverted_attn_scale_mask = (1.0 - attn_scale_mask).transpose(-1, -2) | |
cross_attn_region = torch.logical_and(attn_scale_mask, transposed_inverted_attn_scale_mask) | |
cross_attn_region = cross_attn_region * attn_scale_factor | |
cross_attn_region[cross_attn_region < 1.0] = 1.0 | |
full_attn_scale_mask = torch.ones(1, 24, 512+H*W, 512+H*W) | |
full_attn_scale_mask[:, :, 512:, 512:] = cross_attn_region | |
# Convert to bfloat16 to match model dtype | |
full_attn_scale_mask = full_attn_scale_mask.to(device=pipe.transformer.device, dtype=torch.bfloat16) | |
subject_name=subject | |
target_text_prompt=prompt | |
prompt_final=f'A two side-by-side image of {subject_name}. LEFT: a photo of {subject_name}; RIGHT: a photo of {subject_name} {target_text_prompt}.' | |
# Convert attention mask to PIL image format | |
# Take first head's mask after prompt tokens (shape is now H*W x H*W) | |
attn_vis = full_attn_scale_mask[0, 0] | |
attn_vis[attn_vis <= 1.0] = 0 | |
attn_vis[attn_vis > 1.0] = 255 | |
attn_vis = attn_vis.cpu().float().numpy().astype(np.uint8) | |
# # Convert to PIL Image | |
attn_vis_img = Image.fromarray(attn_vis) | |
attn_vis_img.save('attention_mask_vis.png') | |
with torch.inference_mode(): | |
result = pipe( | |
prompt=prompt_final, | |
height=size[1], | |
width=size[0], | |
control_image=diptych_image, | |
control_mask=mask, | |
num_inference_steps=12, | |
generator=generator, | |
controlnet_conditioning_scale=0.7, | |
guidance_scale=1, | |
negative_prompt="", | |
true_guidance_scale=1.0, | |
attn_scale_mask=full_attn_scale_mask, | |
).images[0] | |
return result, attn_vis_img | |
# Create Gradio interface with structured layout | |
with gr.Blocks() as iface: | |
gr.Markdown("## FLUX Inpainting with Diptych Prompting") | |
gr.Markdown("Upload an image, specify a prompt, and draw a mask on the image. The app will automatically generate the inpainted image.") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Accordion(): | |
input_image = gr.Image(type="filepath", label="Upload Image") | |
with gr.Row(): | |
prompt_preview = gr.Textbox(value="A two side-by-side image of 'subject_name'. LEFT: a photo of 'subject_name'; RIGHT: a photo of 'subject_name' 'target_text_prompt'", interactive=False) | |
subject = gr.Textbox(lines=1, placeholder="Enter your subject", label="Subject") | |
prompt = gr.Textbox(lines=2, placeholder="Enter your prompt here (e.g., 'wearing a christmas hat, in a busy street')", label="Prompt") | |
with gr.Column(): | |
editor_value = gr.ImageEditor(type="pil", label="Image with Mask", sources="upload", visible=False) | |
inpainted_image = gr.Image(type="pil", label="Inpainted Image") | |
attn_vis_img = gr.Image(type="pil", label="Attn Vis Image") | |
with gr.Row(): | |
inpaint_button = gr.Button("Inpaint") | |
inpaint_button.click(fn=inpaint_image, inputs=[input_image, prompt, subject, editor_value], outputs=[inpainted_image, attn_vis_img]) | |
# Launch the app | |
iface.launch() |