cocktailpeanut's picture
update
0354639
import gradio as gr
import torch
from torchvision import transforms
from SDXL.diff_pipe import StableDiffusionXLDiffImg2ImgPipeline
from diffusers import DPMSolverMultistepScheduler
# DepthAnything
import cv2
import numpy as np
import os
from PIL import Image
import torch.nn.functional as F
from torchvision.transforms import Compose
import tempfile
from gradio_imageslider import ImageSlider
from depth_anything.depth_anything.dpt import DepthAnything
from depth_anything.depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet
NUM_INFERENCE_STEPS = 50
dtype = torch.float16
if torch.cuda.is_available():
DEVICE = "cuda"
elif torch.backends.mps.is_available():
DEVICE = "mps"
dtype = torch.float32
else:
DEVICE = "cpu"
#device = "cuda"
encoder = 'vitl' # can also be 'vitb' or 'vitl'
model = DepthAnything.from_pretrained(f"LiheYoung/depth_anything_{encoder}14").to(DEVICE).eval()
base = StableDiffusionXLDiffImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype, variant="fp16", use_safetensors=True
)
refiner = StableDiffusionXLDiffImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=base.text_encoder_2,
vae=base.vae,
torch_dtype=dtype,
use_safetensors=True,
variant="fp16",
)
base.scheduler = DPMSolverMultistepScheduler.from_config(base.scheduler.config)
refiner.scheduler = DPMSolverMultistepScheduler.from_config(base.scheduler.config)
# DepthAnything
transform = Compose([
Resize(
width=518,
height=518,
resize_target=False,
keep_aspect_ratio=True,
ensure_multiple_of=14,
resize_method='lower_bound',
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
])
@torch.no_grad()
def predict_depth(model, image):
return model(image)
def depthify(image):
original_image = image.copy()
h, w = image.shape[:2]
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
image = transform({'image': image})['image']
image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)
depth = predict_depth(model, image)
depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0]
raw_depth = Image.fromarray(depth.cpu().numpy().astype('uint8'))
tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
raw_depth.save(tmp.name)
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth = depth.cpu().numpy().astype(np.uint8)
colored_depth = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)[:, :, ::-1]
return [(original_image, colored_depth), tmp.name, raw_depth]
# DifferentialDiffusion
def preprocess_image(image_array):
image = Image.fromarray(image_array)
image = image.convert("RGB")
image = transforms.CenterCrop((image.size[1] // 64 * 64, image.size[0] // 64 * 64))(image)
image = transforms.ToTensor()(image)
image = image * 2 - 1
image = image.unsqueeze(0).to(DEVICE)
return image
def preprocess_map(map):
map = map.convert("L")
map = transforms.CenterCrop((map.size[1] // 64 * 64, map.size[0] // 64 * 64))(map)
# convert to tensor
map = transforms.ToTensor()(map)
map = map.to(DEVICE)
return map
def inference(
image,
map,
guidance_scale,
prompt,
negative_prompt,
steps,
denoising_start,
denoising_end
):
validate_inputs(image, map)
image = preprocess_image(image)
map = preprocess_map(map)
base_device = base.to(DEVICE)
edited_images = base_device(
prompt=prompt,
original_image=image,
image=image,
strength=1,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
negative_prompt=negative_prompt,
map=map,
num_inference_steps=steps,
denoising_end=denoising_end,
output_type="latent"
).images
base_device=None
refiner_device = refiner.to(DEVICE)
edited_images = refiner_device(
prompt=prompt,
original_image=image,
image=edited_images,
strength=1,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
negative_prompt=negative_prompt,
map=map,
num_inference_steps=steps,
denoising_start=denoising_start
).images[0]
refiner_device=None
return edited_images
def validate_inputs(image, map):
if image is None:
raise gr.Error("Missing image")
if map is None:
raise gr.Error("Missing map")
def run(image, gs, prompt, neg_prompt, steps, denoising_start, denoising_end):
# first run
[(original_image, colored_depth), name, raw_depth] = depthify(image)
print(f"original_image={original_image} colored_depth={colored_depth}, name={name}, raw_depth={raw_depth}")
return raw_depth, inference(original_image, raw_depth, gs, prompt, neg_prompt, steps, denoising_start, denoising_end)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
with gr.Row():
input_image = gr.Image(label="Input Image")
# change_map = gr.Image(label="Change Map", type="pil")
gs = gr.Slider(0, 28, value=7.5, label="Guidance Scale")
steps = gr.Number(value=50, label="Steps")
denoising_start = gr.Slider(0, 1, value=0.8, label="Denoising Start")
denoising_end = gr.Slider(0, 1, value=0.8, label="Denoising End")
prompt = gr.Textbox(label="Prompt")
neg_prompt = gr.Textbox(label="Negative Prompt")
with gr.Row():
# clr_btn=gr.ClearButton(components=[input_image, change_map, gs, prompt, neg_prompt])
clr_btn=gr.ClearButton(components=[input_image, gs, prompt, neg_prompt, steps, denoising_start, denoising_end])
run_btn = gr.Button("Run",variant="primary")
with gr.Column():
output = gr.Image(label="Output Image")
change_map = gr.Image(label="Change Map")
run_btn.click(
run,
#inference,
inputs=[input_image, gs, prompt, neg_prompt, steps, denoising_start, denoising_end],
outputs=[change_map, output]
)
clr_btn.add(output)
if __name__ == "__main__":
demo.launch()