Spaces:
Configuration error
Configuration error
import gradio as gr | |
import torch | |
from torchvision import transforms | |
from SDXL.diff_pipe import StableDiffusionXLDiffImg2ImgPipeline | |
from diffusers import DPMSolverMultistepScheduler | |
# DepthAnything | |
import cv2 | |
import numpy as np | |
import os | |
from PIL import Image | |
import torch.nn.functional as F | |
from torchvision.transforms import Compose | |
import tempfile | |
from gradio_imageslider import ImageSlider | |
from depth_anything.depth_anything.dpt import DepthAnything | |
from depth_anything.depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet | |
NUM_INFERENCE_STEPS = 50 | |
dtype = torch.float16 | |
if torch.cuda.is_available(): | |
DEVICE = "cuda" | |
elif torch.backends.mps.is_available(): | |
DEVICE = "mps" | |
dtype = torch.float32 | |
else: | |
DEVICE = "cpu" | |
#device = "cuda" | |
encoder = 'vitl' # can also be 'vitb' or 'vitl' | |
model = DepthAnything.from_pretrained(f"LiheYoung/depth_anything_{encoder}14").to(DEVICE).eval() | |
base = StableDiffusionXLDiffImg2ImgPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype, variant="fp16", use_safetensors=True | |
) | |
refiner = StableDiffusionXLDiffImg2ImgPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-refiner-1.0", | |
text_encoder_2=base.text_encoder_2, | |
vae=base.vae, | |
torch_dtype=dtype, | |
use_safetensors=True, | |
variant="fp16", | |
) | |
base.scheduler = DPMSolverMultistepScheduler.from_config(base.scheduler.config) | |
refiner.scheduler = DPMSolverMultistepScheduler.from_config(base.scheduler.config) | |
# DepthAnything | |
transform = Compose([ | |
Resize( | |
width=518, | |
height=518, | |
resize_target=False, | |
keep_aspect_ratio=True, | |
ensure_multiple_of=14, | |
resize_method='lower_bound', | |
image_interpolation_method=cv2.INTER_CUBIC, | |
), | |
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
PrepareForNet(), | |
]) | |
def predict_depth(model, image): | |
return model(image) | |
def depthify(image): | |
original_image = image.copy() | |
h, w = image.shape[:2] | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0 | |
image = transform({'image': image})['image'] | |
image = torch.from_numpy(image).unsqueeze(0).to(DEVICE) | |
depth = predict_depth(model, image) | |
depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0] | |
raw_depth = Image.fromarray(depth.cpu().numpy().astype('uint8')) | |
tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False) | |
raw_depth.save(tmp.name) | |
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 | |
depth = depth.cpu().numpy().astype(np.uint8) | |
colored_depth = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)[:, :, ::-1] | |
return [(original_image, colored_depth), tmp.name, raw_depth] | |
# DifferentialDiffusion | |
def preprocess_image(image_array): | |
image = Image.fromarray(image_array) | |
image = image.convert("RGB") | |
image = transforms.CenterCrop((image.size[1] // 64 * 64, image.size[0] // 64 * 64))(image) | |
image = transforms.ToTensor()(image) | |
image = image * 2 - 1 | |
image = image.unsqueeze(0).to(DEVICE) | |
return image | |
def preprocess_map(map): | |
map = map.convert("L") | |
map = transforms.CenterCrop((map.size[1] // 64 * 64, map.size[0] // 64 * 64))(map) | |
# convert to tensor | |
map = transforms.ToTensor()(map) | |
map = map.to(DEVICE) | |
return map | |
def inference( | |
image, | |
map, | |
guidance_scale, | |
prompt, | |
negative_prompt, | |
steps, | |
denoising_start, | |
denoising_end | |
): | |
validate_inputs(image, map) | |
image = preprocess_image(image) | |
map = preprocess_map(map) | |
base_device = base.to(DEVICE) | |
edited_images = base_device( | |
prompt=prompt, | |
original_image=image, | |
image=image, | |
strength=1, | |
guidance_scale=guidance_scale, | |
num_images_per_prompt=1, | |
negative_prompt=negative_prompt, | |
map=map, | |
num_inference_steps=steps, | |
denoising_end=denoising_end, | |
output_type="latent" | |
).images | |
base_device=None | |
refiner_device = refiner.to(DEVICE) | |
edited_images = refiner_device( | |
prompt=prompt, | |
original_image=image, | |
image=edited_images, | |
strength=1, | |
guidance_scale=guidance_scale, | |
num_images_per_prompt=1, | |
negative_prompt=negative_prompt, | |
map=map, | |
num_inference_steps=steps, | |
denoising_start=denoising_start | |
).images[0] | |
refiner_device=None | |
return edited_images | |
def validate_inputs(image, map): | |
if image is None: | |
raise gr.Error("Missing image") | |
if map is None: | |
raise gr.Error("Missing map") | |
def run(image, gs, prompt, neg_prompt, steps, denoising_start, denoising_end): | |
# first run | |
[(original_image, colored_depth), name, raw_depth] = depthify(image) | |
print(f"original_image={original_image} colored_depth={colored_depth}, name={name}, raw_depth={raw_depth}") | |
return raw_depth, inference(original_image, raw_depth, gs, prompt, neg_prompt, steps, denoising_start, denoising_end) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
input_image = gr.Image(label="Input Image") | |
# change_map = gr.Image(label="Change Map", type="pil") | |
gs = gr.Slider(0, 28, value=7.5, label="Guidance Scale") | |
steps = gr.Number(value=50, label="Steps") | |
denoising_start = gr.Slider(0, 1, value=0.8, label="Denoising Start") | |
denoising_end = gr.Slider(0, 1, value=0.8, label="Denoising End") | |
prompt = gr.Textbox(label="Prompt") | |
neg_prompt = gr.Textbox(label="Negative Prompt") | |
with gr.Row(): | |
# clr_btn=gr.ClearButton(components=[input_image, change_map, gs, prompt, neg_prompt]) | |
clr_btn=gr.ClearButton(components=[input_image, gs, prompt, neg_prompt, steps, denoising_start, denoising_end]) | |
run_btn = gr.Button("Run",variant="primary") | |
with gr.Column(): | |
output = gr.Image(label="Output Image") | |
change_map = gr.Image(label="Change Map") | |
run_btn.click( | |
run, | |
#inference, | |
inputs=[input_image, gs, prompt, neg_prompt, steps, denoising_start, denoising_end], | |
outputs=[change_map, output] | |
) | |
clr_btn.add(output) | |
if __name__ == "__main__": | |
demo.launch() | |