xinjie.wang commited on
Commit
4e2d5ef
·
1 Parent(s): 3150e46
Files changed (1) hide show
  1. asset3d_gen/models/sr_model.py +36 -33
asset3d_gen/models/sr_model.py CHANGED
@@ -59,55 +59,58 @@ class ImageStableSR:
59
 
60
  class ImageRealESRGAN:
61
  def __init__(self, outscale: int, model_path: str = None) -> None:
62
- # monkey_patch
63
  import torchvision
64
  from packaging import version
65
 
66
  if version.parse(torchvision.__version__) > version.parse("0.16"):
67
  import sys
68
  import types
69
-
70
  import torchvision.transforms.functional as TF
71
 
72
- functional_tensor = types.ModuleType(
73
- "torchvision.transforms.functional_tensor"
74
- )
75
  functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale
76
- sys.modules["torchvision.transforms.functional_tensor"] = (
77
- functional_tensor
78
- )
79
-
80
- from basicsr.archs.rrdbnet_arch import RRDBNet
81
- from realesrgan import RealESRGANer
82
 
83
  self.outscale = outscale
84
- model = RRDBNet(
85
- num_in_ch=3,
86
- num_out_ch=3,
87
- num_feat=64,
88
- num_block=23,
89
- num_grow_ch=32,
90
- scale=4,
91
- )
92
- if model_path is None:
93
- suffix = "super_resolution"
94
- model_path = snapshot_download(
95
- repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
96
- )
97
- model_path = os.path.join(
98
- model_path, suffix, "RealESRGAN_x4plus.pth"
 
99
  )
100
 
101
- self.upsampler = RealESRGANer(
102
- scale=4,
103
- model_path=model_path,
104
- model=model,
105
- pre_pad=0,
106
- half=True,
107
- )
 
 
 
 
 
 
 
 
108
 
109
  @spaces.GPU
110
  def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
 
 
111
  if isinstance(image, Image.Image):
112
  image = np.array(image)
113
 
 
59
 
60
  class ImageRealESRGAN:
61
  def __init__(self, outscale: int, model_path: str = None) -> None:
62
+ # monkey patch to support torchvision>=0.16
63
  import torchvision
64
  from packaging import version
65
 
66
  if version.parse(torchvision.__version__) > version.parse("0.16"):
67
  import sys
68
  import types
 
69
  import torchvision.transforms.functional as TF
70
 
71
+ functional_tensor = types.ModuleType("torchvision.transforms.functional_tensor")
 
 
72
  functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale
73
+ sys.modules["torchvision.transforms.functional_tensor"] = functional_tensor
 
 
 
 
 
74
 
75
  self.outscale = outscale
76
+ self.model_path = model_path
77
+ self.upsampler = None
78
+
79
+ def _lazy_init(self):
80
+ if self.upsampler is None:
81
+ from basicsr.archs.rrdbnet_arch import RRDBNet
82
+ from realesrgan import RealESRGANer
83
+ from huggingface_hub import snapshot_download
84
+
85
+ model = RRDBNet(
86
+ num_in_ch=3,
87
+ num_out_ch=3,
88
+ num_feat=64,
89
+ num_block=23,
90
+ num_grow_ch=32,
91
+ scale=4,
92
  )
93
 
94
+ model_path = self.model_path
95
+ if model_path is None:
96
+ suffix = "super_resolution"
97
+ model_path = snapshot_download(
98
+ repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
99
+ )
100
+ model_path = os.path.join(model_path, suffix, "RealESRGAN_x4plus.pth")
101
+
102
+ self.upsampler = RealESRGANer(
103
+ scale=4,
104
+ model_path=model_path,
105
+ model=model,
106
+ pre_pad=0,
107
+ half=True,
108
+ )
109
 
110
  @spaces.GPU
111
  def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
112
+ self._lazy_init()
113
+
114
  if isinstance(image, Image.Image):
115
  image = np.array(image)
116