chengzeyi commited on
Commit
6380d03
·
1 Parent(s): 67635cf

add more limit

Browse files
Files changed (1) hide show
  1. app.py +143 -69
app.py CHANGED
@@ -64,9 +64,10 @@ class BackendStatus:
64
  self.status = "completed"
65
  self.progress = 100
66
  self.end_time = time.time()
67
- self.history.append(
68
- {"timestamp": datetime.now(), "duration": self.end_time - self.start_time}
69
- )
 
70
 
71
  def fail(self):
72
  self.status = "failed"
@@ -93,10 +94,8 @@ class SessionManager:
93
  with cls._lock:
94
  to_remove = []
95
  for session_id, manager in cls._instances.items():
96
- if (
97
- hasattr(manager, "last_activity")
98
- and current_time - manager.last_activity > max_age
99
- ):
100
  to_remove.append(session_id)
101
 
102
  for session_id in to_remove:
@@ -106,7 +105,10 @@ class SessionManager:
106
  class GenerationManager:
107
 
108
  def __init__(self):
109
- self.backend_statuses = {backend: BackendStatus() for backend in BACKENDS}
 
 
 
110
  self.last_activity = time.time()
111
  self.request_timestamps = [] # Track timestamps of requests
112
 
@@ -116,7 +118,8 @@ class GenerationManager:
116
  def add_request_timestamp(self):
117
  self.request_timestamps.append(time.time())
118
 
119
- def has_exceeded_limit(self, limit=10): # Default limit: 10 requests per hour
 
120
  current_time = time.time()
121
  # Filter timestamps to only include those within the last hour
122
  self.request_timestamps = [
@@ -144,17 +147,14 @@ class GenerationManager:
144
  text=[f"{avg_duration:.2f}s"], # Show time in seconds
145
  textposition="auto",
146
  width=[0.5], # Make bars narrower
147
- )
148
- )
149
 
150
  # Set a minimum y-axis range if we have data
151
  if has_data:
152
- max_duration = max(
153
- [
154
- max([h["duration"] for h in status.history] or [0])
155
- for status in self.backend_statuses.values()
156
- ]
157
- )
158
  # Add 20% padding to the top
159
  y_max = max_duration * 1.2
160
  # Ensure the y-axis always starts at 0
@@ -209,15 +209,19 @@ class GenerationManager:
209
 
210
  # Use aiohttp instead of requests for async
211
  async with aiohttp.ClientSession() as session:
212
- async with session.post(url, headers=headers, json=payload) as response:
 
213
  if response.status == 200:
214
  result = await response.json()
215
  request_id = result["data"]["id"]
216
- print(f"Task submitted successfully. Request ID: {request_id}")
 
 
217
  return request_id
218
  else:
219
  text = await response.text()
220
- raise Exception(f"API error: {response.status}, {text}")
 
221
 
222
  except Exception as e:
223
  status.fail()
@@ -231,6 +235,59 @@ class GenerationManager:
231
  return self
232
 
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  # Helper function to create error images as data URIs
235
  def create_error_image(backend, error_message):
236
  try:
@@ -305,9 +362,9 @@ async def poll_once(manager, backend, request_id):
305
  # It's base64 data - format it as a data URI if needed
306
  try:
307
  # Format as data URI for Gradio to display directly
308
- if isinstance(output, str) and not output.startswith(
309
- "data:image"
310
- ):
311
  # Convert raw base64 to data URI format
312
  return f"data:image/jpeg;base64,{output}"
313
  else:
@@ -315,7 +372,8 @@ async def poll_once(manager, backend, request_id):
315
  return output
316
  except Exception as e:
317
  print(f"Error processing base64 image: {e}")
318
- raise Exception(f"Failed to process base64 image: {str(e)}")
 
319
 
320
  elif current_status == "failed":
321
  manager.backend_statuses[backend].fail()
@@ -347,17 +405,13 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
347
  gr.Markdown("# 🌊 WaveSpeedAI HiDream Arena")
348
 
349
  # Add the introduction with link to WaveSpeedAI
350
- gr.Markdown(
351
- """
352
  [WaveSpeedAI](https://wavespeed.ai/) is the global pioneer in accelerating AI-powered video and image generation.
353
  Our in-house inference accelerator provides lossless speedup on image & video generation based on our rich inference optimization software stack, including our in-house inference compiler, CUDA kernel libraries and parallel computing libraries.
354
- """
355
- )
356
- gr.Markdown(
357
- """
358
  This demo showcases the performance and outputs of leading image generation models, including HiDream and Flux, on our accelerated inference platform.
359
- """
360
- )
361
 
362
  with gr.Row():
363
  with gr.Column(scale=3):
@@ -375,18 +429,20 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
375
  with gr.Column(scale=1):
376
  generate_btn = gr.Button("Generate", variant="primary")
377
 
378
- example_dropdown.change(
379
- lambda ex: ex, inputs=[example_dropdown], outputs=[input_text]
380
- )
381
 
382
  # Two status boxes - small (default) and big (during generation)
383
- small_status_box = gr.Markdown("Ready to generate images", elem_id="small-status")
 
384
 
385
  # Big status box in its own row with styling
386
  with gr.Row(elem_id="big-status-row"):
387
- big_status_box = gr.Markdown(
388
- "", elem_id="big-status", visible=False, elem_classes="big-status-box"
389
- )
 
390
 
391
  with gr.Row():
392
  with gr.Column():
@@ -399,27 +455,27 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
399
  performance_plot = gr.Plot(label="Performance Metrics")
400
 
401
  with gr.Accordion("Recent Generations (last 16)", open=False):
402
- recent_gallery = gr.Gallery(
403
- label="Prompt and Output", columns=3, interactive=False
404
- )
405
 
406
  def get_recent_gallery_items():
407
  gallery_items = []
408
  for r in reversed(recent_generations):
409
  gallery_items.append((r["flux-dev"], f"FLUX-dev: {r['prompt']}"))
410
- gallery_items.append((r["hidream-dev"], f"HiDream-dev: {r['prompt']}"))
411
- gallery_items.append((r["hidream-full"], f"HiDream-full: {r['prompt']}"))
 
 
412
  return gallery_items
413
 
414
  def update_recent_gallery(prompt, results):
415
- recent_generations.append(
416
- {
417
- "prompt": prompt,
418
- "flux-dev": results["flux-dev"],
419
- "hidream-dev": results["hidream-dev"],
420
- "hidream-full": results["hidream-full"],
421
- }
422
- )
423
  if len(recent_generations) > 16:
424
  recent_generations.pop(0)
425
  gallery_items = get_recent_gallery_items()
@@ -470,7 +526,30 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
470
  gr.HTML(f"<style>{css}</style>")
471
 
472
  # Update the generation function to use session manager
473
- async def generate_all_backends_with_status_boxes(prompt, current_session_id):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  """Generate images with big status box during generation"""
475
  # Get or create a session manager
476
  session_id, manager = SessionManager.get_manager(current_session_id)
@@ -478,8 +557,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
478
 
479
  # Check if the user has exceeded the request limit
480
  if manager.has_exceeded_limit(
481
- limit=10
482
- ): # Set the limit to 10 requests per hour
483
  error_message = "❌ You have exceeded the limit of 10 requests per hour. Please try again later."
484
  yield (
485
  error_message,
@@ -557,7 +635,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
557
  poll_attempt = 0
558
 
559
  # Main polling loop
560
- while len(completed_backends) < 3 and poll_attempt < max_poll_attempts:
 
561
  poll_attempt += 1
562
 
563
  # Poll each pending backend
@@ -569,9 +648,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
569
  # Only do actual API calls every few attempts to reduce load
570
  if poll_attempt % 2 == 0 or backend == "flux-dev":
571
  # Use the session manager instead of global manager
572
- result = await poll_once(
573
- manager, backend, request_ids[backend]
574
- )
575
  if result: # Backend completed
576
  results[backend] = result
577
  completed_backends.add(backend)
@@ -585,11 +663,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
585
  results["flux-dev"],
586
  results["hidream-dev"],
587
  results["hidream-full"],
588
- (
589
- manager.get_performance_plot()
590
- if any(completed_backends)
591
- else None
592
- ),
593
  session_id,
594
  None,
595
  )
@@ -600,11 +675,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
600
  await asyncio.sleep(0.1)
601
 
602
  # Final status
603
- final_status = (
604
- "✅ All generations completed!"
605
- if len(completed_backends) == 3
606
- else "⚠️ Some generations timed out"
607
- )
608
 
609
  gallery_update = update_recent_gallery(prompt, results)
610
 
@@ -641,6 +714,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
641
  # Schedule periodic cleanup of old sessions
642
  def cleanup_task():
643
  SessionManager.cleanup_old_sessions()
 
644
  # Schedule the next cleanup
645
  threading.Timer(3600, cleanup_task).start() # Run every hour
646
 
 
64
  self.status = "completed"
65
  self.progress = 100
66
  self.end_time = time.time()
67
+ self.history.append({
68
+ "timestamp": datetime.now(),
69
+ "duration": self.end_time - self.start_time
70
+ })
71
 
72
  def fail(self):
73
  self.status = "failed"
 
94
  with cls._lock:
95
  to_remove = []
96
  for session_id, manager in cls._instances.items():
97
+ if (hasattr(manager, "last_activity")
98
+ and current_time - manager.last_activity > max_age):
 
 
99
  to_remove.append(session_id)
100
 
101
  for session_id in to_remove:
 
105
  class GenerationManager:
106
 
107
  def __init__(self):
108
+ self.backend_statuses = {
109
+ backend: BackendStatus()
110
+ for backend in BACKENDS
111
+ }
112
  self.last_activity = time.time()
113
  self.request_timestamps = [] # Track timestamps of requests
114
 
 
118
  def add_request_timestamp(self):
119
  self.request_timestamps.append(time.time())
120
 
121
+ def has_exceeded_limit(self,
122
+ limit=10): # Default limit: 10 requests per hour
123
  current_time = time.time()
124
  # Filter timestamps to only include those within the last hour
125
  self.request_timestamps = [
 
147
  text=[f"{avg_duration:.2f}s"], # Show time in seconds
148
  textposition="auto",
149
  width=[0.5], # Make bars narrower
150
+ ))
 
151
 
152
  # Set a minimum y-axis range if we have data
153
  if has_data:
154
+ max_duration = max([
155
+ max([h["duration"] for h in status.history] or [0])
156
+ for status in self.backend_statuses.values()
157
+ ])
 
 
158
  # Add 20% padding to the top
159
  y_max = max_duration * 1.2
160
  # Ensure the y-axis always starts at 0
 
209
 
210
  # Use aiohttp instead of requests for async
211
  async with aiohttp.ClientSession() as session:
212
+ async with session.post(url, headers=headers,
213
+ json=payload) as response:
214
  if response.status == 200:
215
  result = await response.json()
216
  request_id = result["data"]["id"]
217
+ print(
218
+ f"Task submitted successfully. Request ID: {request_id}"
219
+ )
220
  return request_id
221
  else:
222
  text = await response.text()
223
+ raise Exception(
224
+ f"API error: {response.status}, {text}")
225
 
226
  except Exception as e:
227
  status.fail()
 
235
  return self
236
 
237
 
238
+ class ClientManager:
239
+ _instances = {}
240
+ _lock = threading.Lock()
241
+
242
+ @classmethod
243
+ def get_manager(cls, client_id=None):
244
+ if not client_id:
245
+ client_id = str(uuid.uuid4())
246
+
247
+ with cls._lock:
248
+ if client_id not in cls._instances:
249
+ cls._instances[client_id] = ClientGenerationManager()
250
+ return cls._instances[client_id]
251
+
252
+ @classmethod
253
+ def cleanup_old_clients(cls, max_age=3600): # 1 hour default
254
+ current_time = time.time()
255
+ with cls._lock:
256
+ to_remove = []
257
+ for client_id, manager in cls._instances.items():
258
+ if (hasattr(manager, "last_activity")
259
+ and current_time - manager.last_activity > max_age):
260
+ to_remove.append(client_id)
261
+
262
+ for client_id in to_remove:
263
+ del cls._instances[client_id]
264
+
265
+
266
+ class ClientGenerationManager:
267
+
268
+ def __init__(self):
269
+ self.lock = threading.Lock()
270
+
271
+ def update_activity(self):
272
+ with self.lock:
273
+ self.last_activity = time.time()
274
+
275
+ def add_request_timestamp(self):
276
+ with self.lock:
277
+ self.request_timestamps.append(time.time())
278
+
279
+ def has_exceeded_limit(self,
280
+ limit=100): # Default limit: 50 requests per hour
281
+ with self.lock:
282
+ current_time = time.time()
283
+ # Filter timestamps to only include those within the last hour
284
+ self.request_timestamps = [
285
+ ts for ts in self.request_timestamps
286
+ if current_time - ts <= 3600
287
+ ]
288
+ return len(self.request_timestamps) >= limit
289
+
290
+
291
  # Helper function to create error images as data URIs
292
  def create_error_image(backend, error_message):
293
  try:
 
362
  # It's base64 data - format it as a data URI if needed
363
  try:
364
  # Format as data URI for Gradio to display directly
365
+ if isinstance(
366
+ output, str
367
+ ) and not output.startswith("data:image"):
368
  # Convert raw base64 to data URI format
369
  return f"data:image/jpeg;base64,{output}"
370
  else:
 
372
  return output
373
  except Exception as e:
374
  print(f"Error processing base64 image: {e}")
375
+ raise Exception(
376
+ f"Failed to process base64 image: {str(e)}")
377
 
378
  elif current_status == "failed":
379
  manager.backend_statuses[backend].fail()
 
405
  gr.Markdown("# 🌊 WaveSpeedAI HiDream Arena")
406
 
407
  # Add the introduction with link to WaveSpeedAI
408
+ gr.Markdown("""
 
409
  [WaveSpeedAI](https://wavespeed.ai/) is the global pioneer in accelerating AI-powered video and image generation.
410
  Our in-house inference accelerator provides lossless speedup on image & video generation based on our rich inference optimization software stack, including our in-house inference compiler, CUDA kernel libraries and parallel computing libraries.
411
+ """)
412
+ gr.Markdown("""
 
 
413
  This demo showcases the performance and outputs of leading image generation models, including HiDream and Flux, on our accelerated inference platform.
414
+ """)
 
415
 
416
  with gr.Row():
417
  with gr.Column(scale=3):
 
429
  with gr.Column(scale=1):
430
  generate_btn = gr.Button("Generate", variant="primary")
431
 
432
+ example_dropdown.change(lambda ex: ex,
433
+ inputs=[example_dropdown],
434
+ outputs=[input_text])
435
 
436
  # Two status boxes - small (default) and big (during generation)
437
+ small_status_box = gr.Markdown("Ready to generate images",
438
+ elem_id="small-status")
439
 
440
  # Big status box in its own row with styling
441
  with gr.Row(elem_id="big-status-row"):
442
+ big_status_box = gr.Markdown("",
443
+ elem_id="big-status",
444
+ visible=False,
445
+ elem_classes="big-status-box")
446
 
447
  with gr.Row():
448
  with gr.Column():
 
455
  performance_plot = gr.Plot(label="Performance Metrics")
456
 
457
  with gr.Accordion("Recent Generations (last 16)", open=False):
458
+ recent_gallery = gr.Gallery(label="Prompt and Output",
459
+ columns=3,
460
+ interactive=False)
461
 
462
  def get_recent_gallery_items():
463
  gallery_items = []
464
  for r in reversed(recent_generations):
465
  gallery_items.append((r["flux-dev"], f"FLUX-dev: {r['prompt']}"))
466
+ gallery_items.append(
467
+ (r["hidream-dev"], f"HiDream-dev: {r['prompt']}"))
468
+ gallery_items.append(
469
+ (r["hidream-full"], f"HiDream-full: {r['prompt']}"))
470
  return gallery_items
471
 
472
  def update_recent_gallery(prompt, results):
473
+ recent_generations.append({
474
+ "prompt": prompt,
475
+ "flux-dev": results["flux-dev"],
476
+ "hidream-dev": results["hidream-dev"],
477
+ "hidream-full": results["hidream-full"],
478
+ })
 
 
479
  if len(recent_generations) > 16:
480
  recent_generations.pop(0)
481
  gallery_items = get_recent_gallery_items()
 
526
  gr.HTML(f"<style>{css}</style>")
527
 
528
  # Update the generation function to use session manager
529
+ async def generate_all_backends_with_status_boxes(prompt,
530
+ current_session_id,
531
+ request: gr.Request):
532
+ client_ip = request.client.host
533
+ print(f"Client IP: {client_ip}")
534
+ client_generation_manager = ClientManager.get_manager(client_ip)
535
+ client_generation_manager.update_activity()
536
+ if client_generation_manager.has_exceeded_limit(limit=100):
537
+ error_message = "❌ Your network has exceeded the limit of 100 requests per hour. Please try again later."
538
+ yield (
539
+ error_message,
540
+ error_message,
541
+ gr.update(visible=False),
542
+ gr.update(visible=True),
543
+ None,
544
+ None,
545
+ None,
546
+ None,
547
+ current_session_id, # Return the session ID
548
+ None,
549
+ )
550
+ return
551
+
552
+ client_generation_manager.add_request_timestamp()
553
  """Generate images with big status box during generation"""
554
  # Get or create a session manager
555
  session_id, manager = SessionManager.get_manager(current_session_id)
 
557
 
558
  # Check if the user has exceeded the request limit
559
  if manager.has_exceeded_limit(
560
+ limit=10): # Set the limit to 10 requests per hour
 
561
  error_message = "❌ You have exceeded the limit of 10 requests per hour. Please try again later."
562
  yield (
563
  error_message,
 
635
  poll_attempt = 0
636
 
637
  # Main polling loop
638
+ while len(completed_backends
639
+ ) < 3 and poll_attempt < max_poll_attempts:
640
  poll_attempt += 1
641
 
642
  # Poll each pending backend
 
648
  # Only do actual API calls every few attempts to reduce load
649
  if poll_attempt % 2 == 0 or backend == "flux-dev":
650
  # Use the session manager instead of global manager
651
+ result = await poll_once(manager, backend,
652
+ request_ids[backend])
 
653
  if result: # Backend completed
654
  results[backend] = result
655
  completed_backends.add(backend)
 
663
  results["flux-dev"],
664
  results["hidream-dev"],
665
  results["hidream-full"],
666
+ (manager.get_performance_plot()
667
+ if any(completed_backends) else None),
 
 
 
668
  session_id,
669
  None,
670
  )
 
675
  await asyncio.sleep(0.1)
676
 
677
  # Final status
678
+ final_status = ("✅ All generations completed!"
679
+ if len(completed_backends) == 3 else
680
+ "⚠️ Some generations timed out")
 
 
681
 
682
  gallery_update = update_recent_gallery(prompt, results)
683
 
 
714
  # Schedule periodic cleanup of old sessions
715
  def cleanup_task():
716
  SessionManager.cleanup_old_sessions()
717
+ ClientManager.cleanup_old_clients()
718
  # Schedule the next cleanup
719
  threading.Timer(3600, cleanup_task).start() # Run every hour
720