Spaces:
Running
Running
import logging | |
import sys | |
import threading | |
import torch | |
from torchvision import transforms | |
from typing import * | |
from diffusers import EulerAncestralDiscreteScheduler | |
import diffusers.schedulers.scheduling_euler_ancestral_discrete | |
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput | |
import cv2 | |
from PIL import Image | |
import numpy as np | |
def fire_in_thread(f, *args, **kwargs): | |
threading.Thread(target=f, args=args, kwargs=kwargs).start() | |
def add_logging_arguments(parser): | |
parser.add_argument( | |
"--console_log_level", | |
type=str, | |
default=None, | |
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], | |
help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO", | |
) | |
parser.add_argument( | |
"--console_log_file", | |
type=str, | |
default=None, | |
help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する", | |
) | |
parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力") | |
def setup_logging(args=None, log_level=None, reset=False): | |
if logging.root.handlers: | |
if reset: | |
# remove all handlers | |
for handler in logging.root.handlers[:]: | |
logging.root.removeHandler(handler) | |
else: | |
return | |
# log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO | |
if log_level is None and args is not None: | |
log_level = args.console_log_level | |
if log_level is None: | |
log_level = "INFO" | |
log_level = getattr(logging, log_level) | |
msg_init = None | |
if args is not None and args.console_log_file: | |
handler = logging.FileHandler(args.console_log_file, mode="w") | |
else: | |
handler = None | |
if not args or not args.console_log_simple: | |
try: | |
from rich.logging import RichHandler | |
from rich.console import Console | |
from rich.logging import RichHandler | |
handler = RichHandler(console=Console(stderr=True)) | |
except ImportError: | |
# print("rich is not installed, using basic logging") | |
msg_init = "rich is not installed, using basic logging" | |
if handler is None: | |
handler = logging.StreamHandler(sys.stdout) # same as print | |
handler.propagate = False | |
formatter = logging.Formatter( | |
fmt="%(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
) | |
handler.setFormatter(formatter) | |
logging.root.setLevel(log_level) | |
logging.root.addHandler(handler) | |
if msg_init is not None: | |
logger = logging.getLogger(__name__) | |
logger.info(msg_init) | |
def pil_resize(image, size, interpolation=Image.LANCZOS): | |
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False | |
if has_alpha: | |
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)) | |
else: | |
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
resized_pil = pil_image.resize(size, interpolation) | |
# Convert back to cv2 format | |
if has_alpha: | |
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA) | |
else: | |
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) | |
return resized_cv2 | |
# TODO make inf_utils.py | |
# region Gradual Latent hires fix | |
class GradualLatent: | |
def __init__( | |
self, | |
ratio, | |
start_timesteps, | |
every_n_steps, | |
ratio_step, | |
s_noise=1.0, | |
gaussian_blur_ksize=None, | |
gaussian_blur_sigma=0.5, | |
gaussian_blur_strength=0.5, | |
unsharp_target_x=True, | |
): | |
self.ratio = ratio | |
self.start_timesteps = start_timesteps | |
self.every_n_steps = every_n_steps | |
self.ratio_step = ratio_step | |
self.s_noise = s_noise | |
self.gaussian_blur_ksize = gaussian_blur_ksize | |
self.gaussian_blur_sigma = gaussian_blur_sigma | |
self.gaussian_blur_strength = gaussian_blur_strength | |
self.unsharp_target_x = unsharp_target_x | |
def __str__(self) -> str: | |
return ( | |
f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, " | |
+ f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, " | |
+ f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, " | |
+ f"unsharp_target_x={self.unsharp_target_x})" | |
) | |
def apply_unshark_mask(self, x: torch.Tensor): | |
if self.gaussian_blur_ksize is None: | |
return x | |
blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma) | |
# mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength) | |
mask = (x - blurred) * self.gaussian_blur_strength | |
sharpened = x + mask | |
return sharpened | |
def interpolate(self, x: torch.Tensor, resized_size, unsharp=True): | |
org_dtype = x.dtype | |
if org_dtype == torch.bfloat16: | |
x = x.float() | |
x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype) | |
# apply unsharp mask / アンシャープマスクを適用する | |
if unsharp and self.gaussian_blur_ksize: | |
x = self.apply_unshark_mask(x) | |
return x | |
class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.resized_size = None | |
self.gradual_latent = None | |
def set_gradual_latent_params(self, size, gradual_latent: GradualLatent): | |
self.resized_size = size | |
self.gradual_latent = gradual_latent | |
def step( | |
self, | |
model_output: torch.FloatTensor, | |
timestep: Union[float, torch.FloatTensor], | |
sample: torch.FloatTensor, | |
generator: Optional[torch.Generator] = None, | |
return_dict: bool = True, | |
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: | |
""" | |
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion | |
process from the learned model outputs (most often the predicted noise). | |
Args: | |
model_output (`torch.FloatTensor`): | |
The direct output from learned diffusion model. | |
timestep (`float`): | |
The current discrete timestep in the diffusion chain. | |
sample (`torch.FloatTensor`): | |
A current instance of a sample created by the diffusion process. | |
generator (`torch.Generator`, *optional*): | |
A random number generator. | |
return_dict (`bool`): | |
Whether or not to return a | |
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. | |
Returns: | |
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: | |
If return_dict is `True`, | |
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, | |
otherwise a tuple is returned where the first element is the sample tensor. | |
""" | |
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): | |
raise ValueError( | |
( | |
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" | |
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" | |
" one of the `scheduler.timesteps` as a timestep." | |
), | |
) | |
if not self.is_scale_input_called: | |
# logger.warning( | |
print( | |
"The `scale_model_input` function should be called before `step` to ensure correct denoising. " | |
"See `StableDiffusionPipeline` for a usage example." | |
) | |
if self.step_index is None: | |
self._init_step_index(timestep) | |
sigma = self.sigmas[self.step_index] | |
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise | |
if self.config.prediction_type == "epsilon": | |
pred_original_sample = sample - sigma * model_output | |
elif self.config.prediction_type == "v_prediction": | |
# * c_out + input * c_skip | |
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) | |
elif self.config.prediction_type == "sample": | |
raise NotImplementedError("prediction_type not implemented yet: sample") | |
else: | |
raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`") | |
sigma_from = self.sigmas[self.step_index] | |
sigma_to = self.sigmas[self.step_index + 1] | |
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 | |
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 | |
# 2. Convert to an ODE derivative | |
derivative = (sample - pred_original_sample) / sigma | |
dt = sigma_down - sigma | |
device = model_output.device | |
if self.resized_size is None: | |
prev_sample = sample + derivative * dt | |
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( | |
model_output.shape, dtype=model_output.dtype, device=device, generator=generator | |
) | |
s_noise = 1.0 | |
else: | |
print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape) | |
s_noise = self.gradual_latent.s_noise | |
if self.gradual_latent.unsharp_target_x: | |
prev_sample = sample + derivative * dt | |
prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size) | |
else: | |
sample = self.gradual_latent.interpolate(sample, self.resized_size) | |
derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False) | |
prev_sample = sample + derivative * dt | |
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( | |
(model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]), | |
dtype=model_output.dtype, | |
device=device, | |
generator=generator, | |
) | |
prev_sample = prev_sample + noise * sigma_up * s_noise | |
# upon completion increase step index by one | |
self._step_index += 1 | |
if not return_dict: | |
return (prev_sample,) | |
return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) | |
# endregion | |