CyberWaifu commited on
Commit
19b8151
·
verified ·
1 Parent(s): 5bc4dea

Try to add the image preprocessing adapted from the original code.

Browse files
Files changed (1) hide show
  1. app.py +55 -20
app.py CHANGED
@@ -4,12 +4,12 @@ import numpy as np
4
  from PIL import Image
5
  import json
6
  from huggingface_hub import hf_hub_download
 
7
 
8
  # Constants
9
  MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
10
  MODEL_FILE = "camie_tagger_initial.onnx"
11
  META_FILE = "metadata.json"
12
- IMAGE_SIZE = (512, 512)
13
  DEFAULT_THRESHOLD = 0.35 # Default value if slider is not used
14
 
15
  # Download model and metadata from Hugging Face Hub
@@ -26,16 +26,51 @@ def escape_tag(tag: str) -> str:
26
  return tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)")
27
 
28
  def preprocess_image(pil_image: Image.Image) -> np.ndarray:
29
- """Convert image to RGB, resize, normalize, and rearrange dimensions."""
30
- img = pil_image.convert("RGB").resize(IMAGE_SIZE)
31
- arr = np.array(img).astype(np.float32) / 255.0
32
- arr = np.transpose(arr, (2, 0, 1))
33
- return np.expand_dims(arr, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def run_inference(pil_image: Image.Image) -> np.ndarray:
36
  """
37
  Preprocess the image and run the ONNX model inference.
38
-
39
  Returns the refined logits as a numpy array.
40
  """
41
  input_tensor = preprocess_image(pil_image)
@@ -47,7 +82,7 @@ def run_inference(pil_image: Image.Image) -> np.ndarray:
47
  def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: float):
48
  """
49
  Compute probabilities from logits and collect tag predictions.
50
-
51
  Returns:
52
  results_by_cat: Dictionary mapping each category to a list of (tag, probability) above its threshold.
53
  prompt_tags_by_cat: Dictionary for prompt-style output (character, general).
@@ -79,7 +114,7 @@ def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str:
79
  """
80
  Format the tags for prompt-style output.
81
  Only the top artist tag is shown (regardless of threshold), and all character and general tags are shown.
82
-
83
  Returns a comma-separated string of escaped tags.
84
  """
85
  # Always select the best artist tag from all_artist_tags, regardless of threshold.
@@ -87,26 +122,26 @@ def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str:
87
  if all_artist_tags:
88
  best_artist = max(all_artist_tags, key=lambda item: item[1])
89
  best_artist_tag = escape_tag(best_artist[0])
90
-
91
  # Sort character and general tags by probability (descending)
92
  for cat in prompt_tags_by_cat:
93
  prompt_tags_by_cat[cat].sort(key=lambda x: x[1], reverse=True)
94
-
95
  character_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("character", [])]
96
  general_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("general", [])]
97
-
98
  prompt_tags = []
99
  if best_artist_tag:
100
  prompt_tags.append(best_artist_tag)
101
  prompt_tags.extend(character_tags)
102
  prompt_tags.extend(general_tags)
103
-
104
  return ", ".join(prompt_tags) if prompt_tags else "No tags predicted."
105
 
106
  def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str:
107
  """
108
  Format the tags for detailed output.
109
-
110
  Returns a Markdown-formatted string listing tags by category.
111
  """
112
  if not results_by_cat:
@@ -116,7 +151,7 @@ def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str:
116
  if "artist" not in results_by_cat and all_artist_tags:
117
  best_artist_tag, best_artist_prob = max(all_artist_tags, key=lambda item: item[1])
118
  results_by_cat["artist"] = [(best_artist_tag, best_artist_prob)]
119
-
120
  lines = ["**Predicted Tags by Category:** \n"]
121
  for cat, tag_list in results_by_cat.items():
122
  tag_list.sort(key=lambda x: x[1], reverse=True)
@@ -129,15 +164,15 @@ def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str:
129
  def tag_image(pil_image: Image.Image, output_format: str, threshold: float) -> str:
130
  """
131
  Run inference on the image and return formatted tags based on the chosen output format.
132
-
133
  The slider value (threshold) overrides the default threshold for tag selection.
134
  """
135
  if pil_image is None:
136
  return "Please upload an image."
137
-
138
  refined_logits = run_inference(pil_image)
139
  results_by_cat, prompt_tags_by_cat, all_artist_tags = get_tags(refined_logits, metadata, default_threshold=threshold)
140
-
141
  if output_format == "Prompt-style Tags":
142
  return format_prompt_tags(prompt_tags_by_cat, all_artist_tags)
143
  else:
@@ -177,7 +212,7 @@ with demo:
177
 
178
  # Pass the threshold_slider value into the tag_image function
179
  tag_button.click(fn=tag_image, inputs=[image_in, format_choice, threshold_slider], outputs=output_box)
180
-
181
  gr.Markdown(
182
  "----\n"
183
  "**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime) • "
@@ -187,4 +222,4 @@ with demo:
187
  )
188
 
189
  if __name__ == "__main__":
190
- demo.launch()
 
4
  from PIL import Image
5
  import json
6
  from huggingface_hub import hf_hub_download
7
+ import torchvision.transforms as transforms
8
 
9
  # Constants
10
  MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
11
  MODEL_FILE = "camie_tagger_initial.onnx"
12
  META_FILE = "metadata.json"
 
13
  DEFAULT_THRESHOLD = 0.35 # Default value if slider is not used
14
 
15
  # Download model and metadata from Hugging Face Hub
 
26
  return tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)")
27
 
28
  def preprocess_image(pil_image: Image.Image) -> np.ndarray:
29
+ """Process an image for inference using same preprocessing as training"""
30
+
31
+ image_size=512
32
+
33
+ # Initialize the same transform used during training
34
+ transform = transforms.Compose([
35
+ transforms.ToTensor(),
36
+ ])
37
+
38
+ img = pil_image # Use the PIL image directly
39
+
40
+ # Convert RGBA or Palette images to RGB
41
+ if img.mode in ('RGBA', 'P'):
42
+ img = img.convert('RGB')
43
+
44
+ # Get original dimensions
45
+ width, height = img.size
46
+ aspect_ratio = width / height
47
+
48
+ # Calculate new dimensions to maintain aspect ratio
49
+ if aspect_ratio > 1:
50
+ new_width = image_size
51
+ new_height = int(new_width / aspect_ratio)
52
+ else:
53
+ new_height = image_size
54
+ new_width = int(new_height * aspect_ratio)
55
+
56
+ # Resize with LANCZOS filter
57
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
58
+
59
+ # Create new image with padding
60
+ new_image = Image.new('RGB', (image_size, image_size), (0, 0, 0))
61
+ paste_x = (image_size - new_width) // 2
62
+ paste_y = (image_size - new_height) // 2
63
+ new_image.paste(img, (paste_x, paste_y))
64
+
65
+ # Apply transforms (without normalization)
66
+ img_tensor = transform(new_image)
67
+ return img_tensor.numpy() # Convert the PyTorch tensor to NumPy array
68
+
69
 
70
  def run_inference(pil_image: Image.Image) -> np.ndarray:
71
  """
72
  Preprocess the image and run the ONNX model inference.
73
+
74
  Returns the refined logits as a numpy array.
75
  """
76
  input_tensor = preprocess_image(pil_image)
 
82
  def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: float):
83
  """
84
  Compute probabilities from logits and collect tag predictions.
85
+
86
  Returns:
87
  results_by_cat: Dictionary mapping each category to a list of (tag, probability) above its threshold.
88
  prompt_tags_by_cat: Dictionary for prompt-style output (character, general).
 
114
  """
115
  Format the tags for prompt-style output.
116
  Only the top artist tag is shown (regardless of threshold), and all character and general tags are shown.
117
+
118
  Returns a comma-separated string of escaped tags.
119
  """
120
  # Always select the best artist tag from all_artist_tags, regardless of threshold.
 
122
  if all_artist_tags:
123
  best_artist = max(all_artist_tags, key=lambda item: item[1])
124
  best_artist_tag = escape_tag(best_artist[0])
125
+
126
  # Sort character and general tags by probability (descending)
127
  for cat in prompt_tags_by_cat:
128
  prompt_tags_by_cat[cat].sort(key=lambda x: x[1], reverse=True)
129
+
130
  character_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("character", [])]
131
  general_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("general", [])]
132
+
133
  prompt_tags = []
134
  if best_artist_tag:
135
  prompt_tags.append(best_artist_tag)
136
  prompt_tags.extend(character_tags)
137
  prompt_tags.extend(general_tags)
138
+
139
  return ", ".join(prompt_tags) if prompt_tags else "No tags predicted."
140
 
141
  def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str:
142
  """
143
  Format the tags for detailed output.
144
+
145
  Returns a Markdown-formatted string listing tags by category.
146
  """
147
  if not results_by_cat:
 
151
  if "artist" not in results_by_cat and all_artist_tags:
152
  best_artist_tag, best_artist_prob = max(all_artist_tags, key=lambda item: item[1])
153
  results_by_cat["artist"] = [(best_artist_tag, best_artist_prob)]
154
+
155
  lines = ["**Predicted Tags by Category:** \n"]
156
  for cat, tag_list in results_by_cat.items():
157
  tag_list.sort(key=lambda x: x[1], reverse=True)
 
164
  def tag_image(pil_image: Image.Image, output_format: str, threshold: float) -> str:
165
  """
166
  Run inference on the image and return formatted tags based on the chosen output format.
167
+
168
  The slider value (threshold) overrides the default threshold for tag selection.
169
  """
170
  if pil_image is None:
171
  return "Please upload an image."
172
+
173
  refined_logits = run_inference(pil_image)
174
  results_by_cat, prompt_tags_by_cat, all_artist_tags = get_tags(refined_logits, metadata, default_threshold=threshold)
175
+
176
  if output_format == "Prompt-style Tags":
177
  return format_prompt_tags(prompt_tags_by_cat, all_artist_tags)
178
  else:
 
212
 
213
  # Pass the threshold_slider value into the tag_image function
214
  tag_button.click(fn=tag_image, inputs=[image_in, format_choice, threshold_slider], outputs=output_box)
215
+
216
  gr.Markdown(
217
  "----\n"
218
  "**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime) • "
 
222
  )
223
 
224
  if __name__ == "__main__":
225
+ demo.launch()