Sagar Bharadwaj commited on
Commit
b4734e1
·
1 Parent(s): e3a0f13

Made simplify image faster

Browse files
Files changed (1) hide show
  1. colorbynumber/simplify_image.py +12 -22
colorbynumber/simplify_image.py CHANGED
@@ -1,18 +1,5 @@
1
  import numpy as np
2
 
3
- def _closest_color(pixel, color_list):
4
- """
5
- Finds the closest color in the list to the given pixel (RGB values)
6
-
7
- Args:
8
- pixel: A tuple representing the RGB values of a pixel (R, G, B).
9
- color_list: A list of tuples representing RGB values of allowed colors.
10
-
11
- Returns:
12
- A tuple representing the RGB values of the closest color in the list.
13
- """
14
- distances = np.array([np.linalg.norm(np.array(pixel) - np.array(color)) for color in color_list])
15
- return color_list[np.argmin(distances)]
16
 
17
  def simplify_image(image, color_list):
18
  """
@@ -26,12 +13,15 @@ def simplify_image(image, color_list):
26
  A copy of the image with all colors replaced with the closest color in the list.
27
  """
28
 
29
- # Replace each pixel with closest color from the list
30
- converted_image = image.copy() # Operate on a copy
31
- height, width, channels = image.shape
32
- for y in range(height):
33
- for x in range(width):
34
- pixel = image[y, x]
35
- closest = _closest_color(pixel, color_list)
36
- converted_image[y, x] = closest
37
- return converted_image
 
 
 
 
1
  import numpy as np
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  def simplify_image(image, color_list):
5
  """
 
13
  A copy of the image with all colors replaced with the closest color in the list.
14
  """
15
 
16
+ width, height, channels = image.shape
17
+ image_copy = image.reshape((width, height, 1, channels)).copy()
18
+
19
+ color_list = np.array(color_list)
20
+ num_colors = color_list.shape[0]
21
+ color_list_broadcastable = color_list.reshape((1, 1, num_colors, 3))
22
+
23
+ norm_diff = ((image_copy - color_list_broadcastable)**2).sum(axis = -1)
24
+ indices_color_choices = norm_diff.argmin(axis = -1)
25
+ simplified_image = color_list[indices_color_choices.flatten(), :].reshape(image.shape)
26
+
27
+ return simplified_image