lzyhha commited on
Commit
0d1683b
·
1 Parent(s): 40fb840
Files changed (1) hide show
  1. visualcloze.py +15 -14
visualcloze.py CHANGED
@@ -407,20 +407,21 @@ class VisualClozeModel:
407
  )[-1]
408
 
409
  # Get query row
410
- samples = samples[:1]
411
- row_samples = []
412
- start = 0
413
- for size in sliced_subimage:
414
- end = start + (size[0] * size[1] // 256)
415
- latent_h = size[0] // 8
416
- latent_w = size[1] // 8
417
- row_sample = samples[:, start:end, :]
418
- row_sample = rearrange(row_sample, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=2, pw=2, h=latent_h//2, w=latent_w//2)
419
- row_sample = self.ae.decode(row_sample / self.ae.config.scaling_factor + self.ae.config.shift_factor)[0]
420
- row_sample = (row_sample + 1.0) / 2.0
421
- row_sample.clamp_(0.0, 1.0)
422
- row_samples.append(row_sample[0])
423
- start = end
 
424
 
425
  # Convert all samples to PIL images
426
  output_images = []
 
407
  )[-1]
408
 
409
  # Get query row
410
+ with torch.no_grad():
411
+ samples = samples[:1]
412
+ row_samples = []
413
+ start = 0
414
+ for size in sliced_subimage:
415
+ end = start + (size[0] * size[1] // 256)
416
+ latent_h = size[0] // 8
417
+ latent_w = size[1] // 8
418
+ row_sample = samples[:, start:end, :]
419
+ row_sample = rearrange(row_sample, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=2, pw=2, h=latent_h//2, w=latent_w//2)
420
+ row_sample = self.ae.decode(row_sample / self.ae.config.scaling_factor + self.ae.config.shift_factor)[0]
421
+ row_sample = (row_sample + 1.0) / 2.0
422
+ row_sample.clamp_(0.0, 1.0)
423
+ row_samples.append(row_sample[0])
424
+ start = end
425
 
426
  # Convert all samples to PIL images
427
  output_images = []