YiftachEde commited on
Commit
1908f03
·
verified ·
1 Parent(s): b149af8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -1,9 +1,11 @@
1
  import os
 
2
  # this is a HF Spaces specific hack for ZeroGPU
3
  import spaces
4
 
5
  import sys
6
  import torch
 
7
 
8
  import torch
9
  import torch.nn as nn
@@ -319,12 +321,14 @@ class ShapERenderer:
319
  guidance_scale = float(guidance_scale)
320
 
321
  with torch.amp.autocast('cuda'): # Use automatic mixed precision
322
- latents = sample_latents(
323
- batch_size=batch_size,
324
- model=self.model,
325
- diffusion=self.diffusion,
326
- guidance_scale=guidance_scale,
327
- model_kwargs=dict(texts=[prompt] * batch_size),
 
 
328
  progress=True,
329
  clip_denoised=True,
330
  use_fp16=True,
@@ -534,7 +538,6 @@ def create_demo():
534
  )
535
 
536
  # Set up event handlers
537
- @spaces.GPU(duration=20) # Reduced duration to 20 seconds
538
  def generate(prompt, guidance_scale, num_steps):
539
  try:
540
  torch.cuda.empty_cache() # Clear GPU memory before starting
 
1
  import os
2
+ from typing import Union
3
  # this is a HF Spaces specific hack for ZeroGPU
4
  import spaces
5
 
6
  import sys
7
  import torch
8
+ from shap_e.models.transmitter.base import Transmitter, VectorDecoder
9
 
10
  import torch
11
  import torch.nn as nn
 
321
  guidance_scale = float(guidance_scale)
322
 
323
  with torch.amp.autocast('cuda'): # Use automatic mixed precision
324
+ # spaces duration is 20 seconds, so we need to be careful here
325
+ with spaces.GPU(duration=20):
326
+ latents = sample_latents(
327
+ batch_size=batch_size,
328
+ model=self.model,
329
+ diffusion=self.diffusion,
330
+ guidance_scale=guidance_scale,
331
+ model_kwargs=dict(texts=[prompt] * batch_size),
332
  progress=True,
333
  clip_denoised=True,
334
  use_fp16=True,
 
538
  )
539
 
540
  # Set up event handlers
 
541
  def generate(prompt, guidance_scale, num_steps):
542
  try:
543
  torch.cuda.empty_cache() # Clear GPU memory before starting