VisualCloze / app.py
lzyhha
space
808cfcf
raw
history blame
24.5 kB
import argparse
import spaces
from visualcloze import VisualClozeModel
import gradio as gr
import demo_tasks
from functools import partial
from data.prefix_instruction import get_layout_instruction
from huggingface_hub import snapshot_download
max_grid_h = 5
max_grid_w = 5
default_grid_h = 2
default_grid_w = 3
default_upsampling_noise = 0.4
default_steps = 30
GUIDANCE = """
## 📋 Quick Start Guide:
1. Adjust **Number of In-context Examples**, 0 disables in-context learning.
2. Set **Task Columns**, the number of images involved in a task.
3. Upload Images. For in-context examples, upload all images. For the current query, upload images exclude the target.
4. Click **Generate** to create the images.
5. Parameters can be fine-tuned under **Advanced Options**.
<div style='font-size: 24px; font-weight: bold; color: #FF9999;'>Click the task button in the right bottom to acquire examples of various tasks.</div>
### 📧 Need help or have questions? Contact us at: lizhongyu [AT] mail.nankai.edu.cn
"""
CITATION = r"""
If you find VisualCloze is helpful, please consider to star ⭐ the <a href='https://github.com/lzyhha/VisualCloze' target='_blank'>Github Repo</a>. Thanks!
---
📝 **Citation**
<br>
If our work is useful for your research, please consider citing:
```bibtex
@article{li2025visualcloze,
title={VisualCloze: A Universal Image Generation Framework via Visual In-Context Learning},
author={Li, Zhong-Yu and Du, ruoyi and Yan, Juncheng and Zhuo, Le and Li, Zhen and Gao, Peng and Ma, Zhanyu and Cheng, Ming-Ming},
booktitle={arXiv preprint arxiv:},
year={2025}
}
```
📋 **License**
<br>
This project is licensed under xxx.
"""
def create_demo(model):
with gr.Blocks(title="VisualCloze Demo") as demo:
gr.Markdown("# VisualCloze: A Universal Image Generation Framework via Visual In-Context Learning")
gr.HTML("""
<div style="display:flex;column-gap:4px;">
<a href="xxx">
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a>
<a href="xxx">
<img src='https://img.shields.io/badge/ArXiv-Paper-red'>
</a>
<a href="xxx">
<img src='https://img.shields.io/badge/VisualCloze%20checkpoint-HF%20Model-green?logoColor=violet&label=%F0%9F%A4%97%20Checkpoint'>
</a>
<a href="xxx">
<img src='https://img.shields.io/badge/VisualCloze%20datasets-HF%20Dataset-6B88E3?logoColor=violet&label=%F0%9F%A4%97%20Graph200k%20Dataset'>
</a>
</div>
""")
gr.Markdown(GUIDANCE)
# gr.Markdown("<div style='font-size: 24px; font-weight: bold; color: #FF9999;'>" +
# "Note: Click the task button in the right bottom to acquire examples of tasks." +
# "</div>", )
# Pre-create all possible image components
all_image_inputs = []
rows = []
row_texts = []
with gr.Row():
# 左侧列:图像网格和提示输入
with gr.Column(scale=2):
# 图像网格部分
for i in range(max_grid_h):
# Add row label before each row
row_texts.append(gr.Markdown(
"<div style='font-size: 24px; font-weight: bold;'>" +
("query" if i == default_grid_h - 1 else f"In-context Example {i + 1}") +
"</div>",
elem_id=f"row_text_{i}",
visible=i < default_grid_h
))
with gr.Row(visible=i < default_grid_h, elem_id=f"row_{i}") as row:
rows.append(row)
for j in range(max_grid_w):
img_input = gr.Image(
label=f"In-context Example {i + 1}/{j + 1}" if i != default_grid_h - 1 else f"Query {j + 1}",
type="pil",
visible= i < default_grid_h and j < default_grid_w,
interactive=True,
elem_id=f"img_{i}_{j}"
)
all_image_inputs.append(img_input)
# 提示输入部分
layout_prompt = gr.Textbox(
label="Layout Description (Auto-filled, Read-only)",
placeholder="Layout description will be automatically filled based on grid size...",
value=get_layout_instruction(default_grid_w, default_grid_h),
elem_id="layout_prompt",
interactive=False
)
task_prompt = gr.Textbox(
label="Task Description (Can be modified by referring to examples to perform custom tasks, but may lead to unstable results)",
placeholder="Describe what task should be performed...",
value="",
elem_id="task_prompt"
)
content_prompt = gr.Textbox(
label="Content Description (Image caption, Editing instructions, etc.)",
placeholder="Describe the content requirements...",
value="",
elem_id="content_prompt"
)
generate_btn = gr.Button("Generate", elem_id="generate_btn")
grid_h = gr.Slider(minimum=0, maximum=max_grid_h-1, value=default_grid_h-1, step=1, label="Number of In-context Examples", elem_id="grid_h")
grid_w = gr.Slider(minimum=1, maximum=max_grid_w, value=default_grid_w, step=1, label="Task Columns", elem_id="grid_w")
with gr.Accordion("Advanced options", open=False):
seed = gr.Number(label="Seed (0 for random)", value=0, precision=0)
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=default_steps, step=1)
cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=50.0, value=30, step=1)
upsampling_steps = gr.Slider(label="Upsampling steps (SDEdit)", minimum=1, maximum=100.0, value=10, step=1)
upsampling_noise = gr.Slider(label="Upsampling noise (SDEdit)", minimum=0, maximum=1.0, value=default_upsampling_noise, step=0.01)
gr.Markdown(CITATION)
# 右侧列:输出图像
with gr.Column(scale=2):
output_gallery = gr.Gallery(
label="Generated Results",
show_label=True,
elem_id="output_gallery",
columns=None, # 设为None以允许自动调整
rows=None, # 设为None以允许自动调整
height="auto",
allow_preview=True,
object_fit="contain" # 确保图片完整显示
)
gr.Markdown("# Task Examples")
text_dense_prediction_tasks = gr.Textbox(label="Task", visible=False)
dense_prediction_tasks = gr.Dataset(
samples=demo_tasks.dense_prediction_text,
label='Dense Prediction',
samples_per_page=1000,
components=[text_dense_prediction_tasks])
text_conditional_generation_tasks = gr.Textbox(label="Task", visible=False)
conditional_generation_tasks = gr.Dataset(
samples=demo_tasks.conditional_generation_text,
label='Conditional Generation',
samples_per_page=1000,
components=[text_conditional_generation_tasks])
text_image_restoration_tasks = gr.Textbox(label="Task", visible=False)
image_restoration_tasks = gr.Dataset(
samples=demo_tasks.image_restoration_text,
label='Image Restoration',
samples_per_page=1000,
components=[text_image_restoration_tasks])
text_style_transfer_tasks = gr.Textbox(label="Task", visible=False)
style_transfer_tasks = gr.Dataset(
samples=demo_tasks.style_transfer_text,
label='Style Transfer',
samples_per_page=1000,
components=[text_style_transfer_tasks])
text_style_condition_fusion_tasks = gr.Textbox(label="Task", visible=False)
style_condition_fusion_tasks = gr.Dataset(
samples=demo_tasks.style_condition_fusion_text,
label='Style Condition Fusion',
samples_per_page=1000,
components=[text_style_condition_fusion_tasks])
text_tryon_tasks = gr.Textbox(label="Task", visible=False)
tryon_tasks = gr.Dataset(
samples=demo_tasks.tryon_text,
label='Virtual Try-On',
samples_per_page=1000,
components=[text_tryon_tasks])
text_relighting_tasks = gr.Textbox(label="Task", visible=False)
relighting_tasks = gr.Dataset(
samples=demo_tasks.relighting_text,
label='Relighting',
samples_per_page=1000,
components=[text_relighting_tasks])
text_photodoodle_tasks = gr.Textbox(label="Task", visible=False)
photodoodle_tasks = gr.Dataset(
samples=demo_tasks.photodoodle_text,
label='Photodoodle',
samples_per_page=1000,
components=[text_photodoodle_tasks])
text_editing_tasks = gr.Textbox(label="Task", visible=False)
editing_tasks = gr.Dataset(
samples=demo_tasks.editing_text,
label='Editing',
samples_per_page=1000,
components=[text_editing_tasks])
text_unseen_tasks = gr.Textbox(label="Task", visible=False)
unseen_tasks = gr.Dataset(
samples=demo_tasks.unseen_tasks_text,
label='Unseen Tasks (May produce unstable effects)',
samples_per_page=1000,
components=[text_unseen_tasks])
gr.Markdown("# Subject-driven Tasks Examples")
text_subject_driven_tasks = gr.Textbox(label="Task", visible=False)
subject_driven_tasks = gr.Dataset(
samples=demo_tasks.subject_driven_text,
label='Subject-driven Generation',
samples_per_page=1000,
components=[text_subject_driven_tasks])
text_condition_subject_fusion_tasks = gr.Textbox(label="Task", visible=False)
condition_subject_fusion_tasks = gr.Dataset(
samples=demo_tasks.condition_subject_fusion_text,
label='Condition+Subject Fusion',
samples_per_page=1000,
components=[text_condition_subject_fusion_tasks])
text_style_transfer_with_subject_tasks = gr.Textbox(label="Task", visible=False)
style_transfer_with_subject_tasks = gr.Dataset(
samples=demo_tasks.style_transfer_with_subject_text,
label='Style Transfer with Subject',
samples_per_page=1000,
components=[text_style_transfer_with_subject_tasks])
text_condition_subject_style_fusion_tasks = gr.Textbox(label="Task", visible=False)
condition_subject_style_fusion_tasks = gr.Dataset(
samples=demo_tasks.condition_subject_style_fusion_text,
label='Condition+Subject+Style Fusion',
samples_per_page=1000,
components=[text_condition_subject_style_fusion_tasks])
text_editing_with_subject_tasks = gr.Textbox(label="Task", visible=False)
editing_with_subject_tasks = gr.Dataset(
samples=demo_tasks.editing_with_subject_text,
label='Editing with Subject',
samples_per_page=1000,
components=[text_editing_with_subject_tasks])
text_image_restoration_with_subject_tasks = gr.Textbox(label="Task", visible=False)
image_restoration_with_subject_tasks = gr.Dataset(
samples=demo_tasks.image_restoration_with_subject_text,
label='Image Restoration with Subject',
samples_per_page=1000,
components=[text_image_restoration_with_subject_tasks])
def update_grid(h, w):
actual_h = h + 1
model.set_grid_size(actual_h, w)
updates = []
# Update image component visibility
for i in range(max_grid_h * max_grid_w):
curr_row = i // max_grid_w
curr_col = i % max_grid_w
updates.append(
gr.update(
label=f"In-context Example {curr_row + 1}/{curr_col + 1}" if curr_row != actual_h - 1 else f"Query {curr_col + 1}",
elem_id=f"img_{curr_row}_{curr_col}",
visible=(curr_row < actual_h and curr_col < w)))
# Update row visibility and labels
updates_row = []
updates_row_text = []
for i in range(max_grid_h):
updates_row.append(gr.update(f"row_{i}", visible=(i < actual_h)))
updates_row_text.append(
gr.update(
elem_id=f"row_text_{i}",
visible=i < actual_h,
value="<div style='font-size: 24px; font-weight: bold;'>" +
("Query" if i == actual_h - 1 else f"In-context Example {i + 1}") +
"</div>",
)
)
updates.extend(updates_row)
updates.extend(updates_row_text)
updates.append(gr.update(elem_id="layout_prompt", value=get_layout_instruction(w, actual_h)))
return updates
def generate_image(*inputs):
images = []
for i in range(model.grid_h):
images.append([])
for j in range(model.grid_w):
images[i].append(inputs[i * max_grid_w + j])
seed, cfg, steps, upsampling_steps, upsampling_noise, layout_text, task_text, content_text = inputs[-8:]
results = generate(
images,
[layout_text, task_text, content_text],
seed=seed, cfg=cfg, steps=steps,
upsampling_steps=upsampling_steps, upsampling_noise=upsampling_noise
)
output = gr.update(
elem_id='output_gallery',
value=results,
columns=min(len(results), 2),
rows=int(len(results) / 2 + 0.5))
return output
def process_tasks(task, func):
outputs = func(task)
mask = outputs[0]
state = outputs[1:8]
if state[5] is None:
state[5] = default_upsampling_noise
if state[6] is None:
state[6] = default_steps
images = outputs[8:-len(mask)]
output = outputs[-len(mask):]
for i in range(len(mask)):
if mask[i] == 1:
images.append(None)
else:
images.append(output[-len(mask) + i])
state[0] = state[0] - 1
cur_hrid_h = state[0]
cur_hrid_w = state[1]
current_example = [None] * 25
for i, image in enumerate(images):
pos = (i // cur_hrid_w) * 5 + (i % cur_hrid_w)
if image is not None:
current_example[pos] = image
update_grid(cur_hrid_h, cur_hrid_w)
output = gr.update(
elem_id='output_gallery',
value=output,
columns=min(len(output), 2),
rows=int(len(output) / 2 + 0.5))
return [output] + current_example + state
dense_prediction_tasks.click(
partial(process_tasks, func=demo_tasks.process_dense_prediction_tasks),
inputs=[dense_prediction_tasks],
outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps], show_progress=False, queue=False)
conditional_generation_tasks.click(
partial(process_tasks, func=demo_tasks.process_conditional_generation_tasks),
inputs=[conditional_generation_tasks],
outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps], show_progress=False, queue=False)
image_restoration_tasks.click(
partial(process_tasks, func=demo_tasks.process_image_restoration_tasks),
inputs=[image_restoration_tasks],
outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps], show_progress=False, queue=False)
style_transfer_tasks.click(
partial(process_tasks, func=demo_tasks.process_style_transfer_tasks),
inputs=[style_transfer_tasks],
outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps], show_progress=False, queue=False)
style_condition_fusion_tasks.click(
partial(process_tasks, func=demo_tasks.process_style_condition_fusion_tasks),
inputs=[style_condition_fusion_tasks],
outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps], show_progress=False, queue=False)
relighting_tasks.click(
partial(process_tasks, func=demo_tasks.process_relighting_tasks),
inputs=[relighting_tasks],
outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps], show_progress=False, queue=False)
tryon_tasks.click(
partial(process_tasks, func=demo_tasks.process_tryon_tasks),
inputs=[tryon_tasks],
outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps], show_progress=False, queue=False)
photodoodle_tasks.click(
partial(process_tasks, func=demo_tasks.process_photodoodle_tasks),
inputs=[photodoodle_tasks],
outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps], show_progress=False, queue=False)
editing_tasks.click(
partial(process_tasks, func=demo_tasks.process_editing_tasks),
inputs=[editing_tasks],
outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps], show_progress=False, queue=False)
unseen_tasks.click(
partial(process_tasks, func=demo_tasks.process_unseen_tasks),
inputs=[unseen_tasks],
outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps], show_progress=False, queue=False)
subject_driven_tasks.click(
partial(process_tasks, func=demo_tasks.process_subject_driven_tasks),
inputs=[subject_driven_tasks],
outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps], show_progress=False, queue=False)
style_transfer_with_subject_tasks.click(
partial(process_tasks, func=demo_tasks.process_style_transfer_with_subject_tasks),
inputs=[style_transfer_with_subject_tasks],
outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps], show_progress=False, queue=False)
condition_subject_fusion_tasks.click(
partial(process_tasks, func=demo_tasks.process_condition_subject_fusion_tasks),
inputs=[condition_subject_fusion_tasks],
outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps], show_progress=False, queue=False)
condition_subject_style_fusion_tasks.click(
partial(process_tasks, func=demo_tasks.process_condition_subject_style_fusion_tasks),
inputs=[condition_subject_style_fusion_tasks],
outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps], show_progress=False, queue=False)
editing_with_subject_tasks.click(
partial(process_tasks, func=demo_tasks.process_editing_with_subject_tasks),
inputs=[editing_with_subject_tasks],
outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps], show_progress=False, queue=False)
image_restoration_with_subject_tasks.click(
partial(process_tasks, func=demo_tasks.process_image_restoration_with_subject_tasks),
inputs=[image_restoration_with_subject_tasks],
outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps], show_progress=False, queue=False)
# Initialize grid
model.set_grid_size(default_grid_h, default_grid_w)
# Connect event processing function to all components that need updating
output_components = all_image_inputs + rows + row_texts + [layout_prompt]
grid_h.change(fn=update_grid, inputs=[grid_h, grid_w], outputs=output_components)
grid_w.change(fn=update_grid, inputs=[grid_h, grid_w], outputs=output_components)
# Modify generate button click event
generate_btn.click(
fn=generate_image,
inputs=all_image_inputs + [seed, cfg, steps, upsampling_steps, upsampling_noise] + [layout_prompt, task_prompt, content_prompt],
outputs=output_gallery
)
return demo
@spaces.GPU
def generate(
images,
prompts,
seed, cfg, steps,
upsampling_steps, upsampling_noise):
return model.process_images(
images=images,
prompts=prompts,
seed=seed,
cfg=cfg,
steps=steps,
upsampling_steps=upsampling_steps,
upsampling_noise=upsampling_noise)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="models/visualcloze-384-lora.pth")
parser.add_argument("--precision", type=str, choices=["fp32", "bf16", "fp16"], default="bf16")
parser.add_argument("--resolution", type=int, default=384)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
snapshot_download(repo_id="VisualCloze/VisualCloze", repo_type="model", local_dir="models")
# Initialize model
model = VisualClozeModel(resolution=args.resolution, model_path=args.model_path, precision=args.precision)
# Create Gradio demo
demo = create_demo(model)
# Start Gradio server
demo.launch()