|
import torch |
|
import gradio as gr |
|
from hi_diffusers import HiDreamImagePipeline |
|
from hi_diffusers import HiDreamImageTransformer2DModel |
|
from hi_diffusers.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler |
|
from hi_diffusers.schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler |
|
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast |
|
|
|
MODEL_PREFIX = "HiDream-ai" |
|
LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct" |
|
|
|
|
|
MODEL_CONFIGS = { |
|
"dev": { |
|
"path": f"{MODEL_PREFIX}/HiDream-I1-Dev", |
|
"guidance_scale": 0.0, |
|
"num_inference_steps": 28, |
|
"shift": 6.0, |
|
"scheduler": FlashFlowMatchEulerDiscreteScheduler |
|
}, |
|
"full": { |
|
"path": f"{MODEL_PREFIX}/HiDream-I1-Full", |
|
"guidance_scale": 5.0, |
|
"num_inference_steps": 50, |
|
"shift": 3.0, |
|
"scheduler": FlowUniPCMultistepScheduler |
|
}, |
|
"fast": { |
|
"path": f"{MODEL_PREFIX}/HiDream-I1-Fast", |
|
"guidance_scale": 0.0, |
|
"num_inference_steps": 16, |
|
"shift": 3.0, |
|
"scheduler": FlashFlowMatchEulerDiscreteScheduler |
|
} |
|
} |
|
|
|
|
|
RESOLUTION_OPTIONS = [ |
|
"1024 × 1024 (Square)", |
|
"768 × 1360 (Portrait)", |
|
"1360 × 768 (Landscape)", |
|
"880 × 1168 (Portrait)", |
|
"1168 × 880 (Landscape)", |
|
"1248 × 832 (Landscape)", |
|
"832 × 1248 (Portrait)" |
|
] |
|
|
|
|
|
def load_models(model_type): |
|
config = MODEL_CONFIGS[model_type] |
|
pretrained_model_name_or_path = config["path"] |
|
scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=1000, shift=config["shift"], use_dynamic_shifting=False) |
|
|
|
tokenizer_4 = PreTrainedTokenizerFast.from_pretrained( |
|
LLAMA_MODEL_NAME, |
|
use_fast=False) |
|
|
|
text_encoder_4 = LlamaForCausalLM.from_pretrained( |
|
LLAMA_MODEL_NAME, |
|
output_hidden_states=True, |
|
output_attentions=True, |
|
torch_dtype=torch.bfloat16).to("cuda") |
|
|
|
transformer = HiDreamImageTransformer2DModel.from_pretrained( |
|
pretrained_model_name_or_path, |
|
subfolder="transformer", |
|
torch_dtype=torch.bfloat16).to("cuda") |
|
|
|
pipe = HiDreamImagePipeline.from_pretrained( |
|
pretrained_model_name_or_path, |
|
scheduler=scheduler, |
|
tokenizer_4=tokenizer_4, |
|
text_encoder_4=text_encoder_4, |
|
torch_dtype=torch.bfloat16 |
|
).to("cuda", torch.bfloat16) |
|
pipe.transformer = transformer |
|
|
|
return pipe, config |
|
|
|
|
|
def parse_resolution(resolution_str): |
|
if "1024 × 1024" in resolution_str: |
|
return 1024, 1024 |
|
elif "768 × 1360" in resolution_str: |
|
return 768, 1360 |
|
elif "1360 × 768" in resolution_str: |
|
return 1360, 768 |
|
elif "880 × 1168" in resolution_str: |
|
return 880, 1168 |
|
elif "1168 × 880" in resolution_str: |
|
return 1168, 880 |
|
elif "1248 × 832" in resolution_str: |
|
return 1248, 832 |
|
elif "832 × 1248" in resolution_str: |
|
return 832, 1248 |
|
else: |
|
return 1024, 1024 |
|
|
|
|
|
def generate_image(model_type, prompt, resolution, seed): |
|
global pipe, current_model |
|
|
|
|
|
config = MODEL_CONFIGS[model_type] |
|
guidance_scale = config["guidance_scale"] |
|
num_inference_steps = config["num_inference_steps"] |
|
|
|
|
|
height, width = parse_resolution(resolution) |
|
|
|
|
|
if seed == -1: |
|
seed = torch.randint(0, 1000000, (1,)).item() |
|
|
|
generator = torch.Generator("cuda").manual_seed(seed) |
|
|
|
images = pipe( |
|
prompt, |
|
height=height, |
|
width=width, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
num_images_per_prompt=1, |
|
generator=generator |
|
).images |
|
|
|
return images[0], seed |
|
|
|
|
|
print("Loading default model (full)...") |
|
current_model = "fast" |
|
pipe, _ = load_models(current_model) |
|
print("Model loaded successfully!") |
|
|
|
|
|
with gr.Blocks(title="HiDream Image Generator") as demo: |
|
gr.Markdown("# HiDream Image Generator") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
model_type = gr.Radio( |
|
choices=list(MODEL_CONFIGS.keys()), |
|
value="full", |
|
label="Model Type", |
|
info="Select model variant" |
|
) |
|
|
|
prompt = gr.Textbox( |
|
label="Prompt", |
|
placeholder="A cat holding a sign that says \"Hi-Dreams.ai\".", |
|
lines=3 |
|
) |
|
|
|
resolution = gr.Radio( |
|
choices=RESOLUTION_OPTIONS, |
|
value=RESOLUTION_OPTIONS[0], |
|
label="Resolution", |
|
info="Select image resolution" |
|
) |
|
|
|
seed = gr.Number( |
|
label="Seed (use -1 for random)", |
|
value=-1, |
|
precision=0 |
|
) |
|
|
|
generate_btn = gr.Button("Generate Image") |
|
seed_used = gr.Number(label="Seed Used", interactive=False) |
|
|
|
with gr.Column(): |
|
output_image = gr.Image(label="Generated Image", type="pil") |
|
|
|
generate_btn.click( |
|
fn=generate_image, |
|
inputs=[model_type, prompt, resolution, seed], |
|
outputs=[output_image, seed_used] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|