root commited on
Commit
a743121
·
1 Parent(s): 45fe9c1
Files changed (1) hide show
  1. wan_pipeline.py +3 -1
wan_pipeline.py CHANGED
@@ -71,7 +71,7 @@ EXAMPLE_DOC_STRING = """
71
  >>> export_to_video(output, "output.mp4", fps=16)
72
  ```
73
  """
74
-
75
  def optimized_scale(positive_flat, negative_flat):
76
 
77
  # Calculate dot production
@@ -561,6 +561,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
561
  alpha = optimized_scale(positive_flat,negative_flat)
562
  alpha = alpha.view(batch_size, 1, 1, 1)
563
 
 
564
  if (i <= zero_steps) and use_zero_init:
565
  noise_pred = noise_pred_text*0.
566
  else:
@@ -614,3 +615,4 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
614
  return (video,)
615
 
616
  return WanPipelineOutput(frames=video)
 
 
71
  >>> export_to_video(output, "output.mp4", fps=16)
72
  ```
73
  """
74
+ @torch.cuda.amp.autocast(dtype=torch.float32)
75
  def optimized_scale(positive_flat, negative_flat):
76
 
77
  # Calculate dot production
 
561
  alpha = optimized_scale(positive_flat,negative_flat)
562
  alpha = alpha.view(batch_size, 1, 1, 1)
563
 
564
+
565
  if (i <= zero_steps) and use_zero_init:
566
  noise_pred = noise_pred_text*0.
567
  else:
 
615
  return (video,)
616
 
617
  return WanPipelineOutput(frames=video)
618
+