ovi054 commited on
Commit
0b7c365
·
verified ·
1 Parent(s): 2e4ba2a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -0
app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import HiDreamImagePipeline
4
+ from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
5
+ import random
6
+ import spaces
7
+ import numpy as np
8
+
9
+ # Set data type
10
+ dtype = torch.bfloat16
11
+ device = "cpu" # Initial device for model loading; inference will use GPU
12
+
13
+ # Load tokenizer and text encoder for Llama
14
+ try:
15
+ tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
16
+ text_encoder_4 = LlamaForCausalLM.from_pretrained(
17
+ "meta-llama/Meta-Llama-3.1-8B-Instruct",
18
+ output_hidden_states=True,
19
+ output_attentions=True,
20
+ torch_dtype=dtype,
21
+ ).to(device)
22
+ except Exception as e:
23
+ raise Exception(f"Failed to load Llama model: {e}. Ensure you have access to 'meta-llama/Meta-Llama-3.1-8B-Instruct' and are logged in via `huggingface-cli login`.")
24
+
25
+ # Load the HiDreamImagePipeline
26
+ try:
27
+ pipe = HiDreamImagePipeline.from_pretrained(
28
+ "HiDream-ai/HiDream-I1-Full",
29
+ tokenizer_4=tokenizer_4,
30
+ text_encoder_4=text_encoder_4,
31
+ torch_dtype=dtype,
32
+ ).to(device)
33
+ pipe.enable_model_cpu_offload() # Offload to CPU when not in use, critical for Spaces
34
+ except Exception as e:
35
+ raise Exception(f"Failed to load HiDreamImagePipeline: {e}. Ensure you have access to 'HiDream-ai/HiDream-I1-Full'.")
36
+
37
+ # Define maximum values
38
+ MAX_SEED = np.iinfo(np.int32).max
39
+ MAX_IMAGE_SIZE = 2048
40
+
41
+ # Inference function with GPU access
42
+ @spaces.GPU()
43
+ def infer(prompt, negative_prompt="", seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=5.0, progress=gr.Progress(track_tqdm=True)):
44
+ # Ensure the model is on GPU for inference
45
+ pipe.to("cuda")
46
+
47
+ try:
48
+ if randomize_seed:
49
+ seed = random.randint(0, MAX_SEED)
50
+ generator = torch.Generator("cuda").manual_seed(seed)
51
+
52
+ # Generate the image
53
+ image = pipe(
54
+ prompt=prompt,
55
+ negative_prompt=negative_prompt,
56
+ height=height,
57
+ width=width,
58
+ num_inference_steps=num_inference_steps,
59
+ guidance_scale=guidance_scale,
60
+ generator=generator,
61
+ output_type="pil",
62
+ ).images[0]
63
+
64
+ # Clear GPU memory
65
+ torch.cuda.empty_cache()
66
+
67
+ return image, seed
68
+ finally:
69
+ # Move model back to CPU to free GPU memory
70
+ pipe.to("cpu")
71
+ torch.cuda.empty_cache()
72
+
73
+ # Define examples
74
+ examples = [
75
+ ["A cat holding a sign that says \"Hi-Dreams.ai\".", ""],
76
+ ["A futuristic cityscape with flying cars.", "blurry, low quality"],
77
+ ["A serene landscape with mountains and a lake.", ""],
78
+ ]
79
+
80
+ # CSS styling
81
+ css = """
82
+ #col-container {
83
+ margin: 0 auto;
84
+ max-width: 960px;
85
+ }
86
+ .generate-btn {
87
+ background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important;
88
+ border: none !important;
89
+ color: white !important;
90
+ }
91
+ .generate-btn:hover {
92
+ transform: translateY(-2px);
93
+ box-shadow: 0 5px 15px rgba(0,0,0,0.2);
94
+ }
95
+ """
96
+
97
+ # Create Gradio interface
98
+ with gr.Blocks(css=css) as app:
99
+ gr.HTML("<center><h1>HiDreamImage Generator</h1></center>")
100
+ with gr.Column(elem_id="col-container"):
101
+ with gr.Row():
102
+ with gr.Column():
103
+ with gr.Row():
104
+ text_prompt = gr.Textbox(
105
+ label="Prompt",
106
+ placeholder="Enter a prompt here",
107
+ lines=3,
108
+ elem_id="prompt-text-input"
109
+ )
110
+ with gr.Row():
111
+ with gr.Accordion("Advanced Settings", open=False):
112
+ with gr.Row():
113
+ width = gr.Slider(
114
+ label="Width",
115
+ value=1024,
116
+ minimum=64,
117
+ maximum=MAX_IMAGE_SIZE,
118
+ step=8
119
+ )
120
+ height = gr.Slider(
121
+ label="Height",
122
+ value=1024,
123
+ minimum=64,
124
+ maximum=MAX_IMAGE_SIZE,
125
+ step=8
126
+ )
127
+ with gr.Row():
128
+ steps = gr.Slider(
129
+ label="Inference Steps",
130
+ value=50,
131
+ minimum=1,
132
+ maximum=100,
133
+ step=1
134
+ )
135
+ cfg = gr.Slider(
136
+ label="Guidance Scale",
137
+ value=5.0,
138
+ minimum=1,
139
+ maximum=20,
140
+ step=0.5
141
+ )
142
+ with gr.Row():
143
+ seed = gr.Slider(
144
+ label="Seed",
145
+ value=42,
146
+ minimum=0,
147
+ maximum=MAX_SEED,
148
+ step=1
149
+ )
150
+ randomize_seed = gr.Checkbox(
151
+ label="Randomize Seed",
152
+ value=True
153
+ )
154
+ with gr.Row():
155
+ negative_prompt = gr.Textbox(
156
+ label="Negative Prompt",
157
+ placeholder="Enter what to avoid (optional)",
158
+ lines=2
159
+ )
160
+ with gr.Row():
161
+ text_button = gr.Button(
162
+ "✨ Generate Image",
163
+ variant='primary',
164
+ elem_classes=["generate-btn"]
165
+ )
166
+ with gr.Column():
167
+ with gr.Row():
168
+ image_output = gr.Image(
169
+ type="pil",
170
+ label="Generated Image",
171
+ elem_id="gallery"
172
+ )
173
+ seed_output = gr.Textbox(
174
+ label="Seed Used",
175
+ interactive=False
176
+ )
177
+
178
+ with gr.Column():
179
+ gr.Examples(
180
+ examples=examples,
181
+ inputs=[text_prompt, negative_prompt],
182
+ )
183
+
184
+ # Connect the button and textbox submit to the inference function
185
+ gr.on(
186
+ triggers=[text_button.click, text_prompt.submit],
187
+ fn=infer,
188
+ inputs=[text_prompt, negative_prompt, seed, randomize_seed, width, height, steps, cfg],
189
+ outputs=[image_output, seed_output]
190
+ )
191
+
192
+ # Launch the app
193
+ app.launch(share=True)