Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
from diffusers import HiDreamImagePipeline | |
from transformers import PreTrainedTokenizerFast, LlamaForCausalLM | |
import random | |
import numpy as np | |
# Set data type | |
dtype = torch.bfloat16 | |
device = "cpu" # Use CPU for model loading to avoid CUDA initialization | |
# Load tokenizer and text encoder for Llama | |
try: | |
tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") | |
text_encoder_4 = LlamaForCausalLM.from_pretrained( | |
"meta-llama/Meta-Llama-3.1-8B-Instruct", | |
output_hidden_states=True, | |
output_attentions=True, | |
attn_implementation="eager", | |
torch_dtype=dtype, | |
).to(device) | |
except Exception as e: | |
raise Exception(f"Failed to load Llama model: {e}. Ensure you have access to 'meta-llama/Meta-Llama-3.1-8B-Instruct' and are logged in via `huggingface-cli login`.") | |
# Load the HiDreamImagePipeline | |
try: | |
pipe = HiDreamImagePipeline.from_pretrained( | |
"HiDream-ai/HiDream-I1-Fast", | |
tokenizer_4=tokenizer_4, | |
text_encoder_4=text_encoder_4, | |
torch_dtype=dtype, | |
).to(device) | |
pipe.enable_model_cpu_offload() # Offload to CPU, automatically manages GPU placement | |
except Exception as e: | |
raise Exception(f"Failed to load HiDreamImagePipeline: {e}. Ensure you have access to 'HiDream-ai/HiDream-I1-Full'.") | |
# Define maximum values | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 2048 | |
# Inference function with GPU access | |
def infer(prompt, negative_prompt="", seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=16, guidance_scale=3.5, progress=gr.Progress(track_tqdm=True)): | |
pipe.to("cuda") | |
try: | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator("cuda").manual_seed(seed) | |
# Generate the image; offloading handles device placement | |
image = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
height=height, | |
width=width, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
output_type="pil", | |
).images[0] | |
return image, seed | |
finally: | |
# Clear GPU memory | |
torch.cuda.empty_cache() | |
# Define examples | |
examples = [ | |
["A cat holding a sign that says \"Hi-Dreams.ai\".", ""], | |
["A futuristic cityscape with flying cars.", "blurry, low quality"], | |
["A serene landscape with mountains and a lake.", ""], | |
] | |
# CSS styling | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 960px; | |
} | |
.generate-btn { | |
background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important; | |
border: none !important; | |
color: white !important; | |
} | |
.generate-btn:hover { | |
transform: translateY2px); | |
box-shadow: 0 5px 15px rgba(0,0,0,0.2); | |
} | |
""" | |
# Create Gradio interface | |
with gr.Blocks(css=css) as app: | |
gr.HTML("<center><h1>HiDreamImage Generator</h1></center>") | |
with gr.Column(elem_id="col-container"): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
text_prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="Enter a prompt here", | |
lines=3, | |
elem_id="prompt-text-input" | |
) | |
with gr.Row(): | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
placeholder="Enter what to avoid (optional)", | |
lines=2 | |
) | |
with gr.Row(): | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
value=1024, | |
minimum=64, | |
maximum=MAX_IMAGE_SIZE, | |
step=8 | |
) | |
height = gr.Slider( | |
label="Height", | |
value=1024, | |
minimum=64, | |
maximum=MAX_IMAGE_SIZE, | |
step=8 | |
) | |
with gr.Row(): | |
steps = gr.Slider( | |
label="Inference Steps", | |
value=16, | |
minimum=1, | |
maximum=100, | |
step=1 | |
) | |
cfg = gr.Slider( | |
label="Guidance Scale", | |
value=3.5, | |
minimum=1, | |
maximum=20, | |
step=0.5 | |
) | |
with gr.Row(): | |
seed = gr.Slider( | |
label="Seed", | |
value=42, | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1 | |
) | |
randomize_seed = gr.Checkbox( | |
label="Randomize Seed", | |
value=True | |
) | |
with gr.Row(): | |
text_button = gr.Button( | |
"✨ Generate Image", | |
variant='primary', | |
elem_classes=["generate-btn"] | |
) | |
with gr.Column(): | |
with gr.Row(): | |
image_output = gr.Image( | |
type="pil", | |
label="Generated Image", | |
elem_id="gallery" | |
) | |
seed_output = gr.Textbox( | |
label="Seed Used", | |
interactive=False | |
) | |
with gr.Column(): | |
gr.Examples( | |
examples=examples, | |
inputs=[text_prompt, negative_prompt], | |
) | |
# Connect the button and textbox submit to the inference function | |
gr.on( | |
triggers=[text_button.click, text_prompt.submit], | |
fn=infer, | |
inputs=[text_prompt, negative_prompt, seed, randomize_seed, width, height, steps, cfg], | |
outputs=[image_output, seed_output] | |
) | |
app.launch(share=True) |