TDN-M commited on
Commit
142ce26
·
verified ·
1 Parent(s): 935ed8d

Update app_test.py

Browse files
Files changed (1) hide show
  1. app_test.py +178 -10
app_test.py CHANGED
@@ -1,4 +1,4 @@
1
- import gradio as gr
2
  import huggingface_hub
3
 
4
  huggingface_hub.snapshot_download(
@@ -11,18 +11,186 @@ huggingface_hub.snapshot_download(
11
  local_dir_use_symlinks=False,
12
  )
13
 
 
 
 
 
 
 
 
 
 
 
14
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # specify the directory path
17
- dir_path = 'sdxl_models'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # list all files in the directory
20
- files = os.listdir(dir_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # print the list of files
23
- print(files)
 
 
 
 
24
 
25
- def infer (text):
26
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- gr.Interface(fn=infer, inputs=[gr.Textbox()], outputs=[gr.Textbox()]).launch()
 
1
+ import spaces
2
  import huggingface_hub
3
 
4
  huggingface_hub.snapshot_download(
 
11
  local_dir_use_symlinks=False,
12
  )
13
 
14
+ import gradio as gr
15
+ from diffusers import StableDiffusionXLControlNetInpaintPipeline, ControlNetModel
16
+ from rembg import remove
17
+ from PIL import Image
18
+ import torch
19
+ from ip_adapter import IPAdapterXL
20
+ from ip_adapter.utils import register_cross_attention_hook, get_net_attn_map, attnmaps2images
21
+ from PIL import Image, ImageChops, ImageEnhance
22
+ import numpy as np
23
+
24
  import os
25
+ import glob
26
+ import torch
27
+ import cv2
28
+ import argparse
29
+
30
+ import DPT.util.io
31
+
32
+ from torchvision.transforms import Compose
33
+
34
+ from DPT.dpt.models import DPTDepthModel
35
+ from DPT.dpt.midas_net import MidasNet_large
36
+ from DPT.dpt.transforms import Resize, NormalizeImage, PrepareForNet
37
+
38
+ """
39
+ Get ZeST Ready
40
+ """
41
+ base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
42
+ image_encoder_path = "models/image_encoder"
43
+ ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin"
44
+ controlnet_path = "diffusers/controlnet-depth-sdxl-1.0"
45
+ device = "cuda"
46
+ torch.cuda.empty_cache()
47
+
48
+ # load SDXL pipeline
49
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, variant="fp16", use_safetensors=True, torch_dtype=torch.float16).to(device)
50
+ pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
51
+ base_model_path,
52
+ controlnet=controlnet,
53
+ use_safetensors=True,
54
+ torch_dtype=torch.float16,
55
+ add_watermarker=False,
56
+ ).to(device)
57
+ pipe.unet = register_cross_attention_hook(pipe.unet)
58
+
59
+ ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device)
60
+
61
+
62
+ """
63
+ Get Depth Model Ready
64
+ """
65
+ model_path = "DPT/weights/dpt_hybrid-midas-501f0c75.pt"
66
+ net_w = net_h = 384
67
+ model = DPTDepthModel(
68
+ path=model_path,
69
+ backbone="vitb_rn50_384",
70
+ non_negative=True,
71
+ enable_attention_hooks=False,
72
+ )
73
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
74
+
75
+ transform = Compose(
76
+ [
77
+ Resize(
78
+ net_w,
79
+ net_h,
80
+ resize_target=None,
81
+ keep_aspect_ratio=True,
82
+ ensure_multiple_of=32,
83
+ resize_method="minimal",
84
+ image_interpolation_method=cv2.INTER_CUBIC,
85
+ ),
86
+ normalization,
87
+ PrepareForNet(),
88
+ ]
89
+ )
90
+
91
+ model.eval()
92
+
93
+ @spaces.GPU()
94
+ def greet(input_image, material_exemplar):
95
+
96
+ """
97
+ Compute depth map from input_image
98
+ """
99
+
100
+ img = np.array(input_image)
101
+
102
+ img_input = transform({"image": img})["image"]
103
+
104
+ # compute
105
+ with torch.no_grad():
106
+ sample = torch.from_numpy(img_input).unsqueeze(0)
107
+
108
+ # if optimize == True and device == torch.device("cuda"):
109
+ # sample = sample.to(memory_format=torch.channels_last)
110
+ # sample = sample.half()
111
 
112
+ prediction = model.forward(sample)
113
+ prediction = (
114
+ torch.nn.functional.interpolate(
115
+ prediction.unsqueeze(1),
116
+ size=img.shape[:2],
117
+ mode="bicubic",
118
+ align_corners=False,
119
+ )
120
+ .squeeze()
121
+ .cpu()
122
+ .numpy()
123
+ )
124
+
125
+ depth_min = prediction.min()
126
+ depth_max = prediction.max()
127
+ bits = 2
128
+ max_val = (2 ** (8 * bits)) - 1
129
 
130
+ if depth_max - depth_min > np.finfo("float").eps:
131
+ out = max_val * (prediction - depth_min) / (depth_max - depth_min)
132
+ else:
133
+ out = np.zeros(prediction.shape, dtype=depth.dtype)
134
+
135
+ out = (out / 256).astype('uint8')
136
+ depth_map = Image.fromarray(out).resize((1024, 1024))
137
+
138
+
139
+ """
140
+ Process foreground decolored image
141
+ """
142
+ rm_bg = remove(input_image)
143
+ target_mask = rm_bg.convert("RGB").point(lambda x: 0 if x < 1 else 255).convert('L').convert('RGB')
144
+ mask_target_img = ImageChops.lighter(input_image, target_mask)
145
+ invert_target_mask = ImageChops.invert(target_mask)
146
+ gray_target_image = input_image.convert('L').convert('RGB')
147
+ gray_target_image = ImageEnhance.Brightness(gray_target_image)
148
+ factor = 1.0 # Try adjusting this to get the desired brightness
149
+ gray_target_image = gray_target_image.enhance(factor)
150
+ grayscale_img = ImageChops.darker(gray_target_image, target_mask)
151
+ img_black_mask = ImageChops.darker(input_image, invert_target_mask)
152
+ grayscale_init_img = ImageChops.lighter(img_black_mask, grayscale_img)
153
+ init_img = grayscale_init_img
154
+
155
+ """
156
+ Process material exemplar and resize all images
157
+ """
158
+ ip_image = material_exemplar.resize((1024, 1024))
159
+ init_img = init_img.resize((1024,1024))
160
+ mask = target_mask.resize((1024, 1024))
161
+
162
+
163
+ num_samples = 1
164
+ images = ip_model.generate(pil_image=ip_image, image=init_img, control_image=depth_map, mask_image=mask, controlnet_conditioning_scale=0.9, num_samples=num_samples, num_inference_steps=30, seed=42)
165
+
166
+ return images[0]
167
 
168
+ css = """
169
+ #col-container{
170
+ margin: 0 auto;
171
+ max-width: 960px;
172
+ }
173
+ """
174
 
175
+ with gr.Blocks(css=css) as demo:
176
+ with gr.Column(elem_id="col-container"):
177
+ gr.Markdown("""
178
+ # ZeST: Zero-Shot Material Transfer from a Single Image
179
+ <p>Upload two images -- input image and material exemplar. (both 1024*1024 for better results) <br />
180
+ ZeST extracts the material from the exemplar and cast it onto the input image following the original lighting cues.</p>
181
+ """)
182
+ with gr.Row():
183
+ with gr.Column():
184
+ with gr.Row():
185
+ input_image = gr.Image(type="pil", label="input image")
186
+ input_image2 = gr.Image(type="pil", label = "material examplar")
187
+ submit_btn = gr.Button("Submit")
188
+ gr.Examples(
189
+ examples = [["demo_assets/input_imgs/pumpkin.png", "demo_assets/material_exemplars/cup_glaze.png"]],
190
+ inputs = [input_image, input_image2]
191
+ )
192
+ with gr.Column():
193
+ output_image = gr.Image(label="transfer result")
194
+ submit_btn.click(fn=greet, inputs=[input_image, input_image2], outputs=[output_image])
195
 
196
+ demo.queue().launch()