LPX55's picture
Update app.py
67dc795 verified
import spaces
import os
import gradio as gr
import torch
import safetensors
from huggingface_hub import hf_hub_download
from diffusers.utils import load_image, check_min_version
from controlnet_flux import FluxControlNetModel
from transformer_flux import FluxTransformer2DModel
from pipeline_flux_cnet import FluxControlNetInpaintingPipeline
from PIL import Image, ImageDraw
import numpy as np
import subprocess
from transformers import T5EncoderModel
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
HF_TOKEN = os.getenv("HF_TOKEN")
# Ensure that the minimal version of diffusers is installed
check_min_version("0.30.2")
quant_config = TransformersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
)
text_encoder_2_4bit = T5EncoderModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="text_encoder_2",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
token=HF_TOKEN
)
# quant_config = DiffusersBitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_use_double_quant=True,
# )
transformerx = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
torch_dtype=torch.bfloat16,
token=HF_TOKEN
)
# text_encoder_8bit = T5EncoderModel.from_pretrained(
# "black-forest-labs/FLUX.1-dev",
# subfolder="text_encoder_2",
# quantization_config=quant_config,
# torch_dtype=torch.bfloat16,
# use_safetensors=True,
# token=HF_TOKEN
# )
# Build pipeline
controlnet = FluxControlNetModel.from_pretrained(
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
# subfolder="controlnet",
torch_dtype=torch.bfloat16,
token=HF_TOKEN
)
pipe = FluxControlNetInpaintingPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
controlnet=controlnet,
# text_encoder_2=text_encoder_8bit,
transformer=transformerx,
torch_dtype=torch.bfloat16,
# device_map="balanced",
token=HF_TOKEN
)
# pipe.text_encoder_2 = text_encoder_2_4bit
# pipe.transformer = transformer_4bit
pipe.transformer.to(torch.bfloat16)
pipe.controlnet.to(torch.bfloat16)
pipe.to("cuda")
pipe.load_lora_weights("alimama-creative/FLUX.1-Turbo-Alpha", adapter_name="turbo")
pipe.set_adapters(["turbo"], adapter_weights=[0.95])
pipe.fuse_lora(lora_scale=1)
pipe.unload_lora_weights()
# We can utilize the enable_group_offload method for Diffusers model implementations
# pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True)
# For any other model implementations, the apply_group_offloading function can be used
# pipe.push_to_hub("FLUX.1-Inpainting-8step_uncensored", private=True, token=HF_TOKEN)
# pipe.enable_vae_tiling()
# pipe.enable_model_cpu_offload()
print(pipe.hf_device_map)
def create_mask_from_editor(editor_value):
"""
Create a mask from the ImageEditor value.
Args:
editor_value: Dictionary from EditorValue with 'background', 'layers', and 'composite'
Returns:
PIL Image with white mask
"""
# The 'composite' key contains the final image with all layers applied
composite_image = editor_value['composite']
# Convert to numpy array
composite_array = np.array(composite_image)
# Create mask where the composite image is white
mask_array = np.all(composite_array == (255, 255, 255), axis=-1).astype(np.uint8) * 255
mask_image = Image.fromarray(mask_array)
return mask_image
def create_mask_on_image(image, xyxy):
"""
Create a white mask on the image given xyxy coordinates.
Args:
image: PIL Image
xyxy: List of [x1, y1, x2, y2] coordinates
Returns:
PIL Image with white mask
"""
# Convert to numpy array
img_array = np.array(image)
# Create mask
mask = Image.new('RGB', image.size, (0, 0, 0))
draw = ImageDraw.Draw(mask)
# Draw white rectangle
draw.rectangle(xyxy, fill=(255, 255, 255))
# Convert mask to array
mask_array = np.array(mask)
# Apply mask to image
masked_array = np.where(mask_array == 255, 255, img_array)
return Image.fromarray(mask_array), Image.fromarray(masked_array)
def create_diptych_image(image):
# Create a diptych image with original on left and black on right
width, height = image.size
diptych = Image.new('RGB', (width * 2, height), 'black')
diptych.paste(image, (0, 0))
return diptych
@spaces.GPU(duration=120)
def inpaint_image(image, prompt, subject, editor_value):
# Load image and mask
size = (1536, 768)
image = load_image(image).convert("RGB").resize((768, 768))
diptych_image = create_diptych_image(image)
# mask = load_image(mask_path).convert("RGB").resize(size)
# mask, mask_image = create_mask_on_image(image, [250, 275, 500, 400])
mask, mask_image = create_mask_on_image(diptych_image, [768, 0, 1536, 768])
generator = torch.Generator(device="cuda").manual_seed(24)
# Load and preprocess image
# Calculate attention scale mask
attn_scale_factor = 1.5
# Create a tensor of ones with same size as diptych image
H, W = size[1]//16, size[0]//16
attn_scale_mask = torch.zeros(size[1], size[0])
attn_scale_mask[:, 768:] = 1.0 # height, width
attn_scale_mask = torch.nn.functional.interpolate(attn_scale_mask[None, None, :, :], (H, W), mode='nearest-exact').flatten()
attn_scale_mask = attn_scale_mask[None, None, :, None].repeat(1, 24, 1, H*W)
# Get inverted attention mask by subtracting from 1.0
transposed_inverted_attn_scale_mask = (1.0 - attn_scale_mask).transpose(-1, -2)
cross_attn_region = torch.logical_and(attn_scale_mask, transposed_inverted_attn_scale_mask)
cross_attn_region = cross_attn_region * attn_scale_factor
cross_attn_region[cross_attn_region < 1.0] = 1.0
full_attn_scale_mask = torch.ones(1, 24, 512+H*W, 512+H*W)
full_attn_scale_mask[:, :, 512:, 512:] = cross_attn_region
# Convert to bfloat16 to match model dtype
full_attn_scale_mask = full_attn_scale_mask.to(device=pipe.transformer.device, dtype=torch.bfloat16)
subject_name=subject
target_text_prompt=prompt
prompt_final=f'A two side-by-side image of {subject_name}. LEFT: a photo of {subject_name}; RIGHT: a photo of {subject_name} {target_text_prompt}.'
# Convert attention mask to PIL image format
# Take first head's mask after prompt tokens (shape is now H*W x H*W)
attn_vis = full_attn_scale_mask[0, 0]
attn_vis[attn_vis <= 1.0] = 0
attn_vis[attn_vis > 1.0] = 255
attn_vis = attn_vis.cpu().float().numpy().astype(np.uint8)
# # Convert to PIL Image
attn_vis_img = Image.fromarray(attn_vis)
attn_vis_img.save('attention_mask_vis.png')
with torch.inference_mode():
result = pipe(
prompt=prompt_final,
height=size[1],
width=size[0],
control_image=diptych_image,
control_mask=mask,
num_inference_steps=12,
generator=generator,
controlnet_conditioning_scale=0.7,
guidance_scale=1,
negative_prompt="",
true_guidance_scale=1.0,
attn_scale_mask=full_attn_scale_mask,
).images[0]
return result, attn_vis_img
# Create Gradio interface with structured layout
with gr.Blocks() as iface:
gr.Markdown("## FLUX Inpainting with Diptych Prompting")
gr.Markdown("Upload an image, specify a prompt, and draw a mask on the image. The app will automatically generate the inpainted image.")
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Accordion():
input_image = gr.Image(type="filepath", label="Upload Image")
with gr.Row():
prompt_preview = gr.Textbox(value="A two side-by-side image of 'subject_name'. LEFT: a photo of 'subject_name'; RIGHT: a photo of 'subject_name' 'target_text_prompt'", interactive=False)
subject = gr.Textbox(lines=1, placeholder="Enter your subject", label="Subject")
prompt = gr.Textbox(lines=2, placeholder="Enter your prompt here (e.g., 'wearing a christmas hat, in a busy street')", label="Prompt")
with gr.Column():
editor_value = gr.ImageEditor(type="pil", label="Image with Mask", sources="upload", visible=False)
inpainted_image = gr.Image(type="pil", label="Inpainted Image")
attn_vis_img = gr.Image(type="pil", label="Attn Vis Image")
with gr.Row():
inpaint_button = gr.Button("Inpaint")
inpaint_button.click(fn=inpaint_image, inputs=[input_image, prompt, subject, editor_value], outputs=[inpainted_image, attn_vis_img])
# Launch the app
iface.launch()