Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
from PIL import Image, ImageDraw, ImageOps | |
import base64, json | |
from io import BytesIO | |
import torch.nn.functional as F | |
import json | |
from typing import List | |
from dataclasses import dataclass, field | |
from dreamfuse_inference import DreamFuseInference, InferenceConfig | |
import numpy as np | |
import os | |
from transformers import AutoModelForImageSegmentation | |
from torchvision import transforms | |
import torch | |
import subprocess | |
import base64 | |
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True) | |
generated_images = [] | |
RMBG_model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True) | |
RMBG_model = RMBG_model.to("cuda") | |
transform = transforms.Compose([ | |
transforms.Resize((1024, 1024)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
]) | |
def remove_bg(image): | |
im = image.convert("RGB") | |
input_tensor = transform(im).unsqueeze(0).to("cuda") | |
with torch.no_grad(): | |
preds = RMBG_model(input_tensor)[-1].sigmoid().cpu()[0].squeeze() | |
mask = transforms.ToPILImage()(preds).resize(im.size) | |
return mask | |
def get_base64_logo(path="examples/logo.png"): | |
image = Image.open(path).convert("RGBA") | |
buffered = BytesIO() | |
image.save(buffered, format="PNG") | |
base64_img = base64.b64encode(buffered.getvalue()).decode() | |
return f"data:image/png;base64,{base64_img}" | |
class DreamFuseGUI: | |
def __init__(self): | |
self.examples = [ | |
["./examples/placement_000_1.png", | |
"./examples/placement_000_0.png"], | |
["./examples/handheld_000_1.png", | |
"./examples/handheld_000_0.png"], | |
["./examples/030_1.webp", | |
"./examples/030_0.webp"], | |
["./examples/handheld_001_1.png", | |
"./examples/handheld_001_0.png"], | |
["./examples/style_000_1.jpg", | |
"./examples/style_000_0.jpg"], | |
["./examples/wear_000_1.jpg", | |
"./examples/wear_000_0.jpg"], | |
] | |
self.examples = [[Image.open(x) for x in example] for example in self.examples] | |
self.css_style = self._get_css_style() | |
self.js_script = self._get_js_script() | |
def _get_css_style(self): | |
return """ | |
input[type="file"] { | |
border: 1px solid #ccc !important; | |
background-color: #f9f9f9 !important; | |
padding: 8px !important; | |
border-radius: 6px !important; | |
} | |
body { | |
background-color: #ffffff; | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
color: #212121; | |
} | |
.gradio-container { | |
max-width: 1200px; | |
margin: auto; | |
background: #ffffff; | |
border-radius: 12px; | |
padding: 24px; | |
box-shadow: 0px 4px 16px rgba(0, 0, 0, 0.05); | |
} | |
h1, h2 { | |
text-align: center; | |
color: #222; | |
} | |
#canvas_preview { | |
min-height: 420px; | |
border: 2px dashed #ccc; | |
background-color: #fafafa; | |
border-radius: 8px; | |
padding: 10px; | |
display: flex; | |
justify-content: center; | |
align-items: center; | |
color: #999; | |
font-size: 16px; | |
} | |
.gr-button { | |
background-color: #2196f3; | |
border: 1px solid #1976d2; | |
color: #ffffff; | |
padding: 10px 20px; | |
border-radius: 6px; | |
font-size: 15px; | |
cursor: pointer; | |
transition: background-color 0.2s ease; | |
} | |
.gr-button:hover { | |
background-color: #1976d2; | |
} | |
#small-examples { | |
width: 200px; | |
margin: 10px 0; | |
border: 1px solid #ddd; | |
border-radius: 8px; | |
overflow: hidden; | |
background: #fff; | |
box-shadow: 0 1px 4px rgba(0,0,0,0.05); | |
} | |
.markdown-text { | |
color: #333; | |
font-size: 15px; | |
line-height: 1.6; | |
} | |
#canvas-preview-container { | |
background: #fafafa; | |
border: 1px solid #ddd; | |
border-radius: 8px; | |
padding: 10px; | |
margin-top: 10px; | |
} | |
[id^="section-"] { | |
background-color: #ffffff; | |
border: 1px solid #dddddd; | |
border-radius: 10px; | |
padding: 16px; | |
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.04); | |
margin-bottom: 16px; | |
} | |
.svelte-1ipelgc { | |
flex-wrap: nowrap !important; | |
gap: 24px !important; | |
} | |
""" | |
def _get_js_script(self): | |
return r""" | |
async () => { | |
window.updateTransformation = function() { | |
const img = document.getElementById('draggable-img'); | |
const container = document.getElementById('canvas-container'); | |
if (!img || !container) return; | |
const left = parseFloat(img.style.left) || 0; | |
const top = parseFloat(img.style.top) || 0; | |
const canvasSize = 400; | |
const data_original_width = parseFloat(img.getAttribute('data-original-width')); | |
const data_original_height = parseFloat(img.getAttribute('data-original-height')); | |
const bgWidth = parseFloat(container.dataset.bgWidth); | |
const bgHeight = parseFloat(container.dataset.bgHeight); | |
const scale_ratio = img.clientWidth / data_original_width; | |
const transformation = { | |
drag_left: left, | |
drag_top: top, | |
drag_width: img.clientWidth, | |
drag_height: img.clientHeight, | |
data_original_width: data_original_width, | |
data_original_height: data_original_height, | |
scale_ratio: scale_ratio | |
}; | |
const transInput = document.querySelector("#transformation_info textarea"); | |
if(transInput){ | |
const newValue = JSON.stringify(transformation); | |
const nativeSetter = Object.getOwnPropertyDescriptor(window.HTMLTextAreaElement.prototype, 'value').set; | |
nativeSetter.call(transInput, newValue); | |
transInput.dispatchEvent(new Event('input', { bubbles: true })); | |
console.log("Transformation info updated: ", newValue); | |
} else { | |
console.log("找不到 transformation_info 的 textarea 元素"); | |
} | |
}; | |
globalThis.initializeDrag = () => { | |
const oldImg = document.getElementById('draggable-img'); | |
const container = document.getElementById('canvas-container'); | |
const slider = document.getElementById('scale-slider'); | |
if (!oldImg || !container || !slider) { | |
return; | |
} | |
const img = oldImg.cloneNode(true); | |
oldImg.replaceWith(img); | |
img.ondragstart = (e) => { e.preventDefault(); return false; }; | |
let offsetX = 0, offsetY = 0; | |
let isDragging = false; | |
let scaleAnchor = null; | |
img.addEventListener('mousedown', (e) => { | |
isDragging = true; | |
img.style.cursor = 'grabbing'; | |
const imgRect = img.getBoundingClientRect(); | |
offsetX = e.clientX - imgRect.left; | |
offsetY = e.clientY - imgRect.top; | |
img.style.transform = "none"; | |
img.style.left = img.offsetLeft + "px"; | |
img.style.top = img.offsetTop + "px"; | |
console.log("mousedown: left=", img.style.left, "top=", img.style.top); | |
}); | |
document.addEventListener('mousemove', (e) => { | |
if (!isDragging) return; | |
e.preventDefault(); | |
const containerRect = container.getBoundingClientRect(); | |
let left = e.clientX - containerRect.left - offsetX; | |
let top = e.clientY - containerRect.top - offsetY; | |
const minLeft = -img.clientWidth * (7/8); | |
const maxLeft = containerRect.width - img.clientWidth * (1/8); | |
const minTop = -img.clientHeight * (7/8); | |
const maxTop = containerRect.height - img.clientHeight * (1/8); | |
if (left < minLeft) left = minLeft; | |
if (left > maxLeft) left = maxLeft; | |
if (top < minTop) top = minTop; | |
if (top > maxTop) top = maxTop; | |
img.style.left = left + "px"; | |
img.style.top = top + "px"; | |
}); | |
window.addEventListener('mouseup', (e) => { | |
if (isDragging) { | |
isDragging = false; | |
img.style.cursor = 'grab'; | |
const containerRect = container.getBoundingClientRect(); | |
const bgWidth = parseFloat(container.dataset.bgWidth); | |
const bgHeight = parseFloat(container.dataset.bgHeight); | |
const offsetLeft = (containerRect.width - bgWidth) / 2; | |
const offsetTop = (containerRect.height - bgHeight) / 2; | |
const absoluteLeft = parseFloat(img.style.left); | |
const absoluteTop = parseFloat(img.style.top); | |
const relativeX = absoluteLeft - offsetLeft; | |
const relativeY = absoluteTop - offsetTop; | |
document.getElementById("coordinate").textContent = | |
`Location: (x=${relativeX.toFixed(2)}, y=${relativeY.toFixed(2)})`; | |
updateTransformation(); | |
} | |
scaleAnchor = null; | |
}); | |
slider.addEventListener('mousedown', (e) => { | |
const containerRect = container.getBoundingClientRect(); | |
const imgRect = img.getBoundingClientRect(); | |
scaleAnchor = { | |
x: imgRect.left + imgRect.width/2 - containerRect.left, | |
y: imgRect.top + imgRect.height/2 - containerRect.top | |
}; | |
console.log("Slider mousedown, captured scaleAnchor: ", scaleAnchor); | |
}); | |
slider.addEventListener('input', (e) => { | |
const scale = parseFloat(e.target.value); | |
const originalWidth = parseFloat(img.getAttribute('data-original-width')); | |
const originalHeight = parseFloat(img.getAttribute('data-original-height')); | |
const newWidth = originalWidth * scale; | |
const newHeight = originalHeight * scale; | |
const containerRect = container.getBoundingClientRect(); | |
let centerX, centerY; | |
if (scaleAnchor) { | |
centerX = scaleAnchor.x; | |
centerY = scaleAnchor.y; | |
} else { | |
const imgRect = img.getBoundingClientRect(); | |
centerX = imgRect.left + imgRect.width/2 - containerRect.left; | |
centerY = imgRect.top + imgRect.height/2 - containerRect.top; | |
} | |
const newLeft = centerX - newWidth/2; | |
const newTop = centerY - newHeight/2; | |
img.style.width = newWidth + "px"; | |
img.style.height = newHeight + "px"; | |
img.style.left = newLeft + "px"; | |
img.style.top = newTop + "px"; | |
console.log("slider: scale=", scale, "newWidth=", newWidth, "newHeight=", newHeight); | |
updateTransformation(); | |
}); | |
slider.addEventListener('mouseup', (e) => { | |
scaleAnchor = null; | |
}); | |
console.log("✅ 拖拽和缩放事件已绑定"); | |
}; | |
} | |
""" | |
def get_next_sequence(self, folder_path): | |
# 列出文件夹中的所有文件名 | |
filenames = os.listdir(folder_path) | |
# 提取文件名中的序列号部分(假设是前三位数字) | |
sequences = [int(name.split('_')[0]) for name in filenames if name.split('_')[0].isdigit()] | |
# 找到最大序列号 | |
max_sequence = max(sequences, default=-1) | |
# 返回下一位序列号,格式为三位数字(如002) | |
return f"{max_sequence + 1:03d}" | |
def pil_to_base64(self, img): | |
if img is None: | |
return "" | |
if img.mode != "RGBA": | |
img = img.convert("RGBA") | |
buffered = BytesIO() | |
img.save(buffered, format="PNG", optimize=True) | |
img_bytes = buffered.getvalue() | |
base64_str = base64.b64encode(img_bytes).decode() | |
return f"data:image/png;base64,{base64_str}" | |
def resize_background_image(self, img, max_size=400): | |
if img is None: | |
return None | |
w, h = img.size | |
if w > max_size or h > max_size: | |
ratio = min(max_size / w, max_size / h) | |
new_w, new_h = int(w * ratio), int(h * ratio) | |
img = img.resize((new_w, new_h), Image.LANCZOS) | |
return img | |
def resize_draggable_image(self, img, max_size=400): | |
if img is None: | |
return None | |
w, h = img.size | |
if w > max_size or h > max_size: | |
ratio = min(max_size / w, max_size / h) | |
new_w, new_h = int(w * ratio), int(h * ratio) | |
img = img.resize((new_w, new_h), Image.LANCZOS) | |
return img | |
def generate_html(self, background_img_b64, bg_width, bg_height, draggable_img_b64, draggable_width, draggable_height, canvas_size=400): | |
html_code = f""" | |
<html> | |
<head> | |
<style> | |
body {{ | |
margin: 0; | |
padding: 0; | |
text-align: center; | |
font-family: sans-serif; | |
background: transparent; | |
color: #fff; | |
}} | |
h2 {{ | |
margin-top: 1rem; | |
}} | |
#scale-control {{ | |
margin: 1rem auto; | |
width: 400px; | |
text-align: left; | |
}} | |
#scale-control label {{ | |
font-size: 1rem; | |
margin-right: 0.5rem; | |
}} | |
#canvas-container {{ | |
position: relative; | |
width: {canvas_size}px; | |
height: {canvas_size}px; | |
margin: 0 auto; | |
border: 1px dashed rgba(255,255,255,0.5); | |
overflow: hidden; | |
background-image: url('{background_img_b64}'); | |
background-repeat: no-repeat; | |
background-position: center; | |
background-size: contain; | |
border-radius: 8px; | |
}} | |
#draggable-img {{ | |
position: absolute; | |
cursor: grab; | |
left: 50%; | |
top: 50%; | |
transform: translate(-50%, -50%); | |
background-color: transparent; | |
}} | |
#coordinate {{ | |
color: #fff; | |
margin-top: 1rem; | |
font-weight: bold; | |
}} | |
</style> | |
</head> | |
<body> | |
<h2> 3️⃣ Drag and Resize</h2> | |
<div id="scale-control"> | |
<label for="scale-slider">Resize FG:</label> | |
<input type="range" id="scale-slider" min="0.1" max="2" step="0.01" value="1"> | |
</div> | |
<div id="canvas-container" data-bg-width="{bg_width}" data-bg-height="{bg_height}"> | |
<img id="draggable-img" | |
src="{draggable_img_b64}" | |
alt="Draggable Image" | |
draggable="false" | |
data-original-width="{draggable_width}" | |
data-original-height="{draggable_height}" | |
/> | |
</div> | |
<p id="coordinate">location: (x=?, y=?)</p> | |
</body> | |
</html> | |
""" | |
return html_code | |
def on_upload(self, background_img, draggable_img): | |
if background_img is None or draggable_img is None: | |
return "<p style='color:red;'>Please upload the background and foreground images。</p>" | |
if draggable_img.mode != "RGB": | |
draggable_img = draggable_img.convert("RGB") | |
draggable_img_mask = remove_bg(draggable_img) | |
alpha_channel = draggable_img_mask.convert("L") | |
draggable_img = draggable_img.convert("RGBA") | |
draggable_img.putalpha(alpha_channel) | |
resized_bg = self.resize_background_image(background_img, max_size=400) | |
bg_w, bg_h = resized_bg.size | |
resized_fg = self.resize_draggable_image(draggable_img, max_size=400) | |
draggable_width, draggable_height = resized_fg.size | |
background_img_b64 = self.pil_to_base64(resized_bg) | |
draggable_img_b64 = self.pil_to_base64(resized_fg) | |
return self.generate_html( | |
background_img_b64, bg_w, bg_h, | |
draggable_img_b64, draggable_width, draggable_height, | |
canvas_size=400 | |
), draggable_img | |
def create_gui(self): | |
config = InferenceConfig() | |
config.lora_id = 'LL3RD/DreamFuse' | |
# pipeline = None | |
pipeline = DreamFuseInference(config) | |
pipeline.gradio_generate = spaces.GPU(duratioin=120)(pipeline.gradio_generate) | |
with gr.Blocks(css=self.css_style) as demo: | |
modified_fg_state = gr.State() | |
logo_data_url = get_base64_logo() | |
gr.HTML( | |
f""" | |
<div style="display: flex; align-items: center; justify-content: center; gap: 12px; margin-bottom: 20px;"> | |
<img src="{logo_data_url}" style="height: 80px;" /> | |
<h1 style="margin: 0; font-size: 32px;">DreamFuse</h1> | |
</div> | |
""" | |
) | |
gr.Markdown('## 📌 4 Easy Steps to Create Your Fusion Image:') | |
gr.Markdown( | |
""" | |
1. Upload the foreground and background images you want to fuse. | |
2. Click 'Generate Canvas' to preview the result. | |
3. Drag and resize the foreground object to position it as you like. | |
4. Click 'Run Model' to create the final fused image. | |
""", | |
elem_classes=["markdown-text"] | |
) | |
with gr.Row(): | |
with gr.Column(scale=1, elem_id="section-upload"): | |
gr.Markdown("### 1️⃣ FG&BG Image Upload") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
background_img_in = gr.Image(label="Background Image", type="pil", height=240, width=200) | |
with gr.Column(scale=1): | |
draggable_img_in = gr.Image(label="Foreground Image", type="pil", image_mode="RGBA", height=240, width=200) | |
generate_btn = gr.Button("2️⃣ Generate Canvas") | |
with gr.Column(scale=1, elem_id="section-preview"): | |
gr.Markdown("### Preview Region") | |
html_out = gr.HTML( | |
value="<p style='text-align:center; color:#aaa;'>Waiting for generating canvas...</p>", | |
label="drag and resize", | |
elem_id="canvas_preview" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1, elem_id="section-parameters"): | |
gr.Markdown("### Parameters") | |
seed_slider = gr.Slider(minimum=-1, maximum=100000, step=1, label="Seed", value=12345) | |
cfg_slider = gr.Slider(minimum=1, maximum=10, step=0.1, label="CFG", value=3.5) | |
size_select = gr.Radio( | |
choices=["512", "768", "1024"], | |
value="512", | |
label="Resolution (Higher resolution improves quality, but slows down generation.)", | |
) | |
prompt_text = gr.Textbox(label="Prompt", placeholder="text prompt", value="") | |
text_strength = gr.Slider(minimum=1, maximum=10, step=1, label="Text Strength (Improve text strength to increase responsiveness)", value=1, visible=True) | |
enable_gui = gr.Checkbox(label="GUI", value=True, visible=False) | |
enable_truecfg = gr.Checkbox(label="TrueCFG", value=False, visible=False) | |
with gr.Column(scale=1, elem_id="section-results"): | |
gr.Markdown("### Model Result") | |
model_generate_btn = gr.Button("4️⃣ Run Model") | |
transformation_text = gr.Textbox(label="Transformation Info", elem_id="transformation_info", visible=False) | |
model_output = gr.Image(label="Model Output", type="pil", height=512, width=512) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Examples( | |
examples=[self.examples[0]], | |
inputs=[background_img_in, draggable_img_in], | |
# elem_id="small-examples" | |
) | |
with gr.Column(scale=1): | |
gr.Examples( | |
examples=[self.examples[2]], | |
inputs=[background_img_in, draggable_img_in], | |
# elem_id="small-examples" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Examples( | |
examples=[self.examples[1]], | |
inputs=[background_img_in, draggable_img_in], | |
# elem_id="small-examples" | |
) | |
with gr.Column(scale=1): | |
gr.Examples( | |
examples=[self.examples[3]], | |
inputs=[background_img_in, draggable_img_in], | |
# elem_id="small-examples" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Examples( | |
examples=[self.examples[4]], | |
inputs=[background_img_in, draggable_img_in], | |
# elem_id="small-examples" | |
) | |
with gr.Column(scale=1): | |
gr.Examples( | |
examples=[self.examples[5]], | |
inputs=[background_img_in, draggable_img_in], | |
# elem_id="small-examples" | |
) | |
generate_btn.click( | |
fn=self.on_upload, | |
inputs=[background_img_in, draggable_img_in], | |
outputs=[html_out, modified_fg_state], | |
).then( | |
fn=None, | |
inputs=None, | |
outputs=None, | |
js="initializeDrag" | |
) | |
model_generate_btn.click( | |
fn=pipeline.gradio_generate, | |
# fn=self.pil_to_base64, | |
inputs=[background_img_in, modified_fg_state, transformation_text, seed_slider, \ | |
prompt_text, enable_gui, cfg_slider, size_select, text_strength, enable_truecfg], | |
outputs=model_output | |
) | |
demo.load(None, None, None, js=self.js_script) | |
generate_btn.click(fn=None, inputs=None, outputs=None, js="initializeDrag") | |
return demo | |
if __name__ == "__main__": | |
gui = DreamFuseGUI() | |
demo = gui.create_gui() | |
demo.queue() | |
demo.launch() | |