File size: 5,335 Bytes
bd199cf
 
 
 
 
 
d7aa376
 
 
 
 
 
 
 
 
89a1445
 
d7aa376
bd199cf
8a7960d
6ed83ba
 
 
 
8a7960d
6ed83ba
 
 
bd199cf
d7aa376
 
 
bd199cf
8a7960d
b1ae048
bd199cf
 
 
 
 
8a7960d
bd199cf
 
b1ae048
bd199cf
 
 
 
 
d7aa376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd199cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83e18ee
bd199cf
 
 
b1ae048
 
bd199cf
 
 
 
b1ae048
 
 
bd199cf
 
 
 
b1ae048
bd199cf
 
 
 
 
 
 
 
 
 
d7aa376
 
 
 
 
 
83e18ee
 
 
 
c4f7e78
 
83e18ee
 
 
 
 
 
 
c4f7e78
d7aa376
42077b5
 
d7aa376
 
 
83e18ee
bd199cf
c4f7e78
83e18ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import gradio as gr
import torch
from torchvision import transforms
from SDXL.diff_pipe import StableDiffusionXLDiffImg2ImgPipeline
from diffusers import DPMSolverMultistepScheduler

# DepthAnything
import cv2
import numpy as np
import os
from PIL import Image
import torch.nn.functional as F
from torchvision.transforms import Compose
import tempfile
from gradio_imageslider import ImageSlider
from .depth_anything.depth_anything.dpt import DepthAnything
from .depth_anything.depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet

NUM_INFERENCE_STEPS = 50
dtype = torch.float16
if torch.cuda.is_available():
  device = "cuda"
elif torch.backends.mps.is_available():
  device = "mps"
  dtype = torch.float32
else:
  device = "cpu"
#device = "cuda"

encoder = 'vitl' # can also be 'vitb' or 'vitl'
model = DepthAnything.from_pretrained(f"LiheYoung/depth_anything_{encoder}14").to(DEVICE).eval()

base = StableDiffusionXLDiffImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype, variant="fp16", use_safetensors=True
)

refiner = StableDiffusionXLDiffImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0",
    text_encoder_2=base.text_encoder_2,
    vae=base.vae,
    torch_dtype=dtype,
    use_safetensors=True,
    variant="fp16",
)

base.scheduler = DPMSolverMultistepScheduler.from_config(base.scheduler.config)
refiner.scheduler = DPMSolverMultistepScheduler.from_config(base.scheduler.config)














# DepthAnything
@torch.no_grad()
def predict_depth(model, image):
    return model(image)

def depthify(image):
    original_image = image.copy()
    h, w = image.shape[:2]
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
    image = transform({'image': image})['image']
    image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)
    depth = predict_depth(model, image)
    depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0]
    raw_depth = Image.fromarray(depth.cpu().numpy().astype('uint8'))
    tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
    raw_depth.save(tmp.name)
    depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
    depth = depth.cpu().numpy().astype(np.uint8)
    colored_depth = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)[:, :, ::-1]
    return [(original_image, colored_depth), tmp.name, raw_depth]






# DifferentialDiffusion

def preprocess_image(image):
    image = image.convert("RGB")
    image = transforms.CenterCrop((image.size[1] // 64 * 64, image.size[0] // 64 * 64))(image)
    image = transforms.ToTensor()(image)
    image = image * 2 - 1
    image = image.unsqueeze(0).to(device)
    return image


def preprocess_map(map):
    map = map.convert("L")
    map = transforms.CenterCrop((map.size[1] // 64 * 64, map.size[0] // 64 * 64))(map)
    # convert to tensor
    map = transforms.ToTensor()(map)
    map = map.to(device)
    return map


def inference(image, map, gs, prompt, negative_prompt):
    validate_inputs(image, map)
    image = preprocess_image(image)
    map = preprocess_map(map)
    base_cuda = base.to(device)
    edited_images = base_cuda(prompt=prompt, original_image=image, image=image, strength=1, guidance_scale=gs,
                         num_images_per_prompt=1,
                         negative_prompt=negative_prompt,
                         map=map,
                         num_inference_steps=NUM_INFERENCE_STEPS, denoising_end=0.8, output_type="latent").images
    base_cuda=None
    refiner_cuda = refiner.to(device)
    edited_images = refiner_cuda(prompt=prompt, original_image=image, image=edited_images, strength=1, guidance_scale=7.5,
                            num_images_per_prompt=1,
                            negative_prompt=negative_prompt,
                            map=map,
                            num_inference_steps=NUM_INFERENCE_STEPS, denoising_start=0.8).images[0]
    refiner_cuda=None
    return edited_images


def validate_inputs(image, map):
    if image is None:
        raise gr.Error("Missing image")
    if map is None:
        raise gr.Error("Missing map")


def run(image, gs, prompt, neg_prompt):
    # first run 
    [(original_image, colored_depth), name, raw_depth] = depthify(image)
    print(f"original_image={original_image} colored_depth={colored_depth}, name={name}, raw_depth={raw_depth}")
    return inference(original_image, raw_depth, gs, prompt, neg_prompt)

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            with gr.Row():
                input_image = gr.Image(label="Input Image", type="pil")
                change_map = gr.Image(label="Change Map", type="pil")
            gs = gr.Slider(0, 28, value=7.5, label="Guidance Scale")
            prompt = gr.Textbox(label="Prompt")
            neg_prompt = gr.Textbox(label="Negative Prompt")
            with gr.Row():
                clr_btn=gr.ClearButton(components=[input_image, change_map, gs, prompt, neg_prompt])
                run_btn = gr.Button("Run",variant="primary")

        output = gr.Image(label="Output Image")
    run_btn.click(
      run,
      #inference,
      inputs=[input_image, change_map, gs, prompt, neg_prompt],
      outputs=output
    )
    clr_btn.add(output)
if __name__ == "__main__":
    demo.launch()