HiDream-I1-Dev / app.py
ovi054's picture
Update app.py
d4b782a verified
import gradio as gr
import spaces
import torch
from diffusers import HiDreamImagePipeline
from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
import random
import numpy as np
# Set data type
dtype = torch.bfloat16
device = "cpu" # Use CPU for model loading to avoid CUDA initialization
# Load tokenizer and text encoder for Llama
try:
tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
text_encoder_4 = LlamaForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3.1-8B-Instruct",
output_hidden_states=True,
output_attentions=True,
attn_implementation="eager",
torch_dtype=dtype,
).to(device)
except Exception as e:
raise Exception(f"Failed to load Llama model: {e}. Ensure you have access to 'meta-llama/Meta-Llama-3.1-8B-Instruct' and are logged in via `huggingface-cli login`.")
# Load the HiDreamImagePipeline
try:
pipe = HiDreamImagePipeline.from_pretrained(
"HiDream-ai/HiDream-I1-Fast",
tokenizer_4=tokenizer_4,
text_encoder_4=text_encoder_4,
torch_dtype=dtype,
).to(device)
pipe.enable_model_cpu_offload() # Offload to CPU, automatically manages GPU placement
except Exception as e:
raise Exception(f"Failed to load HiDreamImagePipeline: {e}. Ensure you have access to 'HiDream-ai/HiDream-I1-Full'.")
# Define maximum values
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
# Inference function with GPU access
@spaces.GPU()
def infer(prompt, negative_prompt="", seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=16, guidance_scale=3.5, progress=gr.Progress(track_tqdm=True)):
pipe.to("cuda")
try:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator("cuda").manual_seed(seed)
# Generate the image; offloading handles device placement
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
output_type="pil",
).images[0]
return image, seed
finally:
# Clear GPU memory
torch.cuda.empty_cache()
# Define examples
examples = [
["A cat holding a sign that says \"Hi-Dreams.ai\".", ""],
["A futuristic cityscape with flying cars.", "blurry, low quality"],
["A serene landscape with mountains and a lake.", ""],
]
# CSS styling
css = """
#col-container {
margin: 0 auto;
max-width: 960px;
}
.generate-btn {
background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important;
border: none !important;
color: white !important;
}
.generate-btn:hover {
transform: translateY2px);
box-shadow: 0 5px 15px rgba(0,0,0,0.2);
}
"""
# Create Gradio interface
with gr.Blocks(css=css) as app:
gr.HTML("<center><h1>HiDreamImage Generator</h1></center>")
with gr.Column(elem_id="col-container"):
with gr.Row():
with gr.Column():
with gr.Row():
text_prompt = gr.Textbox(
label="Prompt",
placeholder="Enter a prompt here",
lines=3,
elem_id="prompt-text-input"
)
with gr.Row():
negative_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Enter what to avoid (optional)",
lines=2
)
with gr.Row():
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
width = gr.Slider(
label="Width",
value=1024,
minimum=64,
maximum=MAX_IMAGE_SIZE,
step=8
)
height = gr.Slider(
label="Height",
value=1024,
minimum=64,
maximum=MAX_IMAGE_SIZE,
step=8
)
with gr.Row():
steps = gr.Slider(
label="Inference Steps",
value=16,
minimum=1,
maximum=100,
step=1
)
cfg = gr.Slider(
label="Guidance Scale",
value=3.5,
minimum=1,
maximum=20,
step=0.5
)
with gr.Row():
seed = gr.Slider(
label="Seed",
value=42,
minimum=0,
maximum=MAX_SEED,
step=1
)
randomize_seed = gr.Checkbox(
label="Randomize Seed",
value=True
)
with gr.Row():
text_button = gr.Button(
"✨ Generate Image",
variant='primary',
elem_classes=["generate-btn"]
)
with gr.Column():
with gr.Row():
image_output = gr.Image(
type="pil",
label="Generated Image",
elem_id="gallery"
)
seed_output = gr.Textbox(
label="Seed Used",
interactive=False
)
with gr.Column():
gr.Examples(
examples=examples,
inputs=[text_prompt, negative_prompt],
)
# Connect the button and textbox submit to the inference function
gr.on(
triggers=[text_button.click, text_prompt.submit],
fn=infer,
inputs=[text_prompt, negative_prompt, seed, randomize_seed, width, height, steps, cfg],
outputs=[image_output, seed_output]
)
app.launch(share=True)