File size: 1,988 Bytes
2f90561
09a8907
2f90561
 
 
 
09a8907
f710dd1
708d533
a2a2581
884b321
a2a2581
09a8907
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2a2581
09a8907
 
 
 
 
 
 
a2a2581
2f90561
 
 
 
 
a2a2581
 
 
 
 
 
2f90561
 
a2a2581
2f90561
a2a2581
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from diffusers import DiffusionPipeline  # ๊ธฐ๋ณธ ํด๋ž˜์Šค ์‚ฌ์šฉ
from PIL import Image
from io import BytesIO
import base64

# ์ปค์Šคํ…€ ๋ชจ๋“ˆ ์ž„ํฌํŠธ
from custom_pipeline import WanTransformer3DModel, AutoencoderKLWan

# ๋ชจ๋ธ ID
model_id = "grnr9730/Wan2.1-I2V-14B-720P-Diffusers"

# ์ปค์Šคํ…€ ํŒŒ์ดํ”„๋ผ์ธ ์ •์˜
class WanImageToVideoPipeline(DiffusionPipeline):
    def __init__(self, transformer, vae, scheduler):
        super().__init__()
        self.transformer = transformer
        self.vae = vae
        self.scheduler = scheduler
        self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler)

    def __call__(self, prompt, **kwargs):
        # ์—ฌ๊ธฐ์— ์‹ค์ œ ์ถ”๋ก  ๋กœ์ง์„ ์ถ”๊ฐ€ (WanTransformer3DModel๊ณผ AutoencoderKLWan ์‚ฌ์šฉ)
        # ์˜ˆ์‹œ๋กœ ๊ฐ„๋‹จํžˆ ์ž‘์„ฑ
        latents = self.vae.encode(torch.randn(1, 3, 224, 224)).latent_dist.sample()
        for _ in self.scheduler.timesteps:
            latents = self.transformer(latents)
        video_frames = self.vae.decode(latents).sample
        return type('Result', (), {'frames': [Image.fromarray((frame * 255).byte().cpu().numpy()) for frame in video_frames]})

# ๋ชจ๋ธ ๋กœ๋“œ
pipe = WanImageToVideoPipeline.from_pretrained(
    model_id,
    transformer=WanTransformer3DModel.from_pretrained(model_id),
    vae=AutoencoderKLWan.from_pretrained(model_id),
    scheduler=FlowMatchEulerDiscreteScheduler.from_pretrained(model_id),
    torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()

# Inference ํ•จ์ˆ˜
def infer(data):
    prompt = data.get("prompt", "A futuristic cityscape")
    
    # ๋น„๋””์˜ค ์ƒ์„ฑ
    video = pipe(prompt).frames  # ํ”„๋ ˆ์ž„ ๋ฆฌ์ŠคํŠธ ๋ฐ˜ํ™˜

    # ์ฒซ ๋ฒˆ์งธ ํ”„๋ ˆ์ž„์„ ์ด๋ฏธ์ง€๋กœ ์ €์žฅ
    first_frame = video[0]

    # ์ด๋ฏธ์ง€ Base64๋กœ ๋ณ€ํ™˜
    buffered = BytesIO()
    first_frame.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")

    return {"image": img_str}