Spaces:
Paused
Paused
File size: 5,147 Bytes
4f44506 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
from PIL import Image
import time
import io
from threading import Thread
import torch.nn.functional as F
import torch
import latent_preview
import server
serv = server.PromptServer.instance
from .utils import hook
rates_table = {'Mochi': 24//6, 'LTXV': 24//8, 'HunyuanVideo': 24//4,
'Cosmos1CV8x8x8': 24//8, 'Wan21': 16//4}
class WrappedPreviewer(latent_preview.LatentPreviewer):
def __init__(self, previewer, rate=8):
self.first_preview = True
self.last_time = 0
self.c_index = 0
self.rate = rate
if hasattr(previewer, 'taesd'):
self.taesd = previewer.taesd
elif hasattr(previewer, 'latent_rgb_factors'):
self.latent_rgb_factors = previewer.latent_rgb_factors
self.latent_rgb_factors_bias = previewer.latent_rgb_factors_bias
else:
raise Exception('Unsupported preview type for VHS animated previews')
def decode_latent_to_preview_image(self, preview_format, x0):
if x0.ndim == 5:
#Keep batch major
x0 = x0.movedim(2,1)
x0 = x0.reshape((-1,)+x0.shape[-3:])
num_images = x0.size(0)
new_time = time.time()
num_previews = int((new_time - self.last_time) * self.rate)
self.last_time = self.last_time + num_previews/self.rate
if num_previews > num_images:
num_previews = num_images
elif num_previews <= 0:
return None
if self.first_preview:
self.first_preview = False
serv.send_sync('VHS_latentpreview', {'length':num_images, 'rate': self.rate})
self.last_time = new_time + 1/self.rate
if self.c_index + num_previews > num_images:
x0 = x0.roll(-self.c_index, 0)[:num_previews]
else:
x0 = x0[self.c_index:self.c_index + num_previews]
Thread(target=self.process_previews, args=(x0, self.c_index,
num_images)).run()
self.c_index = (self.c_index + num_previews) % num_images
return None
def process_previews(self, image_tensor, ind, leng):
image_tensor = self.decode_latent_to_preview(image_tensor)
if image_tensor.size(1) > 512 or image_tensor.size(2) > 512:
image_tensor = image_tensor.movedim(-1,0)
if image_tensor.size(2) < image_tensor.size(3):
height = (512 * image_tensor.size(2)) // image_tensor.size(3)
image_tensor = F.interpolate(image_tensor, (height,512), mode='bilinear')
else:
width = (512 * image_tensor.size(3)) // image_tensor.size(2)
image_tensor = F.interpolate(image_tensor, (512, width), mode='bilinear')
image_tensor = image_tensor.movedim(0,-1)
previews_ubyte = (((image_tensor + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
).to(device="cpu", dtype=torch.uint8)
for preview in previews_ubyte:
i = Image.fromarray(preview.numpy())
message = io.BytesIO()
message.write((1).to_bytes(length=4, byteorder='big')*2)
message.write(ind.to_bytes(length=4, byteorder='big'))
i.save(message, format="JPEG", quality=95, compress_level=1)
#NOTE: send sync already uses call_soon_threadsafe
serv.send_sync(server.BinaryEventTypes.PREVIEW_IMAGE,
message.getvalue(), serv.client_id)
ind = (ind + 1) % leng
def decode_latent_to_preview(self, x0):
if hasattr(self, 'taesd'):
x_sample = self.taesd.decode(x0).movedim(1, 3)
return x_sample
else:
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
if self.latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
latent_image = F.linear(x0.movedim(1, -1), self.latent_rgb_factors,
bias=self.latent_rgb_factors_bias)
return latent_image
@hook(latent_preview, 'get_previewer')
def get_latent_video_previewer(device, latent_format, *args, **kwargs):
node_id = serv.last_node_id
previewer = get_latent_video_previewer.__wrapped__(device, latent_format, *args, **kwargs)
try:
extra_info = next(serv.prompt_queue.currently_running.values().__iter__()) \
[3]['extra_pnginfo']['workflow']['extra']
prev_setting = extra_info.get('VHS_latentpreview', False)
if extra_info.get('VHS_latentpreviewrate', 0) != 0:
rate_setting = extra_info['VHS_latentpreviewrate']
else:
rate_setting = rates_table.get(latent_format.__class__.__name__, 8)
except:
#For safety since there's lots of keys, any of which can fail
prev_setting = False
if not prev_setting or not hasattr(previewer, "decode_latent_to_preview"):
return previewer
return WrappedPreviewer(previewer, rate_setting)
|