ZhiyuanthePony's picture
update
f876753
from typing import *
import torch
import torch.nn as nn
from diffusers import DiffusionPipeline
class Pipeline(DiffusionPipeline):
"""Base class for all pipelines."""
def __init__(self):
super().__init__()
def __call__(self, *args, **kwargs):
raise NotImplementedError
def enable_xformers_memory_efficient_attention(self):
pass
def enable_model_cpu_offload(self):
pass
@property
def device(self) -> torch.device:
for model in self.models.values():
if hasattr(model, 'device'):
return model.device
for model in self.models.values():
if hasattr(model, 'parameters'):
return next(model.parameters()).device
raise RuntimeError("No device found.")
def to(self, device: torch.device) -> None:
for model in self.models.values():
model.to(device)