import gradio as gr import spaces import torch from diffusers import HiDreamImagePipeline from transformers import PreTrainedTokenizerFast, LlamaForCausalLM import random import numpy as np # Set data type dtype = torch.bfloat16 device = "cpu" # Use CPU for model loading to avoid CUDA initialization # Load tokenizer and text encoder for Llama try: tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") text_encoder_4 = LlamaForCausalLM.from_pretrained( "meta-llama/Meta-Llama-3.1-8B-Instruct", output_hidden_states=True, output_attentions=True, attn_implementation="eager", torch_dtype=dtype, ).to(device) except Exception as e: raise Exception(f"Failed to load Llama model: {e}. Ensure you have access to 'meta-llama/Meta-Llama-3.1-8B-Instruct' and are logged in via `huggingface-cli login`.") # Load the HiDreamImagePipeline try: pipe = HiDreamImagePipeline.from_pretrained( "HiDream-ai/HiDream-I1-Fast", tokenizer_4=tokenizer_4, text_encoder_4=text_encoder_4, torch_dtype=dtype, ).to(device) pipe.enable_model_cpu_offload() # Offload to CPU, automatically manages GPU placement except Exception as e: raise Exception(f"Failed to load HiDreamImagePipeline: {e}. Ensure you have access to 'HiDream-ai/HiDream-I1-Full'.") # Define maximum values MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 2048 # Inference function with GPU access @spaces.GPU() def infer(prompt, negative_prompt="", seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=16, guidance_scale=3.5, progress=gr.Progress(track_tqdm=True)): pipe.to("cuda") try: if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator("cuda").manual_seed(seed) # Generate the image; offloading handles device placement image = pipe( prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator, output_type="pil", ).images[0] return image, seed finally: # Clear GPU memory torch.cuda.empty_cache() # Define examples examples = [ ["A cat holding a sign that says \"Hi-Dreams.ai\".", ""], ["A futuristic cityscape with flying cars.", "blurry, low quality"], ["A serene landscape with mountains and a lake.", ""], ] # CSS styling css = """ #col-container { margin: 0 auto; max-width: 960px; } .generate-btn { background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important; border: none !important; color: white !important; } .generate-btn:hover { transform: translateY2px); box-shadow: 0 5px 15px rgba(0,0,0,0.2); } """ # Create Gradio interface with gr.Blocks(css=css) as app: gr.HTML("

HiDreamImage Generator

") with gr.Column(elem_id="col-container"): with gr.Row(): with gr.Column(): with gr.Row(): text_prompt = gr.Textbox( label="Prompt", placeholder="Enter a prompt here", lines=3, elem_id="prompt-text-input" ) with gr.Row(): negative_prompt = gr.Textbox( label="Negative Prompt", placeholder="Enter what to avoid (optional)", lines=2 ) with gr.Row(): with gr.Accordion("Advanced Settings", open=False): with gr.Row(): width = gr.Slider( label="Width", value=1024, minimum=64, maximum=MAX_IMAGE_SIZE, step=8 ) height = gr.Slider( label="Height", value=1024, minimum=64, maximum=MAX_IMAGE_SIZE, step=8 ) with gr.Row(): steps = gr.Slider( label="Inference Steps", value=16, minimum=1, maximum=100, step=1 ) cfg = gr.Slider( label="Guidance Scale", value=3.5, minimum=1, maximum=20, step=0.5 ) with gr.Row(): seed = gr.Slider( label="Seed", value=42, minimum=0, maximum=MAX_SEED, step=1 ) randomize_seed = gr.Checkbox( label="Randomize Seed", value=True ) with gr.Row(): text_button = gr.Button( "✨ Generate Image", variant='primary', elem_classes=["generate-btn"] ) with gr.Column(): with gr.Row(): image_output = gr.Image( type="pil", label="Generated Image", elem_id="gallery" ) seed_output = gr.Textbox( label="Seed Used", interactive=False ) with gr.Column(): gr.Examples( examples=examples, inputs=[text_prompt, negative_prompt], ) # Connect the button and textbox submit to the inference function gr.on( triggers=[text_button.click, text_prompt.submit], fn=infer, inputs=[text_prompt, negative_prompt, seed, randomize_seed, width, height, steps, cfg], outputs=[image_output, seed_output] ) app.launch(share=True)