UNO-FLUX / app.py
itembox's picture
Korean
d8ed60f
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import json
from pathlib import Path
import gradio as gr
import torch
import spaces
from uno.flux.pipeline import UNOPipeline
def get_examples(examples_dir: str = "assets/examples") -> list:
examples = Path(examples_dir)
ans = []
for example in examples.iterdir():
if not example.is_dir():
continue
with open(example / "config.json") as f:
example_dict = json.load(f)
example_list = []
example_list.append(example_dict["useage"]) # case for
example_list.append(example_dict["prompt"]) # prompt
for key in ["image_ref1", "image_ref2", "image_ref3", "image_ref4"]:
if key in example_dict:
example_list.append(str(example / example_dict[key]))
else:
example_list.append(None)
example_list.append(example_dict["seed"])
ans.append(example_list)
return ans
def create_demo(
model_type: str,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
offload: bool = False,
):
pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512)
pipeline.gradio_generate = spaces.GPU(duratioin=120)(pipeline.gradio_generate)
badges_text = r"""
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
<a href="https://github.com/bytedance/UNO"><img alt="Build" src="https://img.shields.io/github/stars/bytedance/UNO"></a>
<a href="https://bytedance.github.io/UNO/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-UNO-yellow"></a>
<a href="https://arxiv.org/abs/2504.02160"><img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-UNO-b31b1b.svg"></a>
<a href="https://huggingface.co/bytedance-research/UNO"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Model&color=orange"></a>
<a href="https://huggingface.co/spaces/bytedance-research/UNO-FLUX"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=demo&color=orange"></a>
</div>
""".strip()
with gr.Blocks() as demo:
gr.Markdown(f"# UNO by UNO 팀")
gr.Markdown(badges_text)
gr.Markdown("""
## UNO (Unified Numerous Objects) 저장소에 오신 것을 환영합니다!
UNO는 Bytedance에서 개발한 최첨단 이미지 생성 모델입니다. 이 데모 페이지에서는 **UNO-FLUX** 모델을 사용하여 텍스트 프롬프트와 최대 4개의 참조 이미지를 기반으로 이미지를 생성할 수 있습니다.
**주요 기능:**
* **텍스트-이미지 변환:** 입력한 텍스트 설명을 바탕으로 이미지를 생성합니다.
* **참조 이미지 활용:** 하나 이상의 참조 이미지를 제공하여 생성될 이미지의 스타일, 객체, 또는 구성을 제어할 수 있습니다.
* **다양한 옵션:** 생성 이미지의 크기, 스텝 수, 가이던스 강도 등을 조절하여 원하는 결과를 얻을 수 있습니다.
위의 뱃지들을 클릭하여 GitHub 저장소, 프로젝트 페이지, 관련 논문, Hugging Face 모델 및 데모 페이지로 이동할 수 있습니다.
""")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="프롬프트", value="도시 속 잘생긴 여자")
with gr.Row():
image_prompt1 = gr.Image(label="참조 이미지 1", visible=True, interactive=True, type="pil")
image_prompt2 = gr.Image(label="참조 이미지 2", visible=True, interactive=True, type="pil")
image_prompt3 = gr.Image(label="참조 이미지 3", visible=True, interactive=True, type="pil")
image_prompt4 = gr.Image(label="참조 이미지 4", visible=True, interactive=True, type="pil")
with gr.Row():
with gr.Column():
width = gr.Slider(512, 2048, 512, step=16, label="생성 넓이")
height = gr.Slider(512, 2048, 512, step=16, label="생성 높이")
with gr.Column():
gr.Markdown("📌 모델은 512x512 해상도에서 학습되었습니다.\n512에 가까운 크기가 더 안정적이며, 더 높은 크기는 시각적 효과는 좋지만 안정성은 떨어집니다.")
with gr.Accordion("고급 옵션", open=False):
with gr.Row():
num_steps = gr.Slider(1, 50, 25, step=1, label="스텝 수")
guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="가이던스", interactive=True)
seed = gr.Number(-1, label="시드 (-1이면 무작위)")
generate_btn = gr.Button("생성하기")
with gr.Column():
output_image = gr.Image(label="생성된 이미지")
download_btn = gr.File(label="전체 해상도 다운로드", type="filepath", interactive=False)
inputs = [
prompt, width, height, guidance, num_steps,
seed, image_prompt1, image_prompt2, image_prompt3, image_prompt4
]
generate_btn.click(
fn=pipeline.gradio_generate,
inputs=inputs,
outputs=[output_image, download_btn],
)
example_text = gr.Text("", visible=False, label="사용 사례:")
examples = get_examples("./assets/examples")
gr.Examples(
examples=examples,
inputs=[
example_text, prompt,
image_prompt1, image_prompt2, image_prompt3, image_prompt4,
seed, output_image
],
)
return demo
if __name__ == "__main__":
from typing import Literal
from transformers import HfArgumentParser
@dataclasses.dataclass
class AppArgs:
name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
device: Literal["cuda", "cpu"] = "cuda" if torch.cuda.is_available() else "cpu"
offload: bool = dataclasses.field(
default=False,
metadata={"help": "If True, sequantial offload the models(ae, dit, text encoder) to CPU if not used."}
)
port: int = 7860
parser = HfArgumentParser([AppArgs])
args_tuple = parser.parse_args_into_dataclasses() # type: tuple[AppArgs]
args = args_tuple[0]
demo = create_demo(args.name, args.device, args.offload)
demo.launch(server_port=args.port, ssr_mode=False)