grnr9730's picture
Update handler.py
09a8907 verified
import torch
from diffusers import DiffusionPipeline # ๊ธฐ๋ณธ ํด๋ž˜์Šค ์‚ฌ์šฉ
from PIL import Image
from io import BytesIO
import base64
# ์ปค์Šคํ…€ ๋ชจ๋“ˆ ์ž„ํฌํŠธ
from custom_pipeline import WanTransformer3DModel, AutoencoderKLWan
# ๋ชจ๋ธ ID
model_id = "grnr9730/Wan2.1-I2V-14B-720P-Diffusers"
# ์ปค์Šคํ…€ ํŒŒ์ดํ”„๋ผ์ธ ์ •์˜
class WanImageToVideoPipeline(DiffusionPipeline):
def __init__(self, transformer, vae, scheduler):
super().__init__()
self.transformer = transformer
self.vae = vae
self.scheduler = scheduler
self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler)
def __call__(self, prompt, **kwargs):
# ์—ฌ๊ธฐ์— ์‹ค์ œ ์ถ”๋ก  ๋กœ์ง์„ ์ถ”๊ฐ€ (WanTransformer3DModel๊ณผ AutoencoderKLWan ์‚ฌ์šฉ)
# ์˜ˆ์‹œ๋กœ ๊ฐ„๋‹จํžˆ ์ž‘์„ฑ
latents = self.vae.encode(torch.randn(1, 3, 224, 224)).latent_dist.sample()
for _ in self.scheduler.timesteps:
latents = self.transformer(latents)
video_frames = self.vae.decode(latents).sample
return type('Result', (), {'frames': [Image.fromarray((frame * 255).byte().cpu().numpy()) for frame in video_frames]})
# ๋ชจ๋ธ ๋กœ๋“œ
pipe = WanImageToVideoPipeline.from_pretrained(
model_id,
transformer=WanTransformer3DModel.from_pretrained(model_id),
vae=AutoencoderKLWan.from_pretrained(model_id),
scheduler=FlowMatchEulerDiscreteScheduler.from_pretrained(model_id),
torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
# Inference ํ•จ์ˆ˜
def infer(data):
prompt = data.get("prompt", "A futuristic cityscape")
# ๋น„๋””์˜ค ์ƒ์„ฑ
video = pipe(prompt).frames # ํ”„๋ ˆ์ž„ ๋ฆฌ์ŠคํŠธ ๋ฐ˜ํ™˜
# ์ฒซ ๋ฒˆ์งธ ํ”„๋ ˆ์ž„์„ ์ด๋ฏธ์ง€๋กœ ์ €์žฅ
first_frame = video[0]
# ์ด๋ฏธ์ง€ Base64๋กœ ๋ณ€ํ™˜
buffered = BytesIO()
first_frame.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {"image": img_str}