Update image_transformation.py

#2
Files changed (1) hide show
  1. image_transformation.py +10 -25
image_transformation.py CHANGED
@@ -17,9 +17,6 @@ if is_vision_available():
17
  if is_diffusers_available():
18
  from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, UniPCMultistepScheduler
19
 
20
- if is_opencv_available():
21
- import cv2
22
-
23
 
24
  IMAGE_TRANSFORMATION_DESCRIPTION = (
25
  "This is a tool that transforms an image according to a prompt. It takes two inputs: `image`, which should be "
@@ -30,7 +27,7 @@ IMAGE_TRANSFORMATION_DESCRIPTION = (
30
 
31
  class ImageTransformationTool(Tool):
32
  default_stable_diffusion_checkpoint = "runwayml/stable-diffusion-v1-5"
33
- default_controlnet_checkpoint = "lllyasviel/sd-controlnet-canny"
34
  description = IMAGE_TRANSFORMATION_DESCRIPTION
35
  inputs = ['image', 'text']
36
  outputs = ['image']
@@ -67,32 +64,20 @@ class ImageTransformationTool(Tool):
67
  self.stable_diffusion_checkpoint, controlnet=self.controlnet
68
  )
69
  self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)
70
- self.pipeline.enable_model_cpu_offload()
 
 
 
71
 
72
  self.is_initialized = True
73
 
74
- def __call__(self, image, prompt):
75
  if not self.is_initialized:
76
  self.setup()
77
 
78
- initial_prompt = "super-hero character, best quality, extremely detailed"
79
- prompt = initial_prompt + prompt
80
-
81
- low_threshold = 100
82
- high_threshold = 200
83
-
84
- image = np.array(image)
85
- image = cv2.Canny(image, low_threshold, high_threshold)
86
- image = image[:, :, None]
87
- image = np.concatenate([image, image, image], axis=2)
88
- canny_image = Image.fromarray(image)
89
-
90
- generator = torch.Generator(device="cpu").manual_seed(2)
91
-
92
  return self.pipeline(
93
- prompt,
94
- canny_image,
95
- negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
96
- num_inference_steps=20,
97
- generator=generator,
98
  ).images[0]
 
17
  if is_diffusers_available():
18
  from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, UniPCMultistepScheduler
19
 
 
 
 
20
 
21
  IMAGE_TRANSFORMATION_DESCRIPTION = (
22
  "This is a tool that transforms an image according to a prompt. It takes two inputs: `image`, which should be "
 
27
 
28
  class ImageTransformationTool(Tool):
29
  default_stable_diffusion_checkpoint = "runwayml/stable-diffusion-v1-5"
30
+ default_controlnet_checkpoint = "lllyasviel/control_v11e_sd15_ip2p"
31
  description = IMAGE_TRANSFORMATION_DESCRIPTION
32
  inputs = ['image', 'text']
33
  outputs = ['image']
 
64
  self.stable_diffusion_checkpoint, controlnet=self.controlnet
65
  )
66
  self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)
67
+
68
+ self.pipeline.to(self.device)
69
+ if self.device.type == "cuda":
70
+ self.pipeline.to(torch_dtype=torch.float16)
71
 
72
  self.is_initialized = True
73
 
74
+ def __call__(self, image, prompt, added_prompt=", high quality, high resolution, beautiful, aesthetic, sharp"):
75
  if not self.is_initialized:
76
  self.setup()
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  return self.pipeline(
79
+ prompt + added_prompt,
80
+ image,
81
+ negative_prompt="monochrome, lowres, worst quality, low quality",
82
+ num_inference_steps=25,
 
83
  ).images[0]