# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # app.py # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import json from pathlib import Path import gradio as gr import torch import openai import os from uno.flux.pipeline import UNOPipeline from uno.utils.prompt_enhancer import enhance_prompt_with_chatgpt os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" openai.api_key = os.getenv("OPENAI_API_KEY") from huggingface_hub import login login(token=os.getenv("HUGGINGFACE_TOKEN")) def get_examples(examples_dir: str = "assets/examples") -> list: examples = Path(examples_dir) ans = [] for example in examples.iterdir(): if not example.is_dir(): continue with open(example / "config.json") as f: example_dict = json.load(f) example_list = [example_dict["useage"], example_dict["prompt"]] for key in ["image_ref1", "image_ref2", "image_ref3", "image_ref4"]: example_list.append(str(example / example_dict[key]) if key in example_dict else None) example_list.append(example_dict["seed"]) ans.append(example_list) return ans def create_demo(model_type: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False): pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512) with gr.Blocks() as demo: gr.Markdown("# UNO by UNO team") gr.Markdown( """
Build Build Build
""" ) with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", value="handsome woman in the city") with gr.Row(): image_prompt1 = gr.Image(label="Ref Img1", type="pil") image_prompt2 = gr.Image(label="Ref Img2", type="pil") image_prompt3 = gr.Image(label="Ref Img3", type="pil") image_prompt4 = gr.Image(label="Ref Img4", type="pil") with gr.Row(): with gr.Column(): width = gr.Slider(512, 2048, 512, step=16, label="Generation Width") height = gr.Slider(512, 2048, 512, step=16, label="Generation Height") with gr.Column(): gr.Markdown("๐Ÿ“Œ Trained on 512x512. Larger size = better quality, but less stable.") with gr.Accordion("Advanced Options", open=False): with gr.Row(): num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps") guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance") seed = gr.Number(-1, label="Seed (-1 for random)") num_outputs = gr.Slider(1, 5, 5, step=1, label="Number of Enhanced Prompts / Images") generate_btn = gr.Button("Generate Enhanced Images") with gr.Column(): outputs = [] for i in range(5): outputs.append(gr.Image(label=f"Image {i+1}")) outputs.append(gr.Textbox(label=f"Enhanced Prompt {i+1}")) def run_generation(prompt, width, height, guidance, num_steps, seed, img1, img2, img3, img4, num_outputs): uploaded_images = [img for img in [img1, img2, img3, img4] if img is not None] print(f"\n๐Ÿ“ฅ [DEBUG] User prompt: {prompt}") prompts = enhance_prompt_with_chatgpt( user_prompt=prompt, num_prompts=num_outputs, reference_images=uploaded_images ) print(f"\n๐Ÿง  [DEBUG] Final Prompt List (len={len(prompts)}):") for idx, p in enumerate(prompts): print(f" [{idx+1}] {p}") while len(prompts) < num_outputs: prompts.append(prompt) results = [] for i in range(num_outputs): try: seed_val = int(seed) if seed != -1 else torch.randint(0, 10**8, (1,)).item() print(f"๐Ÿงช [DEBUG] Using seed: {seed_val} for image {i+1}") gen_image, _ = pipeline.gradio_generate( prompt=prompts[i], width=width, height=height, guidance=guidance, num_steps=num_steps, seed=seed_val, image_prompt1=img1, image_prompt2=img2, image_prompt3=img3, image_prompt4=img4, ) print(f"โœ… [DEBUG] Image {i+1} generated using prompt: {prompts[i]}") results.append(gen_image) results.append(prompts[i]) except Exception as e: print(f"โŒ [ERROR] Failed to generate image {i+1}: {e}") results.append(None) results.append(f"โš ๏ธ Failed to generate: {e}") # Pad to 10 outputs: 5 image + prompt pairs while len(results) < 10: results.append(None if len(results) % 2 == 0 else "") return results generate_btn.click( fn=run_generation, inputs=[ prompt, width, height, guidance, num_steps, seed, image_prompt1, image_prompt2, image_prompt3, image_prompt4, num_outputs ], outputs=outputs ) example_text = gr.Text("", visible=False, label="Case For:") examples = get_examples("./assets/examples") gr.Examples( examples=examples, inputs=[ example_text, prompt, image_prompt1, image_prompt2, image_prompt3, image_prompt4, seed, outputs[0] ], ) return demo if __name__ == "__main__": from typing import Literal from transformers import HfArgumentParser @dataclasses.dataclass class AppArgs: name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev" device: Literal["cuda", "cpu"] = "cuda" if torch.cuda.is_available() else "cpu" offload: bool = dataclasses.field( default=False, metadata={"help": "If True, sequentially offload unused models to CPU"} ) port: int = 7860 parser = HfArgumentParser([AppArgs]) args = parser.parse_args_into_dataclasses()[0] demo = create_demo(args.name, args.device, args.offload) demo.launch(server_port=args.port)