chengzeyi commited on
Commit
ccb88d2
·
verified ·
1 Parent(s): 4c35428

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +7 -5
  2. app.py +592 -0
  3. requirements.txt +4 -0
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
- title: Hidream Arena
3
- emoji: 📚
4
- colorFrom: pink
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.25.0
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: HiDream Arena
3
+ emoji: 🖼
4
+ colorFrom: purple
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.0.1
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
+ short_description: Arena for HiDream-I1-dev / -full and FLUX-dev
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import asyncio
3
+ import aiohttp
4
+ import time
5
+ from datetime import datetime
6
+ import plotly.graph_objects as go
7
+ from typing import Dict, List
8
+ import os
9
+ from dotenv import load_dotenv
10
+ import json
11
+ from PIL import Image, ImageDraw, ImageFont
12
+ import uuid
13
+ import threading
14
+
15
+ # Load environment variables first
16
+ load_dotenv()
17
+
18
+ # Constants
19
+ API_BASE_URL = "https://api.wavespeed.ai/api/v2"
20
+ API_KEY = os.getenv("WAVESPEED_API_KEY") # Move API_KEY to global scope
21
+
22
+ if not API_KEY:
23
+ raise ValueError("WAVESPEED_API_KEY not found in environment variables")
24
+
25
+ # Rest of constants
26
+ BACKENDS = {
27
+ "flux-dev": {
28
+ "endpoint": f"{API_BASE_URL}/wavespeed-ai/flux-dev-ultra-fast",
29
+ "name": "Flux-dev",
30
+ "color": "#FF9800",
31
+ },
32
+ "hidream-dev": {
33
+ "endpoint": f"{API_BASE_URL}/wavespeed-ai/hidream-i1-dev",
34
+ "name": "HiDream-dev",
35
+ "color": "#2196F3",
36
+ },
37
+ "hidream-full": {
38
+ "endpoint": f"{API_BASE_URL}/wavespeed-ai/hidream-i1-full",
39
+ "name": "HiDream-full",
40
+ "color": "#4CAF50",
41
+ },
42
+ }
43
+
44
+
45
+ class BackendStatus:
46
+
47
+ def __init__(self):
48
+ self.reset()
49
+ self.history: List[Dict] = []
50
+
51
+ def reset(self):
52
+ self.status = "idle"
53
+ self.progress = 0
54
+ self.start_time = None
55
+ self.end_time = None
56
+
57
+ def start(self):
58
+ self.status = "processing"
59
+ self.progress = 0
60
+ self.start_time = time.time()
61
+ self.end_time = None
62
+
63
+ def complete(self):
64
+ self.status = "completed"
65
+ self.progress = 100
66
+ self.end_time = time.time()
67
+ self.history.append({
68
+ "timestamp": datetime.now(),
69
+ "duration": self.end_time - self.start_time
70
+ })
71
+
72
+ def fail(self):
73
+ self.status = "failed"
74
+ self.end_time = time.time()
75
+
76
+
77
+ class SessionManager:
78
+ _instances = {}
79
+ _lock = threading.Lock()
80
+
81
+ @classmethod
82
+ def get_manager(cls, session_id=None):
83
+ if session_id is None:
84
+ session_id = str(uuid.uuid4())
85
+
86
+ with cls._lock:
87
+ if session_id not in cls._instances:
88
+ cls._instances[session_id] = GenerationManager()
89
+ return session_id, cls._instances[session_id]
90
+
91
+ @classmethod
92
+ def cleanup_old_sessions(cls, max_age=3600): # 1 hour default
93
+ current_time = time.time()
94
+ with cls._lock:
95
+ to_remove = []
96
+ for session_id, manager in cls._instances.items():
97
+ if (hasattr(manager, "last_activity")
98
+ and current_time - manager.last_activity > max_age):
99
+ to_remove.append(session_id)
100
+
101
+ for session_id in to_remove:
102
+ del cls._instances[session_id]
103
+
104
+
105
+ class GenerationManager:
106
+
107
+ def __init__(self):
108
+ self.backend_statuses = {
109
+ backend: BackendStatus()
110
+ for backend in BACKENDS
111
+ }
112
+ self.last_activity = time.time()
113
+
114
+ def update_activity(self):
115
+ self.last_activity = time.time()
116
+
117
+ def get_performance_plot(self):
118
+ fig = go.Figure()
119
+
120
+ has_data = False
121
+
122
+ for backend, status in self.backend_statuses.items():
123
+ durations = [h["duration"] for h in status.history]
124
+ if durations:
125
+ has_data = True
126
+ avg_duration = sum(durations) / len(durations)
127
+ # Use bar chart instead of box plot
128
+ fig.add_trace(
129
+ go.Bar(
130
+ y=[avg_duration], # Average duration
131
+ x=[BACKENDS[backend]["name"]], # Backend name
132
+ name=BACKENDS[backend]["name"],
133
+ marker_color=BACKENDS[backend]["color"],
134
+ text=[f"{avg_duration:.2f}s"], # Show time in seconds
135
+ textposition="auto",
136
+ width=[0.5], # Make bars narrower
137
+ ))
138
+
139
+ # Set a minimum y-axis range if we have data
140
+ if has_data:
141
+ max_duration = max([
142
+ max([h["duration"] for h in status.history] or [0])
143
+ for status in self.backend_statuses.values()
144
+ ])
145
+ # Add 20% padding to the top
146
+ y_max = max_duration * 1.2
147
+ # Ensure the y-axis always starts at 0
148
+ fig.update_yaxes(range=[0, y_max])
149
+
150
+ fig.update_layout(
151
+ title="Average Generation Time",
152
+ yaxis_title="Seconds",
153
+ xaxis_title="",
154
+ showlegend=False,
155
+ template="simple_white",
156
+ height=400, # Increase height
157
+ margin=dict(l=50, r=50, t=50, b=50), # Add margins
158
+ font=dict(size=14), # Larger font
159
+ )
160
+
161
+ # Make sure we have a valid figure even if no data
162
+ if not has_data:
163
+ fig.add_annotation(
164
+ text="No timing data available yet",
165
+ xref="paper",
166
+ yref="paper",
167
+ x=0.5,
168
+ y=0.5,
169
+ showarrow=False,
170
+ font=dict(size=16),
171
+ )
172
+
173
+ return fig
174
+
175
+ async def submit_task(self, backend: str, prompt: str) -> str:
176
+ status = self.backend_statuses[backend]
177
+ status.start()
178
+
179
+ try:
180
+ url = BACKENDS[backend]["endpoint"]
181
+ headers = {
182
+ "Content-Type": "application/json",
183
+ "Authorization": f"Bearer {API_KEY}",
184
+ }
185
+ payload = {
186
+ "prompt": prompt,
187
+ "enable_safety_checker": False,
188
+ "enable_base64_output": True, # Enable base64 output
189
+ "size": "1024*1024",
190
+ "seed": -1,
191
+ }
192
+
193
+ if backend == "flux-dev":
194
+ payload.update({
195
+ "guidance_scale": 3.5,
196
+ "num_images": 1,
197
+ "num_inference_steps": 28,
198
+ "strength": 0.8,
199
+ })
200
+
201
+ print(f"Submitting task to {backend}")
202
+ print(f"URL: {url}")
203
+ print(f"Payload: {json.dumps(payload, indent=2)}")
204
+
205
+ # Use aiohttp instead of requests for async
206
+ async with aiohttp.ClientSession() as session:
207
+ async with session.post(url, headers=headers,
208
+ json=payload) as response:
209
+ if response.status == 200:
210
+ result = await response.json()
211
+ request_id = result["data"]["id"]
212
+ print(
213
+ f"Task submitted successfully. Request ID: {request_id}"
214
+ )
215
+ return request_id
216
+ else:
217
+ text = await response.text()
218
+ raise Exception(
219
+ f"API error: {response.status}, {text}")
220
+
221
+ except Exception as e:
222
+ status.fail()
223
+ raise Exception(f"Failed to submit task: {str(e)}")
224
+
225
+ # Add this method to reset history
226
+ def reset_history(self):
227
+ """Reset history for all backends"""
228
+ for status in self.backend_statuses.values():
229
+ status.history = [] # Clear history data
230
+ return self
231
+
232
+
233
+ # Helper function to create error images as data URIs
234
+ def create_error_image(backend, error_message):
235
+ try:
236
+ import base64
237
+ from io import BytesIO
238
+
239
+ # Create an in-memory image
240
+ img = Image.new("RGB", (512, 512), color="#ffdddd")
241
+ draw = ImageDraw.Draw(img)
242
+ try:
243
+ font = ImageFont.truetype("Arial", 20)
244
+ except:
245
+ font = ImageFont.load_default()
246
+
247
+ # Wrap and draw error message
248
+ words = error_message.split(" ")
249
+ lines = []
250
+ line = ""
251
+ for word in words:
252
+ if len(line + word) < 40:
253
+ line += word + " "
254
+ else:
255
+ lines.append(line)
256
+ line = word + " "
257
+ if line:
258
+ lines.append(line)
259
+
260
+ y_position = 100
261
+ for line in lines:
262
+ draw.text((50, y_position), line, fill="black", font=font)
263
+ y_position += 30
264
+
265
+ # Save to a BytesIO object instead of a file
266
+ buffer = BytesIO()
267
+ img.save(buffer, format="PNG")
268
+ img_bytes = buffer.getvalue()
269
+
270
+ # Convert to base64 and return as data URI
271
+ return f"data:image/png;base64,{base64.b64encode(img_bytes).decode('utf-8')}"
272
+ except Exception as e:
273
+ print(f"Failed to create error image: {e}")
274
+ # Return a simple error message as fallback
275
+ return "Error: " + error_message
276
+
277
+
278
+ # Fix the poll_once function to accept a manager parameter
279
+ async def poll_once(manager, backend, request_id):
280
+ """Poll once and return result if complete, otherwise None"""
281
+ headers = {"Authorization": f"Bearer {API_KEY}"}
282
+ url = f"{API_BASE_URL}/predictions/{request_id}/result"
283
+
284
+ async with aiohttp.ClientSession() as session:
285
+ async with session.get(url, headers=headers) as response:
286
+ if response.status == 200:
287
+ result = await response.json()
288
+ data = result["data"]
289
+ current_status = data["status"]
290
+
291
+ if current_status == "completed":
292
+ # IMPORTANT: Update status BEFORE returning - using the passed manager
293
+ manager.backend_statuses[backend].complete()
294
+ manager.update_activity()
295
+
296
+ # Handle base64 output
297
+ output = data["outputs"][0]
298
+
299
+ # Check if it's a base64 string or URL
300
+ if isinstance(output, str) and output.startswith("http"):
301
+ # It's a URL - return as is
302
+ return output
303
+ else:
304
+ # It's base64 data - format it as a data URI if needed
305
+ try:
306
+ # Format as data URI for Gradio to display directly
307
+ if isinstance(
308
+ output, str
309
+ ) and not output.startswith("data:image"):
310
+ # Convert raw base64 to data URI format
311
+ return f"data:image/png;base64,{output}"
312
+ else:
313
+ # Already in data URI format
314
+ return output
315
+ except Exception as e:
316
+ print(f"Error processing base64 image: {e}")
317
+ raise Exception(
318
+ f"Failed to process base64 image: {str(e)}")
319
+
320
+ elif current_status == "failed":
321
+ manager.backend_statuses[backend].fail()
322
+ manager.update_activity()
323
+ error = data.get("error", "Unknown error")
324
+ raise Exception(error)
325
+
326
+ # Still processing
327
+ return None
328
+ else:
329
+ raise Exception(f"Poll error: {response.status}")
330
+
331
+
332
+ # Use a state variable to store session ID
333
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
334
+ session_id = gr.State(None) # Add this to store session ID
335
+
336
+ gr.Markdown("# 🌊 WaveSpeed AI Image Generator")
337
+
338
+ # Add the introduction with link to WaveSpeedAI
339
+ gr.Markdown(
340
+ "[WaveSpeedAI](https://wavespeed.ai/) is the global pioneer in accelerating AI-powered video and image generation."
341
+ )
342
+ gr.Markdown(
343
+ "Our in-house inference accelerator provides lossless speedup on image & video generation based on our rich inference optimization software stack, including our in-house inference compiler, CUDA kernel libraries and parallel computing libraries."
344
+ )
345
+
346
+ with gr.Row():
347
+ with gr.Column(scale=3):
348
+ input_text = gr.Textbox(
349
+ label="Enter your prompt",
350
+ placeholder="Type here...",
351
+ lines=3,
352
+ )
353
+ with gr.Column(scale=1):
354
+ generate_btn = gr.Button("Generate", variant="primary")
355
+
356
+ # Two status boxes - small (default) and big (during generation)
357
+ small_status_box = gr.Markdown("Ready to generate images",
358
+ elem_id="small-status")
359
+
360
+ # Big status box in its own row with styling
361
+ with gr.Row(elem_id="big-status-row"):
362
+ big_status_box = gr.Markdown("",
363
+ elem_id="big-status",
364
+ visible=False,
365
+ elem_classes="big-status-box")
366
+
367
+ with gr.Row():
368
+ with gr.Column():
369
+ draft_output = gr.Image(label="Flux-dev")
370
+ with gr.Column():
371
+ quick_output = gr.Image(label="HiDream-dev")
372
+ with gr.Column():
373
+ best_output = gr.Image(label="HiDream-full")
374
+
375
+ performance_plot = gr.Plot(label="Performance Metrics")
376
+
377
+ # Add custom CSS for the big status box
378
+ css = """
379
+ #big-status-row {
380
+ margin: 20px 0;
381
+ }
382
+ #big-status {
383
+ font-size: 28px; /* Even larger font size */
384
+ font-weight: bold;
385
+ padding: 30px; /* More padding */
386
+ background-color: #0D47A1; /* Deeper blue background */
387
+ color: white; /* White text */
388
+ border-radius: 10px;
389
+ text-align: center;
390
+ margin: 0 auto;
391
+ box-shadow: 0 6px 12px rgba(0, 0, 0, 0.2); /* Stronger shadow */
392
+ animation: deep-breath 3s infinite; /* Slower, deeper breathing animation */
393
+ width: 100%; /* Full width */
394
+ max-width: 800px; /* Maximum width */
395
+ transition: all 0.3s ease; /* Smooth transitions */
396
+ border-left: 6px solid #64B5F6; /* Add a colored border */
397
+ border-right: 6px solid #64B5F6; /* Add a colored border */
398
+ }
399
+
400
+ /* Deeper breathing animation */
401
+ @keyframes deep-breath {
402
+ 0% {
403
+ opacity: 0.7;
404
+ transform: scale(0.98);
405
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
406
+ }
407
+ 50% {
408
+ opacity: 1;
409
+ transform: scale(1.01);
410
+ box-shadow: 0 8px 16px rgba(0, 0, 0, 0.3);
411
+ }
412
+ 100% {
413
+ opacity: 0.7;
414
+ transform: scale(0.98);
415
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
416
+ }
417
+ }
418
+ """
419
+ gr.HTML(f"<style>{css}</style>")
420
+
421
+ # Update the generation function to use session manager
422
+ async def generate_all_backends_with_status_boxes(prompt,
423
+ current_session_id):
424
+ """Generate images with big status box during generation"""
425
+ # Get or create a session manager
426
+ session_id, manager = SessionManager.get_manager(current_session_id)
427
+ manager.update_activity()
428
+
429
+ # IMPORTANT: Reset history when starting a new generation
430
+ if prompt and prompt.strip() != "":
431
+ manager.reset_history() # Clear previous performance metrics
432
+
433
+ if not prompt or prompt.strip() == "":
434
+ # Handle empty prompt case
435
+ yield (
436
+ "⚠️ Please enter a prompt first",
437
+ "⚠️ Please enter a prompt first",
438
+ gr.update(visible=True),
439
+ gr.update(visible=False),
440
+ None,
441
+ None,
442
+ None,
443
+ None,
444
+ session_id, # Return the session ID
445
+ )
446
+ return
447
+
448
+ # Status message
449
+ status_message = f"🔄 PROCESSING: '{prompt}'"
450
+
451
+ # Initial state - clear all images, show big status box
452
+ yield (
453
+ status_message,
454
+ status_message,
455
+ gr.update(visible=True),
456
+ gr.update(visible=False),
457
+ None,
458
+ None,
459
+ None,
460
+ None,
461
+ session_id, # Return the session ID
462
+ )
463
+
464
+ # For production mode:
465
+ completed_backends = set()
466
+ results = {"flux-dev": None, "hidream-dev": None, "hidream-full": None}
467
+
468
+ try:
469
+ # Submit all tasks
470
+ request_ids = {}
471
+ for backend in BACKENDS:
472
+ try:
473
+ request_id = await manager.submit_task(backend, prompt)
474
+ request_ids[backend] = request_id
475
+ except Exception as e:
476
+ # Handle submission error
477
+ print(f"Error submitting task for {backend}: {e}")
478
+ results[backend] = create_error_image(backend, str(e))
479
+ completed_backends.add(backend)
480
+
481
+ # Poll all backends until they complete
482
+ max_poll_attempts = 300
483
+ poll_attempt = 0
484
+
485
+ # Main polling loop
486
+ while len(completed_backends
487
+ ) < 3 and poll_attempt < max_poll_attempts:
488
+ poll_attempt += 1
489
+
490
+ # Poll each pending backend
491
+ for backend in list(BACKENDS.keys()):
492
+ if backend in completed_backends:
493
+ continue
494
+
495
+ try:
496
+ # Only do actual API calls every few attempts to reduce load
497
+ if poll_attempt % 2 == 0 or backend == "flux-dev":
498
+ # Use the session manager instead of global manager
499
+ result = await poll_once(manager, backend,
500
+ request_ids[backend])
501
+ if result: # Backend completed
502
+ results[backend] = result
503
+ completed_backends.add(backend)
504
+
505
+ # Yield updated state when an image completes
506
+ yield (
507
+ status_message,
508
+ status_message,
509
+ gr.update(visible=True),
510
+ gr.update(visible=False),
511
+ results["flux-dev"],
512
+ results["hidream-dev"],
513
+ results["hidream-full"],
514
+ (manager.get_performance_plot()
515
+ if any(completed_backends) else None),
516
+ session_id,
517
+ )
518
+ except Exception as e:
519
+ print(f"Error polling {backend}: {str(e)}")
520
+
521
+ # Wait between poll attempts
522
+ await asyncio.sleep(0.1)
523
+
524
+ # Final status
525
+ final_status = ("✅ All generations completed!"
526
+ if len(completed_backends) == 3 else
527
+ "⚠️ Some generations timed out")
528
+
529
+ # Final yield
530
+ yield (
531
+ final_status,
532
+ final_status,
533
+ gr.update(visible=False),
534
+ gr.update(visible=True),
535
+ results["flux-dev"],
536
+ results["hidream-dev"],
537
+ results["hidream-full"],
538
+ manager.get_performance_plot(),
539
+ session_id,
540
+ )
541
+
542
+ except Exception as e:
543
+ # Error handling
544
+ error_message = f"❌ Error: {str(e)}"
545
+ yield (
546
+ error_message,
547
+ error_message,
548
+ gr.update(visible=False),
549
+ gr.update(visible=True),
550
+ None,
551
+ None,
552
+ None,
553
+ None,
554
+ session_id,
555
+ )
556
+
557
+ # Schedule periodic cleanup of old sessions
558
+ def cleanup_task():
559
+ SessionManager.cleanup_old_sessions()
560
+ # Schedule the next cleanup
561
+ threading.Timer(3600, cleanup_task).start() # Run every hour
562
+
563
+ # Start the cleanup task
564
+ cleanup_task()
565
+
566
+ # Update the click handler to include session_id
567
+ generate_btn.click(
568
+ fn=generate_all_backends_with_status_boxes,
569
+ inputs=[input_text, session_id],
570
+ outputs=[
571
+ small_status_box,
572
+ big_status_box,
573
+ big_status_box, # visibility
574
+ small_status_box, # visibility
575
+ draft_output,
576
+ quick_output,
577
+ best_output,
578
+ performance_plot,
579
+ session_id, # Update the session ID
580
+ ],
581
+ api_name="generate",
582
+ max_batch_size=10, # Process up to 10 requests at once
583
+ concurrency_limit=20, # Allow up to 20 concurrent requests
584
+ concurrency_id="generation", # Group concurrent requests under this ID
585
+ )
586
+
587
+ # Launch with increased max_threads
588
+ if __name__ == "__main__":
589
+ demo.queue(max_size=50).launch(
590
+ server_name="0.0.0.0",
591
+ max_threads=16, # Increase thread count for better concurrency
592
+ )
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ aiohttp
3
+ plotly
4
+ python-dotenv