Maria
hw6
02c021a
import numpy as np
import torch
import cv2 as cv
import random
import os
import spaces
import gradio as gr
from rembg import remove
from PIL import Image
from transformers import pipeline
from controlnet_aux import MLSDdetector, HEDdetector, NormalBaeDetector, LineartDetector
from peft import PeftModel, LoraConfig
from diffusers import (
DiffusionPipeline,
StableDiffusionPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionControlNetImg2ImgPipeline,
DPMSolverMultistepScheduler,
PNDMScheduler,
ControlNetModel
)
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg, retrieve_timesteps
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils.torch_utils import randn_tensor
from diffusers.utils import load_image, make_image_grid
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
MAX_SEED = np.iinfo(np.int32).max
default_model = 'CompVis/stable-diffusion-v1-4'
LoRA_path = 'new_model'
CONTROLNET_MODE = {
"Canny Edge Detection" : "lllyasviel/control_v11p_sd15_canny",
"Pixel to Pixel": "lllyasviel/control_v11e_sd15_ip2p",
"HED edge detection (soft edge)" : "lllyasviel/control_v11p_sd15_softedge",
"Midas depth estimation" : "lllyasviel/control_v11f1p_sd15_depth",
"Surface Normal Estimation" : "lllyasviel/control_v11p_sd15_normalbae",
"Scribble-Based Generation" : "lllyasviel/control_v11p_sd15_scribble",
"Line Art Generation": "lllyasviel/control_v11p_sd15_lineart",
}
def get_pipe(
model_id,
use_controlnet,
controlnet_mode,
use_ip_adapter
):
if use_controlnet and use_ip_adapter:
print('Pipe with ControlNet and IPAdapter')
controlnet = ControlNetModel.from_pretrained(
CONTROLNET_MODE[controlnet_mode],
cache_dir="./models_cache"
)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
model_id if model_id!='Maria_Lashina_LoRA' else default_model,
torch_dtype=torch_dtype,
controlnet=controlnet,
safety_checker=None,
).to(device)
pipe.load_ip_adapter(
"h94/IP-Adapter",
subfolder="models",
weight_name="ip-adapter-plus_sd15.bin",
)
elif use_controlnet and not use_ip_adapter:
print('Pipe with ControlNet')
controlnet = ControlNetModel.from_pretrained(
CONTROLNET_MODE[controlnet_mode],
cache_dir="./models_cache"
)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
model_id if model_id!='Maria_Lashina_LoRA' else default_model,
torch_dtype=torch_dtype,
controlnet=controlnet,
safety_checker=None,
).to(device)
elif use_ip_adapter and not use_controlnet:
print('Pipe with IpAdapter')
pipe = StableDiffusionPipeline.from_pretrained(
model_id if model_id!='Maria_Lashina_LoRA' else default_model,
torch_dtype=torch_dtype,
safety_checker=None,
).to(device)
pipe.load_ip_adapter(
"h94/IP-Adapter",
subfolder="models",
weight_name="ip-adapter-plus_sd15.bin")
elif not use_controlnet and not use_ip_adapter:
print('Pipe with only SD')
pipe = StableDiffusionPipeline.from_pretrained(
model_id if model_id!='Maria_Lashina_LoRA' else default_model,
torch_dtype=torch_dtype,
safety_checker=None,
).to(device)
if model_id == 'Maria_Lashina_LoRA':
adapter_name = 'cartoonish mouse'
unet_sub_dir = os.path.join(LoRA_path, "unet")
text_encoder_sub_dir = os.path.join(LoRA_path, "text_encoder")
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
return pipe
def prepare_controlnet_image(controlnet_image, mode):
if mode == "Canny Edge Detection":
image = cv.Canny(controlnet_image, 80, 160)
image = np.repeat(image[:, :, None], 3, axis=2)
image = Image.fromarray(image)
elif mode == "Pixel to Pixel":
image = Image.fromarray(controlnet_image).convert('RGB')
elif mode == "HED edge detection (soft edge)":
processor = HEDdetector.from_pretrained('lllyasviel/Annotators')
image = processor(controlnet_image)
elif mode == "Midas depth estimation":
depth_estimator = pipeline('depth-estimation')
image = depth_estimator(Image.fromarray(controlnet_image))['depth']
image = np.array(image)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
image = Image.fromarray(image)
elif mode == "Surface Normal Estimation":
processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
image = processor(controlnet_image)
elif mode == "Scribble-Based Generation":
processor = HEDdetector.from_pretrained('lllyasviel/Annotators')
image = processor(controlnet_image, scribble=True)
elif mode == "Line Art Generation":
processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
image = processor(controlnet_image)
else:
image = controlnet_image
return image
# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(
model_id,
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
lora_scale,
num_inference_steps,
use_controlnet,
control_strength,
controlnet_mode,
controlnet_image,
use_ip_adapter,
ip_adapter_scale,
ip_adapter_image,
delete_background,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
if not use_controlnet and not use_ip_adapter:
pipe = get_pipe(model_id, use_controlnet, controlnet_mode, use_ip_adapter)
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
cross_attention_kwargs={"scale": lora_scale},
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator
).images[0]
elif use_controlnet and not use_ip_adapter:
cn_image = prepare_controlnet_image(controlnet_image, controlnet_mode)
pipe = get_pipe(model_id, use_controlnet, controlnet_mode, use_ip_adapter)
image = pipe(
prompt,
cn_image,
controlnet_conditioning_scale=control_strength,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
cross_attention_kwargs={"scale": lora_scale},
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator
).images[0]
elif not use_controlnet and use_ip_adapter:
pipe = get_pipe(model_id, use_controlnet, controlnet_mode, use_ip_adapter)
pipe.set_ip_adapter_scale(ip_adapter_scale)
image = pipe(
prompt,
ip_adapter_image=ip_adapter_image,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
cross_attention_kwargs={"scale": lora_scale},
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator
).images[0]
elif use_controlnet and use_ip_adapter:
cn_image = prepare_controlnet_image(controlnet_image, controlnet_mode)
pipe = get_pipe(model_id, use_controlnet, controlnet_mode, use_ip_adapter)
pipe.set_ip_adapter_scale(ip_adapter_scale)
image = pipe(
prompt,
cn_image,
controlnet_conditioning_scale=control_strength,
ip_adapter_image=ip_adapter_image,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
cross_attention_kwargs={"scale": lora_scale},
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator
).images[0]
if delete_background:
image = remove(image)
return image, seed