Pringled commited on
Commit
75ff340
·
1 Parent(s): 1d331c4
Files changed (1) hide show
  1. app.py +12 -2
app.py CHANGED
@@ -20,24 +20,33 @@ default_threshold = 0.9
20
  ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
21
  ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
22
 
 
 
23
  from tqdm import tqdm as original_tqdm
 
24
  # Patch tqdm to use Gradio's progress bar
25
  def patch_tqdm_for_gradio(progress):
26
  class GradioTqdm(original_tqdm):
27
  def __init__(self, *args, **kwargs):
28
  super().__init__(*args, **kwargs)
29
  self.progress = progress
 
30
  self.total_batches = kwargs.get('total', len(args[0])) if len(args) > 0 else 1
 
31
 
32
  def update(self, n=1):
33
  super().update(n)
34
- self.progress(self.n / self.total_batches)
 
 
35
 
36
  return GradioTqdm
 
 
 
37
  # Function to patch the original encode function with our Gradio tqdm
38
  def original_encode_with_tqdm(original_encode_func, patched_tqdm):
39
  def new_encode(*args, **kwargs):
40
- # Replace tqdm with our patched version
41
  original_tqdm_backup = original_tqdm
42
  try:
43
  # Patch the `tqdm` within encode
@@ -49,6 +58,7 @@ def original_encode_with_tqdm(original_encode_func, patched_tqdm):
49
 
50
  return new_encode
51
 
 
52
  def batch_iterable(iterable, batch_size):
53
  """Helper function to create batches from an iterable."""
54
  for i in range(0, len(iterable), batch_size):
 
20
  ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
21
  ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
22
 
23
+
24
+ # Patch tqdm to use Gradio's progress bar
25
  from tqdm import tqdm as original_tqdm
26
+
27
  # Patch tqdm to use Gradio's progress bar
28
  def patch_tqdm_for_gradio(progress):
29
  class GradioTqdm(original_tqdm):
30
  def __init__(self, *args, **kwargs):
31
  super().__init__(*args, **kwargs)
32
  self.progress = progress
33
+ # Set smaller step sizes or update more frequently based on total items
34
  self.total_batches = kwargs.get('total', len(args[0])) if len(args) > 0 else 1
35
+ self.update_interval = max(1, self.total_batches // 100) # Update every 1% of progress
36
 
37
  def update(self, n=1):
38
  super().update(n)
39
+ # Only update Gradio's progress every `update_interval` steps
40
+ if self.n % self.update_interval == 0 or self.n == self.total_batches:
41
+ self.progress(self.n / self.total_batches)
42
 
43
  return GradioTqdm
44
+
45
+
46
+
47
  # Function to patch the original encode function with our Gradio tqdm
48
  def original_encode_with_tqdm(original_encode_func, patched_tqdm):
49
  def new_encode(*args, **kwargs):
 
50
  original_tqdm_backup = original_tqdm
51
  try:
52
  # Patch the `tqdm` within encode
 
58
 
59
  return new_encode
60
 
61
+
62
  def batch_iterable(iterable, batch_size):
63
  """Helper function to create batches from an iterable."""
64
  for i in range(0, len(iterable), batch_size):