cocktailpeanut commited on
Commit
d473aed
·
1 Parent(s): 89a1445
Files changed (1) hide show
  1. app.py +14 -13
app.py CHANGED
@@ -13,18 +13,18 @@ import torch.nn.functional as F
13
  from torchvision.transforms import Compose
14
  import tempfile
15
  from gradio_imageslider import ImageSlider
16
- from .depth_anything.depth_anything.dpt import DepthAnything
17
- from .depth_anything.depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet
18
 
19
  NUM_INFERENCE_STEPS = 50
20
  dtype = torch.float16
21
  if torch.cuda.is_available():
22
- device = "cuda"
23
  elif torch.backends.mps.is_available():
24
- device = "mps"
25
  dtype = torch.float32
26
  else:
27
- device = "cpu"
28
  #device = "cuda"
29
 
30
  encoder = 'vitl' # can also be 'vitb' or 'vitl'
@@ -92,7 +92,7 @@ def preprocess_image(image):
92
  image = transforms.CenterCrop((image.size[1] // 64 * 64, image.size[0] // 64 * 64))(image)
93
  image = transforms.ToTensor()(image)
94
  image = image * 2 - 1
95
- image = image.unsqueeze(0).to(device)
96
  return image
97
 
98
 
@@ -101,7 +101,7 @@ def preprocess_map(map):
101
  map = transforms.CenterCrop((map.size[1] // 64 * 64, map.size[0] // 64 * 64))(map)
102
  # convert to tensor
103
  map = transforms.ToTensor()(map)
104
- map = map.to(device)
105
  return map
106
 
107
 
@@ -109,14 +109,14 @@ def inference(image, map, gs, prompt, negative_prompt):
109
  validate_inputs(image, map)
110
  image = preprocess_image(image)
111
  map = preprocess_map(map)
112
- base_cuda = base.to(device)
113
  edited_images = base_cuda(prompt=prompt, original_image=image, image=image, strength=1, guidance_scale=gs,
114
  num_images_per_prompt=1,
115
  negative_prompt=negative_prompt,
116
  map=map,
117
  num_inference_steps=NUM_INFERENCE_STEPS, denoising_end=0.8, output_type="latent").images
118
  base_cuda=None
119
- refiner_cuda = refiner.to(device)
120
  edited_images = refiner_cuda(prompt=prompt, original_image=image, image=edited_images, strength=1, guidance_scale=7.5,
121
  num_images_per_prompt=1,
122
  negative_prompt=negative_prompt,
@@ -144,20 +144,21 @@ with gr.Blocks() as demo:
144
  with gr.Column():
145
  with gr.Row():
146
  input_image = gr.Image(label="Input Image", type="pil")
147
- change_map = gr.Image(label="Change Map", type="pil")
148
  gs = gr.Slider(0, 28, value=7.5, label="Guidance Scale")
149
  prompt = gr.Textbox(label="Prompt")
150
  neg_prompt = gr.Textbox(label="Negative Prompt")
151
  with gr.Row():
152
- clr_btn=gr.ClearButton(components=[input_image, change_map, gs, prompt, neg_prompt])
 
153
  run_btn = gr.Button("Run",variant="primary")
154
 
155
  output = gr.Image(label="Output Image")
156
  run_btn.click(
157
  run,
158
  #inference,
159
- inputs=[input_image, change_map, gs, prompt, neg_prompt],
160
- outputs=output
161
  )
162
  clr_btn.add(output)
163
  if __name__ == "__main__":
 
13
  from torchvision.transforms import Compose
14
  import tempfile
15
  from gradio_imageslider import ImageSlider
16
+ from depth_anything.depth_anything.dpt import DepthAnything
17
+ from depth_anything.depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet
18
 
19
  NUM_INFERENCE_STEPS = 50
20
  dtype = torch.float16
21
  if torch.cuda.is_available():
22
+ DEVICE = "cuda"
23
  elif torch.backends.mps.is_available():
24
+ DEVICE = "mps"
25
  dtype = torch.float32
26
  else:
27
+ DEVICE = "cpu"
28
  #device = "cuda"
29
 
30
  encoder = 'vitl' # can also be 'vitb' or 'vitl'
 
92
  image = transforms.CenterCrop((image.size[1] // 64 * 64, image.size[0] // 64 * 64))(image)
93
  image = transforms.ToTensor()(image)
94
  image = image * 2 - 1
95
+ image = image.unsqueeze(0).to(DEVICE)
96
  return image
97
 
98
 
 
101
  map = transforms.CenterCrop((map.size[1] // 64 * 64, map.size[0] // 64 * 64))(map)
102
  # convert to tensor
103
  map = transforms.ToTensor()(map)
104
+ map = map.to(DEVICE)
105
  return map
106
 
107
 
 
109
  validate_inputs(image, map)
110
  image = preprocess_image(image)
111
  map = preprocess_map(map)
112
+ base_cuda = base.to(DEVICE)
113
  edited_images = base_cuda(prompt=prompt, original_image=image, image=image, strength=1, guidance_scale=gs,
114
  num_images_per_prompt=1,
115
  negative_prompt=negative_prompt,
116
  map=map,
117
  num_inference_steps=NUM_INFERENCE_STEPS, denoising_end=0.8, output_type="latent").images
118
  base_cuda=None
119
+ refiner_cuda = refiner.to(DEVICE)
120
  edited_images = refiner_cuda(prompt=prompt, original_image=image, image=edited_images, strength=1, guidance_scale=7.5,
121
  num_images_per_prompt=1,
122
  negative_prompt=negative_prompt,
 
144
  with gr.Column():
145
  with gr.Row():
146
  input_image = gr.Image(label="Input Image", type="pil")
147
+ # change_map = gr.Image(label="Change Map", type="pil")
148
  gs = gr.Slider(0, 28, value=7.5, label="Guidance Scale")
149
  prompt = gr.Textbox(label="Prompt")
150
  neg_prompt = gr.Textbox(label="Negative Prompt")
151
  with gr.Row():
152
+ # clr_btn=gr.ClearButton(components=[input_image, change_map, gs, prompt, neg_prompt])
153
+ clr_btn=gr.ClearButton(components=[input_image, gs, prompt, neg_prompt])
154
  run_btn = gr.Button("Run",variant="primary")
155
 
156
  output = gr.Image(label="Output Image")
157
  run_btn.click(
158
  run,
159
  #inference,
160
+ inputs=[input_image, gs, prompt, neg_prompt],
161
+ outputs=[change_map, output]
162
  )
163
  clr_btn.add(output)
164
  if __name__ == "__main__":