kaupane commited on
Commit
cc3f951
·
verified ·
1 Parent(s): 23f7c2b

Update app.py

Browse files

Fix dtype inconsistency bug to enable half-precision inference

Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -19,7 +19,7 @@ latent_scale_factor = 0.18215 # Same as in DiTTrainer
19
  global_progress = 0
20
 
21
  # Enable half precision inference
22
- USE_HALF_PRECISION = False
23
 
24
  def load_dit_model(dit_size):
25
  """Load DiT model of specified size"""
@@ -99,7 +99,9 @@ class DiffusionSampler:
99
  if torch.cuda.is_available():
100
  torch.cuda.manual_seed(seed)
101
  torch.cuda.manual_seed_all(seed)
102
-
 
 
103
  model.to(self.device)
104
  model.eval()
105
 
@@ -185,9 +187,9 @@ def generate_random_seed():
185
  return random.randint(0, 2**32 - 1)
186
 
187
  MODEL_SAMPLE_LIMITS = {
188
- "S": {"min":1, "max": 18, "default": 4},
189
- "B": {"min":1, "max": 9, "default": 4},
190
- "L": {"min":1, "max": 3, "default": 1}
191
  }
192
 
193
  def update_sample_slider(dit_size):
 
19
  global_progress = 0
20
 
21
  # Enable half precision inference
22
+ USE_HALF_PRECISION = True
23
 
24
  def load_dit_model(dit_size):
25
  """Load DiT model of specified size"""
 
99
  if torch.cuda.is_available():
100
  torch.cuda.manual_seed(seed)
101
  torch.cuda.manual_seed_all(seed)
102
+
103
+ if self.use_half:
104
+ model.half()
105
  model.to(self.device)
106
  model.eval()
107
 
 
187
  return random.randint(0, 2**32 - 1)
188
 
189
  MODEL_SAMPLE_LIMITS = {
190
+ "S": {"min":1, "max": 16, "default": 4},
191
+ "B": {"min":1, "max": 12, "default": 3},
192
+ "L": {"min":1, "max": 4, "default": 1}
193
  }
194
 
195
  def update_sample_slider(dit_size):