cella110n commited on
Commit
97cfcda
·
verified ·
1 Parent(s): 3a4781c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +549 -549
app.py CHANGED
@@ -1,549 +1,549 @@
1
- import gradio as gr
2
- import numpy as np
3
- from PIL import Image # Keep PIL for now, might be needed by helpers implicitly
4
- # from PIL import Image, ImageDraw, ImageFont # No drawing yet
5
- import json
6
- import os
7
- import io
8
- import requests
9
- import matplotlib.pyplot as plt # For visualization
10
- import matplotlib # For backend setting
11
- from huggingface_hub import hf_hub_download
12
- from dataclasses import dataclass
13
- from typing import List, Dict, Optional, Tuple
14
- import time
15
- import spaces # Required for @spaces.GPU
16
- import onnxruntime as ort # Use ONNX Runtime
17
-
18
- import torch # Keep torch for device check in Tagger
19
- import timm # Restore timm
20
- from safetensors.torch import load_file as safe_load_file # Restore safetensors loading
21
-
22
- # MatplotlibのバックエンドをAggに設定 (Keep commented out for now)
23
- # matplotlib.use('Agg')
24
-
25
- # --- Data Classes and Helper Functions ---
26
- @dataclass
27
- class LabelData:
28
- names: list[str]
29
- rating: list[np.int64]
30
- general: list[np.int64]
31
- artist: list[np.int64]
32
- character: list[np.int64]
33
- copyright: list[np.int64]
34
- meta: list[np.int64]
35
- quality: list[np.int64]
36
-
37
- def pil_ensure_rgb(image: Image.Image) -> Image.Image:
38
- if image.mode not in ["RGB", "RGBA"]:
39
- image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
40
- if image.mode == "RGBA":
41
- background = Image.new("RGB", image.size, (255, 255, 255))
42
- background.paste(image, mask=image.split()[3])
43
- image = background
44
- return image
45
-
46
- def pil_pad_square(image: Image.Image) -> Image.Image:
47
- width, height = image.size
48
- if width == height: return image
49
- new_size = max(width, height)
50
- new_image = Image.new(image.mode, (new_size, new_size), (255, 255, 255)) # Use image.mode
51
- paste_position = ((new_size - width) // 2, (new_size - height) // 2)
52
- new_image.paste(image, paste_position)
53
- return new_image
54
-
55
- def load_tag_mapping(mapping_path):
56
- # Use the implementation from the original app.py as it was confirmed working
57
- with open(mapping_path, 'r', encoding='utf-8') as f: tag_mapping_data = json.load(f)
58
- # Check format compatibility (can be dict of dicts or dict with idx_to_tag/tag_to_category)
59
- if isinstance(tag_mapping_data, dict) and "idx_to_tag" in tag_mapping_data:
60
- idx_to_tag = {int(k): v for k, v in tag_mapping_data["idx_to_tag"].items()}
61
- tag_to_category = tag_mapping_data["tag_to_category"]
62
- elif isinstance(tag_mapping_data, dict):
63
- # Assuming the dict-of-dicts format from previous tests
64
- try:
65
- tag_mapping_data_int_keys = {int(k): v for k, v in tag_mapping_data.items()}
66
- idx_to_tag = {idx: data['tag'] for idx, data in tag_mapping_data_int_keys.items()}
67
- tag_to_category = {data['tag']: data['category'] for data in tag_mapping_data_int_keys.values()}
68
- except (KeyError, ValueError) as e:
69
- raise ValueError(f"Unsupported tag mapping format (dict): {e}. Expected int keys with 'tag' and 'category'.")
70
- else:
71
- raise ValueError("Unsupported tag mapping format: Expected a dictionary.")
72
-
73
- names = [None] * (max(idx_to_tag.keys()) + 1)
74
- rating, general, artist, character, copyright, meta, quality = [], [], [], [], [], [], []
75
- for idx, tag in idx_to_tag.items():
76
- if idx >= len(names): names.extend([None] * (idx - len(names) + 1))
77
- names[idx] = tag
78
- category = tag_to_category.get(tag, 'Unknown') # Handle missing category mapping gracefully
79
- idx_int = int(idx)
80
- if category == 'Rating': rating.append(idx_int)
81
- elif category == 'General': general.append(idx_int)
82
- elif category == 'Artist': artist.append(idx_int)
83
- elif category == 'Character': character.append(idx_int)
84
- elif category == 'Copyright': copyright.append(idx_int)
85
- elif category == 'Meta': meta.append(idx_int)
86
- elif category == 'Quality': quality.append(idx_int)
87
-
88
- return LabelData(names=names, rating=np.array(rating, dtype=np.int64), general=np.array(general, dtype=np.int64), artist=np.array(artist, dtype=np.int64),
89
- character=np.array(character, dtype=np.int64), copyright=np.array(copyright, dtype=np.int64), meta=np.array(meta, dtype=np.int64), quality=np.array(quality, dtype=np.int64)), idx_to_tag, tag_to_category
90
-
91
- def preprocess_image(image: Image.Image, target_size=(448, 448)):
92
- # Adapted from onnx_predict.py's version
93
- image = pil_ensure_rgb(image)
94
- image = pil_pad_square(image)
95
- image_resized = image.resize(target_size, Image.BICUBIC)
96
- img_array = np.array(image_resized, dtype=np.float32) / 255.0
97
- img_array = img_array.transpose(2, 0, 1) # HWC -> CHW
98
- # Assuming model expects RGB based on original code, no BGR conversion here
99
- img_array = img_array[::-1, :, :] # BGR conversion if needed - UNCOMMENTED based on user feedback
100
- mean = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
101
- std = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
102
- img_array = (img_array - mean) / std
103
- img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
104
- return image, img_array
105
-
106
- # Add get_tags function (from onnx_predict.py)
107
- def get_tags(probs, labels: LabelData, gen_threshold, char_threshold):
108
- result = {
109
- "rating": [],
110
- "general": [],
111
- "character": [],
112
- "copyright": [],
113
- "artist": [],
114
- "meta": [],
115
- "quality": []
116
- }
117
- # Rating (select max)
118
- if len(labels.rating) > 0:
119
- # Ensure indices are within bounds
120
- valid_indices = labels.rating[labels.rating < len(probs)]
121
- if len(valid_indices) > 0:
122
- rating_probs = probs[valid_indices]
123
- if len(rating_probs) > 0:
124
- rating_idx_local = np.argmax(rating_probs)
125
- rating_idx_global = valid_indices[rating_idx_local]
126
- # Check if global index is valid for names list
127
- if rating_idx_global < len(labels.names) and labels.names[rating_idx_global] is not None:
128
- rating_name = labels.names[rating_idx_global]
129
- rating_conf = float(rating_probs[rating_idx_local])
130
- result["rating"].append((rating_name, rating_conf))
131
- else:
132
- print(f"Warning: Invalid global index {rating_idx_global} for rating tag.")
133
- else:
134
- print("Warning: rating_probs became empty after filtering.")
135
- else:
136
- print("Warning: No valid indices found for rating tags within probs length.")
137
-
138
- # Quality (select max)
139
- if len(labels.quality) > 0:
140
- valid_indices = labels.quality[labels.quality < len(probs)]
141
- if len(valid_indices) > 0:
142
- quality_probs = probs[valid_indices]
143
- if len(quality_probs) > 0:
144
- quality_idx_local = np.argmax(quality_probs)
145
- quality_idx_global = valid_indices[quality_idx_local]
146
- if quality_idx_global < len(labels.names) and labels.names[quality_idx_global] is not None:
147
- quality_name = labels.names[quality_idx_global]
148
- quality_conf = float(quality_probs[quality_idx_local])
149
- result["quality"].append((quality_name, quality_conf))
150
- else:
151
- print(f"Warning: Invalid global index {quality_idx_global} for quality tag.")
152
- else:
153
- print("Warning: quality_probs became empty after filtering.")
154
- else:
155
- print("Warning: No valid indices found for quality tags within probs length.")
156
-
157
- # Threshold-based categories
158
- category_map = {
159
- "general": (labels.general, gen_threshold),
160
- "character": (labels.character, char_threshold),
161
- "copyright": (labels.copyright, char_threshold),
162
- "artist": (labels.artist, char_threshold),
163
- "meta": (labels.meta, gen_threshold) # Use gen_threshold for meta as per original code
164
- }
165
- for category, (indices, threshold) in category_map.items():
166
- if len(indices) > 0:
167
- valid_indices = indices[(indices < len(probs))] # Check index bounds first
168
- if len(valid_indices) > 0:
169
- category_probs = probs[valid_indices]
170
- mask = category_probs >= threshold
171
- selected_indices_local = np.where(mask)[0]
172
- if len(selected_indices_local) > 0:
173
- selected_indices_global = valid_indices[selected_indices_local]
174
- selected_probs = category_probs[selected_indices_local]
175
- for idx_global, prob_val in zip(selected_indices_global, selected_probs):
176
- # Check if global index is valid for names list
177
- if idx_global < len(labels.names) and labels.names[idx_global] is not None:
178
- result[category].append((labels.names[idx_global], float(prob_val)))
179
- else:
180
- print(f"Warning: Invalid global index {idx_global} for {category} tag.")
181
- # else: print(f"No tags found for category '{category}' above threshold {threshold}")
182
- # else: print(f"No valid indices found for category '{category}' within probs length.")
183
- # else: print(f"No indices defined for category '{category}'")
184
-
185
- # Sort by probability (descending)
186
- for k in result:
187
- result[k] = sorted(result[k], key=lambda x: x[1], reverse=True)
188
- return result
189
-
190
- # Add visualize_predictions function (Adapted from onnx_predict.py and previous versions)
191
- def visualize_predictions(image: Image.Image, predictions: Dict, threshold: float):
192
- # Filter out unwanted meta tags (e.g., id, commentary, request, mismatch)
193
- filtered_meta = []
194
- excluded_meta_patterns = ['id', 'commentary', 'request', 'mismatch']
195
- for tag, prob in predictions.get("meta", []):
196
- if not any(pattern in tag.lower() for pattern in excluded_meta_patterns):
197
- filtered_meta.append((tag, prob))
198
- predictions["meta"] = filtered_meta # Use filtered list for visualization
199
-
200
- # --- Plotting Setup ---
201
- plt.rcParams['font.family'] = 'DejaVu Sans'
202
- fig = plt.figure(figsize=(8, 12), dpi=100)
203
- ax_tags = fig.add_subplot(1, 1, 1)
204
-
205
- all_tags, all_probs, all_colors = [], [], []
206
- color_map = {
207
- 'rating': 'red', 'character': 'blue', 'copyright': 'purple',
208
- 'artist': 'orange', 'general': 'green', 'meta': 'gray', 'quality': 'yellow'
209
- }
210
-
211
- # Aggregate tags from predictions dictionary
212
- for cat, prefix, color in [
213
- ('rating', 'R', color_map['rating']), ('quality', 'Q', color_map['quality']),
214
- ('character', 'C', color_map['character']), ('copyright', '©', color_map['copyright']),
215
- ('artist', 'A', color_map['artist']), ('general', 'G', color_map['general']),
216
- ('meta', 'M', color_map['meta'])
217
- ]:
218
- sorted_tags = sorted(predictions.get(cat, []), key=lambda x: x[1], reverse=True)
219
- for tag, prob in sorted_tags:
220
- all_tags.append(f"[{prefix}] {tag.replace('_', ' ')}")
221
- all_probs.append(prob)
222
- all_colors.append(color)
223
-
224
- if not all_tags:
225
- ax_tags.text(0.5, 0.5, "No tags found above threshold", ha='center', va='center')
226
- ax_tags.set_title(f"Tags (Threshold ≳ {threshold:.2f})")
227
- ax_tags.axis('off')
228
- else:
229
- sorted_indices = sorted(range(len(all_probs)), key=lambda i: all_probs[i])
230
- all_tags = [all_tags[i] for i in sorted_indices]
231
- all_probs = [all_probs[i] for i in sorted_indices]
232
- all_colors = [all_colors[i] for i in sorted_indices]
233
-
234
- num_tags = len(all_tags)
235
- bar_height = min(0.8, max(0.1, 0.8 * (30 / num_tags))) if num_tags > 30 else 0.8
236
- y_positions = np.arange(num_tags)
237
-
238
- bars = ax_tags.barh(y_positions, all_probs, height=bar_height, color=all_colors)
239
- ax_tags.set_yticks(y_positions)
240
- ax_tags.set_yticklabels(all_tags)
241
-
242
- fontsize = 10 if num_tags <= 40 else 8 if num_tags <= 60 else 6
243
- for lbl in ax_tags.get_yticklabels():
244
- lbl.set_fontsize(fontsize)
245
-
246
- for i, (bar, prob) in enumerate(zip(bars, all_probs)):
247
- text_x = min(prob + 0.02, 0.98)
248
- ax_tags.text(text_x, y_positions[i], f"{prob:.3f}", va='center', fontsize=fontsize)
249
-
250
- ax_tags.set_xlim(0, 1)
251
- ax_tags.set_title(f"Tags (Threshold ≳ {threshold:.2f})")
252
-
253
- from matplotlib.patches import Patch
254
- legend_elements = [
255
- Patch(facecolor=color, label=cat.capitalize())
256
- for cat, color in color_map.items()
257
- if any(t.startswith(f"[{cat[0].upper() if cat!='copyright' else '©'}]") for t in all_tags)
258
- ]
259
- if legend_elements:
260
- ax_tags.legend(handles=legend_elements, loc='lower right', fontsize=8)
261
-
262
- plt.tight_layout()
263
- buf = io.BytesIO()
264
- plt.savefig(buf, format='png', dpi=100)
265
- plt.close(fig)
266
- buf.seek(0)
267
- return Image.open(buf)
268
-
269
- # --- Constants ---
270
- REPO_ID = "celstk/wd-eva02-lora-onnx"
271
- # Model options
272
- MODEL_OPTIONS = {
273
- "cl_eva02_tagger_v1_250426": "cl_eva02_tagger_v1_250426/model.onnx",
274
- "cl_eva02_tagger_v1_250427": "cl_eva02_tagger_v1_250427/model.onnx",
275
- "cl_eva02_tagger_v1_250430": "cl_eva02_tagger_v1_250430/model.onnx",
276
- "cl_eva02_tagger_v1_250501": "cl_eva02_tagger_v1_250501/model.onnx",
277
- "cl_eva02_tagger_v1_250502": "cl_eva02_tagger_v1_250502/model.onnx"
278
- }
279
- DEFAULT_MODEL = "cl_eva02_tagger_v1_250502"
280
- CACHE_DIR = "./model_cache"
281
-
282
- # --- Global variables for paths (initialized at startup) ---
283
- g_onnx_model_path = None
284
- g_tag_mapping_path = None
285
- g_labels_data = None
286
- g_idx_to_tag = None
287
- g_tag_to_category = None
288
- g_current_model = None
289
-
290
- # --- Initialization Function ---
291
- def initialize_onnx_paths(model_choice=DEFAULT_MODEL):
292
- global g_onnx_model_path, g_tag_mapping_path, g_labels_data, g_idx_to_tag, g_tag_to_category, g_current_model
293
-
294
- if not model_choice in MODEL_OPTIONS:
295
- print(f"Invalid model choice: {model_choice}, falling back to default: {DEFAULT_MODEL}")
296
- model_choice = DEFAULT_MODEL
297
-
298
- g_current_model = model_choice
299
- model_dir = model_choice
300
- onnx_filename = MODEL_OPTIONS[model_choice]
301
- tag_mapping_filename = f"{model_dir}/tag_mapping.json"
302
-
303
- print(f"Initializing ONNX paths and labels for model: {model_choice}...")
304
- hf_token = os.environ.get("HF_TOKEN")
305
- try:
306
- print(f"Attempting to download ONNX model: {onnx_filename}")
307
- g_onnx_model_path = hf_hub_download(repo_id=REPO_ID, filename=onnx_filename, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
308
- print(f"ONNX model path: {g_onnx_model_path}")
309
-
310
- print(f"Attempting to download Tag mapping: {tag_mapping_filename}")
311
- g_tag_mapping_path = hf_hub_download(repo_id=REPO_ID, filename=tag_mapping_filename, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
312
- print(f"Tag mapping path: {g_tag_mapping_path}")
313
-
314
- print("Loading labels from mapping...")
315
- g_labels_data, g_idx_to_tag, g_tag_to_category = load_tag_mapping(g_tag_mapping_path)
316
- print(f"Labels loaded. Count: {len(g_labels_data.names)}")
317
-
318
- return True
319
-
320
- except Exception as e:
321
- print(f"Error during initialization: {e}")
322
- import traceback; traceback.print_exc()
323
- # Reset globals to force reinitialization
324
- g_onnx_model_path = None
325
- g_tag_mapping_path = None
326
- g_labels_data = None
327
- g_idx_to_tag = None
328
- g_tag_to_category = None
329
- g_current_model = None
330
- # Raise Gradio error to make it visible in the UI
331
- raise gr.Error(f"Initialization failed: {e}. Check logs and HF_TOKEN.")
332
-
333
- # Function to handle model change
334
- def change_model(model_choice):
335
- try:
336
- success = initialize_onnx_paths(model_choice)
337
- if success:
338
- return f"Model changed to: {model_choice}"
339
- else:
340
- return "Failed to change model. See logs for details."
341
- except Exception as e:
342
- return f"Error changing model: {str(e)}"
343
-
344
- # --- Main Prediction Function (ONNX) ---
345
- @spaces.GPU()
346
- def predict_onnx(image_input, model_choice, gen_threshold, char_threshold, output_mode):
347
- print(f"--- predict_onnx function started (GPU worker) with model {model_choice} ---")
348
-
349
- # Ensure current model matches selected model
350
- global g_current_model
351
- if g_current_model != model_choice:
352
- print(f"Model mismatch! Current: {g_current_model}, Selected: {model_choice}. Reinitializing...")
353
- try:
354
- initialize_onnx_paths(model_choice)
355
- except Exception as e:
356
- return f"Error initializing model '{model_choice}': {str(e)}", None
357
-
358
- # --- 1. Ensure paths and labels are loaded ---
359
- if g_onnx_model_path is None or g_labels_data is None:
360
- message = "Error: Paths or labels not initialized. Check startup logs."
361
- print(message)
362
- # Return error message and None for the image output
363
- return message, None
364
-
365
- # --- 2. Load ONNX Session (inside worker) ---
366
- session = None
367
- try:
368
- print(f"Loading ONNX session from: {g_onnx_model_path}")
369
- available_providers = ort.get_available_providers()
370
- providers = []
371
- if 'CUDAExecutionProvider' in available_providers:
372
- providers.append('CUDAExecutionProvider')
373
- providers.append('CPUExecutionProvider')
374
- print(f"Attempting to load session with providers: {providers}")
375
- session = ort.InferenceSession(g_onnx_model_path, providers=providers)
376
- print(f"ONNX session loaded using: {session.get_providers()[0]}")
377
- except Exception as e:
378
- message = f"Error loading ONNX session in worker: {e}"
379
- print(message)
380
- import traceback; traceback.print_exc()
381
- return message, None
382
-
383
- # --- 3. Process Input Image ---
384
- if image_input is None:
385
- return "Please upload an image.", None
386
-
387
- print(f"Processing image with thresholds: gen={gen_threshold}, char={char_threshold}")
388
- try:
389
- # Handle different input types (PIL, numpy, URL, file path)
390
- if isinstance(image_input, str):
391
- if image_input.startswith("http"): # URL
392
- response = requests.get(image_input, timeout=10)
393
- response.raise_for_status()
394
- image = Image.open(io.BytesIO(response.content))
395
- elif os.path.exists(image_input): # File path
396
- image = Image.open(image_input)
397
- else:
398
- raise ValueError(f"Invalid image input string: {image_input}")
399
- elif isinstance(image_input, np.ndarray):
400
- image = Image.fromarray(image_input)
401
- elif isinstance(image_input, Image.Image):
402
- image = image_input # Already a PIL image
403
- else:
404
- raise TypeError(f"Unsupported image input type: {type(image_input)}")
405
-
406
- # Preprocess the PIL image
407
- original_pil_image, input_tensor = preprocess_image(image)
408
-
409
- # Ensure input tensor is float32, as expected by most ONNX models
410
- # (even if the model internally uses float16)
411
- input_tensor = input_tensor.astype(np.float32)
412
-
413
- except Exception as e:
414
- message = f"Error processing input image: {e}"
415
- print(message)
416
- return message, None
417
-
418
- # --- 4. Run Inference ---
419
- try:
420
- input_name = session.get_inputs()[0].name
421
- output_name = session.get_outputs()[0].name
422
- print(f"Running inference with input '{input_name}', output '{output_name}'")
423
- start_time = time.time()
424
- outputs = session.run([output_name], {input_name: input_tensor})[0]
425
- inference_time = time.time() - start_time
426
- print(f"Inference completed in {inference_time:.3f} seconds")
427
-
428
- # Check for NaN/Inf in outputs
429
- if np.isnan(outputs).any() or np.isinf(outputs).any():
430
- print("Warning: NaN or Inf detected in model output. Clamping...")
431
- outputs = np.nan_to_num(outputs, nan=0.0, posinf=1.0, neginf=0.0) # Clamp to 0-1 range
432
-
433
- # Apply sigmoid (outputs are likely logits)
434
- # Use a stable sigmoid implementation
435
- def stable_sigmoid(x):
436
- return 1 / (1 + np.exp(-np.clip(x, -30, 30))) # Clip to avoid overflow
437
- probs = stable_sigmoid(outputs[0]) # Assuming batch size 1
438
-
439
- except Exception as e:
440
- message = f"Error during ONNX inference: {e}"
441
- print(message)
442
- import traceback; traceback.print_exc()
443
- return message, None
444
- finally:
445
- # Clean up session if needed (might reduce memory usage between clicks)
446
- del session
447
-
448
- # --- 5. Post-process and Format Output ---
449
- try:
450
- print("Post-processing results...")
451
- # Use the correct global variable for labels
452
- predictions = get_tags(probs, g_labels_data, gen_threshold, char_threshold)
453
-
454
- # Format output text string
455
- output_tags = []
456
- if predictions.get("rating"): output_tags.append(predictions["rating"][0][0].replace("_", " "))
457
- if predictions.get("quality"): output_tags.append(predictions["quality"][0][0].replace("_", " "))
458
- # Add other categories, respecting order and filtering meta if needed
459
- for category in ["artist", "character", "copyright", "general", "meta"]:
460
- tags_in_category = predictions.get(category, [])
461
- for tag, prob in tags_in_category:
462
- # Basic meta tag filtering for text output
463
- if category == "meta" and any(p in tag.lower() for p in ['id', 'commentary', 'request', 'mismatch']):
464
- continue
465
- output_tags.append(tag.replace("_", " "))
466
- output_text = ", ".join(output_tags)
467
-
468
- # Generate visualization if requested
469
- viz_image = None
470
- if output_mode == "Tags + Visualization":
471
- print("Generating visualization...")
472
- # Pass the correct threshold for display title (can pass both if needed)
473
- # For simplicity, passing gen_threshold as a representative value
474
- viz_image = visualize_predictions(original_pil_image, predictions, gen_threshold)
475
- print("Visualization generated.")
476
- else:
477
- print("Visualization skipped.")
478
-
479
- print("Prediction complete.")
480
- return output_text, viz_image
481
-
482
- except Exception as e:
483
- message = f"Error during post-processing: {e}"
484
- print(message)
485
- import traceback; traceback.print_exc()
486
- return message, None
487
-
488
- # --- Gradio Interface Definition (Full ONNX Version) ---
489
- css = """
490
- .gradio-container { font-family: 'IBM Plex Sans', sans-serif; }
491
- footer { display: none !important; }
492
- .gr-prose { max-width: 100% !important; }
493
- """
494
- # js = """ /* Keep existing JS */ """ # No JS needed currently
495
-
496
- with gr.Blocks(css=css) as demo:
497
- gr.Markdown("# CL EVA02 ONNX Tagger")
498
- gr.Markdown("Upload an image or paste an image URL to predict tags using the CL EVA02 Tagger model (ONNX), fine-tuned from [SmilingWolf/wd-eva02-large-tagger-v3](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3).")
499
-
500
- with gr.Row():
501
- with gr.Column(scale=1):
502
- image_input = gr.Image(type="pil", label="Input Image", elem_id="input-image")
503
- model_choice = gr.Dropdown(
504
- choices=list(MODEL_OPTIONS.keys()),
505
- value=DEFAULT_MODEL,
506
- label="Model Version",
507
- interactive=True
508
- )
509
- gen_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.55, label="General/Meta Tag Threshold")
510
- char_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.60, label="Character/Copyright/Artist Tag Threshold")
511
- output_mode = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", label="Output Mode")
512
- predict_button = gr.Button("Predict", variant="primary")
513
- with gr.Column(scale=1):
514
- output_tags = gr.Textbox(label="Predicted Tags", lines=10, interactive=False)
515
- output_visualization = gr.Image(type="pil", label="Prediction Visualization", interactive=False)
516
-
517
- # Handle model change
518
- model_status = gr.Textbox(label="Model Status", interactive=False, visible=False)
519
- model_choice.change(
520
- fn=change_model,
521
- inputs=[model_choice],
522
- outputs=[model_status]
523
- )
524
-
525
- gr.Examples(
526
- examples=[
527
- ["https://pbs.twimg.com/media/GXBXsRvbQAAg1kp.jpg", DEFAULT_MODEL, 0.55, 0.70, "Tags + Visualization"],
528
- ["https://pbs.twimg.com/media/GjlX0gibcAA4EJ4.jpg", DEFAULT_MODEL, 0.55, 0.70, "Tags Only"],
529
- ["https://pbs.twimg.com/media/Gj4nQbjbEAATeoH.jpg", DEFAULT_MODEL, 0.55, 0.70, "Tags + Visualization"],
530
- ["https://pbs.twimg.com/media/GkbtX0GaoAMlUZt.jpg", DEFAULT_MODEL, 0.55, 0.70, "Tags + Visualization"]
531
- ],
532
- inputs=[image_input, model_choice, gen_threshold, char_threshold, output_mode],
533
- outputs=[output_tags, output_visualization],
534
- fn=predict_onnx, # Use the ONNX prediction function
535
- cache_examples=False # Disable caching for examples during testing
536
- )
537
- predict_button.click(
538
- fn=predict_onnx, # Use the ONNX prediction function
539
- inputs=[image_input, model_choice, gen_threshold, char_threshold, output_mode],
540
- outputs=[output_tags, output_visualization]
541
- )
542
-
543
- # --- Main Block ---
544
- if __name__ == "__main__":
545
- if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
546
- # Initialize paths and labels at startup (with default model)
547
- initialize_onnx_paths(DEFAULT_MODEL)
548
- # Launch Gradio app
549
- demo.launch(share=True)
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image # Keep PIL for now, might be needed by helpers implicitly
4
+ # from PIL import Image, ImageDraw, ImageFont # No drawing yet
5
+ import json
6
+ import os
7
+ import io
8
+ import requests
9
+ import matplotlib.pyplot as plt # For visualization
10
+ import matplotlib # For backend setting
11
+ from huggingface_hub import hf_hub_download
12
+ from dataclasses import dataclass
13
+ from typing import List, Dict, Optional, Tuple
14
+ import time
15
+ import spaces # Required for @spaces.GPU
16
+ import onnxruntime as ort # Use ONNX Runtime
17
+
18
+ import torch # Keep torch for device check in Tagger
19
+ import timm # Restore timm
20
+ from safetensors.torch import load_file as safe_load_file # Restore safetensors loading
21
+
22
+ # MatplotlibのバックエンドをAggに設定 (Keep commented out for now)
23
+ # matplotlib.use('Agg')
24
+
25
+ # --- Data Classes and Helper Functions ---
26
+ @dataclass
27
+ class LabelData:
28
+ names: list[str]
29
+ rating: list[np.int64]
30
+ general: list[np.int64]
31
+ artist: list[np.int64]
32
+ character: list[np.int64]
33
+ copyright: list[np.int64]
34
+ meta: list[np.int64]
35
+ quality: list[np.int64]
36
+
37
+ def pil_ensure_rgb(image: Image.Image) -> Image.Image:
38
+ if image.mode not in ["RGB", "RGBA"]:
39
+ image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
40
+ if image.mode == "RGBA":
41
+ background = Image.new("RGB", image.size, (255, 255, 255))
42
+ background.paste(image, mask=image.split()[3])
43
+ image = background
44
+ return image
45
+
46
+ def pil_pad_square(image: Image.Image) -> Image.Image:
47
+ width, height = image.size
48
+ if width == height: return image
49
+ new_size = max(width, height)
50
+ new_image = Image.new(image.mode, (new_size, new_size), (255, 255, 255)) # Use image.mode
51
+ paste_position = ((new_size - width) // 2, (new_size - height) // 2)
52
+ new_image.paste(image, paste_position)
53
+ return new_image
54
+
55
+ def load_tag_mapping(mapping_path):
56
+ # Use the implementation from the original app.py as it was confirmed working
57
+ with open(mapping_path, 'r', encoding='utf-8') as f: tag_mapping_data = json.load(f)
58
+ # Check format compatibility (can be dict of dicts or dict with idx_to_tag/tag_to_category)
59
+ if isinstance(tag_mapping_data, dict) and "idx_to_tag" in tag_mapping_data:
60
+ idx_to_tag = {int(k): v for k, v in tag_mapping_data["idx_to_tag"].items()}
61
+ tag_to_category = tag_mapping_data["tag_to_category"]
62
+ elif isinstance(tag_mapping_data, dict):
63
+ # Assuming the dict-of-dicts format from previous tests
64
+ try:
65
+ tag_mapping_data_int_keys = {int(k): v for k, v in tag_mapping_data.items()}
66
+ idx_to_tag = {idx: data['tag'] for idx, data in tag_mapping_data_int_keys.items()}
67
+ tag_to_category = {data['tag']: data['category'] for data in tag_mapping_data_int_keys.values()}
68
+ except (KeyError, ValueError) as e:
69
+ raise ValueError(f"Unsupported tag mapping format (dict): {e}. Expected int keys with 'tag' and 'category'.")
70
+ else:
71
+ raise ValueError("Unsupported tag mapping format: Expected a dictionary.")
72
+
73
+ names = [None] * (max(idx_to_tag.keys()) + 1)
74
+ rating, general, artist, character, copyright, meta, quality = [], [], [], [], [], [], []
75
+ for idx, tag in idx_to_tag.items():
76
+ if idx >= len(names): names.extend([None] * (idx - len(names) + 1))
77
+ names[idx] = tag
78
+ category = tag_to_category.get(tag, 'Unknown') # Handle missing category mapping gracefully
79
+ idx_int = int(idx)
80
+ if category == 'Rating': rating.append(idx_int)
81
+ elif category == 'General': general.append(idx_int)
82
+ elif category == 'Artist': artist.append(idx_int)
83
+ elif category == 'Character': character.append(idx_int)
84
+ elif category == 'Copyright': copyright.append(idx_int)
85
+ elif category == 'Meta': meta.append(idx_int)
86
+ elif category == 'Quality': quality.append(idx_int)
87
+
88
+ return LabelData(names=names, rating=np.array(rating, dtype=np.int64), general=np.array(general, dtype=np.int64), artist=np.array(artist, dtype=np.int64),
89
+ character=np.array(character, dtype=np.int64), copyright=np.array(copyright, dtype=np.int64), meta=np.array(meta, dtype=np.int64), quality=np.array(quality, dtype=np.int64)), idx_to_tag, tag_to_category
90
+
91
+ def preprocess_image(image: Image.Image, target_size=(448, 448)):
92
+ # Adapted from onnx_predict.py's version
93
+ image = pil_ensure_rgb(image)
94
+ image = pil_pad_square(image)
95
+ image_resized = image.resize(target_size, Image.BICUBIC)
96
+ img_array = np.array(image_resized, dtype=np.float32) / 255.0
97
+ img_array = img_array.transpose(2, 0, 1) # HWC -> CHW
98
+ # Assuming model expects RGB based on original code, no BGR conversion here
99
+ img_array = img_array[::-1, :, :] # BGR conversion if needed - UNCOMMENTED based on user feedback
100
+ mean = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
101
+ std = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
102
+ img_array = (img_array - mean) / std
103
+ img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
104
+ return image, img_array
105
+
106
+ # Add get_tags function (from onnx_predict.py)
107
+ def get_tags(probs, labels: LabelData, gen_threshold, char_threshold):
108
+ result = {
109
+ "rating": [],
110
+ "general": [],
111
+ "character": [],
112
+ "copyright": [],
113
+ "artist": [],
114
+ "meta": [],
115
+ "quality": []
116
+ }
117
+ # Rating (select max)
118
+ if len(labels.rating) > 0:
119
+ # Ensure indices are within bounds
120
+ valid_indices = labels.rating[labels.rating < len(probs)]
121
+ if len(valid_indices) > 0:
122
+ rating_probs = probs[valid_indices]
123
+ if len(rating_probs) > 0:
124
+ rating_idx_local = np.argmax(rating_probs)
125
+ rating_idx_global = valid_indices[rating_idx_local]
126
+ # Check if global index is valid for names list
127
+ if rating_idx_global < len(labels.names) and labels.names[rating_idx_global] is not None:
128
+ rating_name = labels.names[rating_idx_global]
129
+ rating_conf = float(rating_probs[rating_idx_local])
130
+ result["rating"].append((rating_name, rating_conf))
131
+ else:
132
+ print(f"Warning: Invalid global index {rating_idx_global} for rating tag.")
133
+ else:
134
+ print("Warning: rating_probs became empty after filtering.")
135
+ else:
136
+ print("Warning: No valid indices found for rating tags within probs length.")
137
+
138
+ # Quality (select max)
139
+ if len(labels.quality) > 0:
140
+ valid_indices = labels.quality[labels.quality < len(probs)]
141
+ if len(valid_indices) > 0:
142
+ quality_probs = probs[valid_indices]
143
+ if len(quality_probs) > 0:
144
+ quality_idx_local = np.argmax(quality_probs)
145
+ quality_idx_global = valid_indices[quality_idx_local]
146
+ if quality_idx_global < len(labels.names) and labels.names[quality_idx_global] is not None:
147
+ quality_name = labels.names[quality_idx_global]
148
+ quality_conf = float(quality_probs[quality_idx_local])
149
+ result["quality"].append((quality_name, quality_conf))
150
+ else:
151
+ print(f"Warning: Invalid global index {quality_idx_global} for quality tag.")
152
+ else:
153
+ print("Warning: quality_probs became empty after filtering.")
154
+ else:
155
+ print("Warning: No valid indices found for quality tags within probs length.")
156
+
157
+ # Threshold-based categories
158
+ category_map = {
159
+ "general": (labels.general, gen_threshold),
160
+ "character": (labels.character, char_threshold),
161
+ "copyright": (labels.copyright, char_threshold),
162
+ "artist": (labels.artist, char_threshold),
163
+ "meta": (labels.meta, gen_threshold) # Use gen_threshold for meta as per original code
164
+ }
165
+ for category, (indices, threshold) in category_map.items():
166
+ if len(indices) > 0:
167
+ valid_indices = indices[(indices < len(probs))] # Check index bounds first
168
+ if len(valid_indices) > 0:
169
+ category_probs = probs[valid_indices]
170
+ mask = category_probs >= threshold
171
+ selected_indices_local = np.where(mask)[0]
172
+ if len(selected_indices_local) > 0:
173
+ selected_indices_global = valid_indices[selected_indices_local]
174
+ selected_probs = category_probs[selected_indices_local]
175
+ for idx_global, prob_val in zip(selected_indices_global, selected_probs):
176
+ # Check if global index is valid for names list
177
+ if idx_global < len(labels.names) and labels.names[idx_global] is not None:
178
+ result[category].append((labels.names[idx_global], float(prob_val)))
179
+ else:
180
+ print(f"Warning: Invalid global index {idx_global} for {category} tag.")
181
+ # else: print(f"No tags found for category '{category}' above threshold {threshold}")
182
+ # else: print(f"No valid indices found for category '{category}' within probs length.")
183
+ # else: print(f"No indices defined for category '{category}'")
184
+
185
+ # Sort by probability (descending)
186
+ for k in result:
187
+ result[k] = sorted(result[k], key=lambda x: x[1], reverse=True)
188
+ return result
189
+
190
+ # Add visualize_predictions function (Adapted from onnx_predict.py and previous versions)
191
+ def visualize_predictions(image: Image.Image, predictions: Dict, threshold: float):
192
+ # Filter out unwanted meta tags (e.g., id, commentary, request, mismatch)
193
+ filtered_meta = []
194
+ excluded_meta_patterns = ['id', 'commentary', 'request', 'mismatch']
195
+ for tag, prob in predictions.get("meta", []):
196
+ if not any(pattern in tag.lower() for pattern in excluded_meta_patterns):
197
+ filtered_meta.append((tag, prob))
198
+ predictions["meta"] = filtered_meta # Use filtered list for visualization
199
+
200
+ # --- Plotting Setup ---
201
+ plt.rcParams['font.family'] = 'DejaVu Sans'
202
+ fig = plt.figure(figsize=(8, 12), dpi=100)
203
+ ax_tags = fig.add_subplot(1, 1, 1)
204
+
205
+ all_tags, all_probs, all_colors = [], [], []
206
+ color_map = {
207
+ 'rating': 'red', 'character': 'blue', 'copyright': 'purple',
208
+ 'artist': 'orange', 'general': 'green', 'meta': 'gray', 'quality': 'yellow'
209
+ }
210
+
211
+ # Aggregate tags from predictions dictionary
212
+ for cat, prefix, color in [
213
+ ('rating', 'R', color_map['rating']), ('quality', 'Q', color_map['quality']),
214
+ ('character', 'C', color_map['character']), ('copyright', '©', color_map['copyright']),
215
+ ('artist', 'A', color_map['artist']), ('general', 'G', color_map['general']),
216
+ ('meta', 'M', color_map['meta'])
217
+ ]:
218
+ sorted_tags = sorted(predictions.get(cat, []), key=lambda x: x[1], reverse=True)
219
+ for tag, prob in sorted_tags:
220
+ all_tags.append(f"[{prefix}] {tag.replace('_', ' ')}")
221
+ all_probs.append(prob)
222
+ all_colors.append(color)
223
+
224
+ if not all_tags:
225
+ ax_tags.text(0.5, 0.5, "No tags found above threshold", ha='center', va='center')
226
+ ax_tags.set_title(f"Tags (Threshold ≳ {threshold:.2f})")
227
+ ax_tags.axis('off')
228
+ else:
229
+ sorted_indices = sorted(range(len(all_probs)), key=lambda i: all_probs[i])
230
+ all_tags = [all_tags[i] for i in sorted_indices]
231
+ all_probs = [all_probs[i] for i in sorted_indices]
232
+ all_colors = [all_colors[i] for i in sorted_indices]
233
+
234
+ num_tags = len(all_tags)
235
+ bar_height = min(0.8, max(0.1, 0.8 * (30 / num_tags))) if num_tags > 30 else 0.8
236
+ y_positions = np.arange(num_tags)
237
+
238
+ bars = ax_tags.barh(y_positions, all_probs, height=bar_height, color=all_colors)
239
+ ax_tags.set_yticks(y_positions)
240
+ ax_tags.set_yticklabels(all_tags)
241
+
242
+ fontsize = 10 if num_tags <= 40 else 8 if num_tags <= 60 else 6
243
+ for lbl in ax_tags.get_yticklabels():
244
+ lbl.set_fontsize(fontsize)
245
+
246
+ for i, (bar, prob) in enumerate(zip(bars, all_probs)):
247
+ text_x = min(prob + 0.02, 0.98)
248
+ ax_tags.text(text_x, y_positions[i], f"{prob:.3f}", va='center', fontsize=fontsize)
249
+
250
+ ax_tags.set_xlim(0, 1)
251
+ ax_tags.set_title(f"Tags (Threshold ≳ {threshold:.2f})")
252
+
253
+ from matplotlib.patches import Patch
254
+ legend_elements = [
255
+ Patch(facecolor=color, label=cat.capitalize())
256
+ for cat, color in color_map.items()
257
+ if any(t.startswith(f"[{cat[0].upper() if cat!='copyright' else '©'}]") for t in all_tags)
258
+ ]
259
+ if legend_elements:
260
+ ax_tags.legend(handles=legend_elements, loc='lower right', fontsize=8)
261
+
262
+ plt.tight_layout()
263
+ buf = io.BytesIO()
264
+ plt.savefig(buf, format='png', dpi=100)
265
+ plt.close(fig)
266
+ buf.seek(0)
267
+ return Image.open(buf)
268
+
269
+ # --- Constants ---
270
+ REPO_ID = "celstk/wd-eva02-lora-onnx"
271
+ # Model options
272
+ MODEL_OPTIONS = {
273
+ "cl_eva02_tagger_v1_250426": "cl_eva02_tagger_v1_250426/model.onnx",
274
+ "cl_eva02_tagger_v1_250427": "cl_eva02_tagger_v1_250427/model.onnx",
275
+ "cl_eva02_tagger_v1_250430": "cl_eva02_tagger_v1_250430/model.onnx",
276
+ "cl_eva02_tagger_v1_250502": "cl_eva02_tagger_v1_250503/model.onnx",
277
+ "cl_eva02_tagger_v1_250504": "cl_eva02_tagger_v1_250504/model.onnx"
278
+ }
279
+ DEFAULT_MODEL = "cl_eva02_tagger_v1_250504"
280
+ CACHE_DIR = "./model_cache"
281
+
282
+ # --- Global variables for paths (initialized at startup) ---
283
+ g_onnx_model_path = None
284
+ g_tag_mapping_path = None
285
+ g_labels_data = None
286
+ g_idx_to_tag = None
287
+ g_tag_to_category = None
288
+ g_current_model = None
289
+
290
+ # --- Initialization Function ---
291
+ def initialize_onnx_paths(model_choice=DEFAULT_MODEL):
292
+ global g_onnx_model_path, g_tag_mapping_path, g_labels_data, g_idx_to_tag, g_tag_to_category, g_current_model
293
+
294
+ if not model_choice in MODEL_OPTIONS:
295
+ print(f"Invalid model choice: {model_choice}, falling back to default: {DEFAULT_MODEL}")
296
+ model_choice = DEFAULT_MODEL
297
+
298
+ g_current_model = model_choice
299
+ model_dir = model_choice
300
+ onnx_filename = MODEL_OPTIONS[model_choice]
301
+ tag_mapping_filename = f"{model_dir}/tag_mapping.json"
302
+
303
+ print(f"Initializing ONNX paths and labels for model: {model_choice}...")
304
+ hf_token = os.environ.get("HF_TOKEN")
305
+ try:
306
+ print(f"Attempting to download ONNX model: {onnx_filename}")
307
+ g_onnx_model_path = hf_hub_download(repo_id=REPO_ID, filename=onnx_filename, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
308
+ print(f"ONNX model path: {g_onnx_model_path}")
309
+
310
+ print(f"Attempting to download Tag mapping: {tag_mapping_filename}")
311
+ g_tag_mapping_path = hf_hub_download(repo_id=REPO_ID, filename=tag_mapping_filename, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
312
+ print(f"Tag mapping path: {g_tag_mapping_path}")
313
+
314
+ print("Loading labels from mapping...")
315
+ g_labels_data, g_idx_to_tag, g_tag_to_category = load_tag_mapping(g_tag_mapping_path)
316
+ print(f"Labels loaded. Count: {len(g_labels_data.names)}")
317
+
318
+ return True
319
+
320
+ except Exception as e:
321
+ print(f"Error during initialization: {e}")
322
+ import traceback; traceback.print_exc()
323
+ # Reset globals to force reinitialization
324
+ g_onnx_model_path = None
325
+ g_tag_mapping_path = None
326
+ g_labels_data = None
327
+ g_idx_to_tag = None
328
+ g_tag_to_category = None
329
+ g_current_model = None
330
+ # Raise Gradio error to make it visible in the UI
331
+ raise gr.Error(f"Initialization failed: {e}. Check logs and HF_TOKEN.")
332
+
333
+ # Function to handle model change
334
+ def change_model(model_choice):
335
+ try:
336
+ success = initialize_onnx_paths(model_choice)
337
+ if success:
338
+ return f"Model changed to: {model_choice}"
339
+ else:
340
+ return "Failed to change model. See logs for details."
341
+ except Exception as e:
342
+ return f"Error changing model: {str(e)}"
343
+
344
+ # --- Main Prediction Function (ONNX) ---
345
+ @spaces.GPU()
346
+ def predict_onnx(image_input, model_choice, gen_threshold, char_threshold, output_mode):
347
+ print(f"--- predict_onnx function started (GPU worker) with model {model_choice} ---")
348
+
349
+ # Ensure current model matches selected model
350
+ global g_current_model
351
+ if g_current_model != model_choice:
352
+ print(f"Model mismatch! Current: {g_current_model}, Selected: {model_choice}. Reinitializing...")
353
+ try:
354
+ initialize_onnx_paths(model_choice)
355
+ except Exception as e:
356
+ return f"Error initializing model '{model_choice}': {str(e)}", None
357
+
358
+ # --- 1. Ensure paths and labels are loaded ---
359
+ if g_onnx_model_path is None or g_labels_data is None:
360
+ message = "Error: Paths or labels not initialized. Check startup logs."
361
+ print(message)
362
+ # Return error message and None for the image output
363
+ return message, None
364
+
365
+ # --- 2. Load ONNX Session (inside worker) ---
366
+ session = None
367
+ try:
368
+ print(f"Loading ONNX session from: {g_onnx_model_path}")
369
+ available_providers = ort.get_available_providers()
370
+ providers = []
371
+ if 'CUDAExecutionProvider' in available_providers:
372
+ providers.append('CUDAExecutionProvider')
373
+ providers.append('CPUExecutionProvider')
374
+ print(f"Attempting to load session with providers: {providers}")
375
+ session = ort.InferenceSession(g_onnx_model_path, providers=providers)
376
+ print(f"ONNX session loaded using: {session.get_providers()[0]}")
377
+ except Exception as e:
378
+ message = f"Error loading ONNX session in worker: {e}"
379
+ print(message)
380
+ import traceback; traceback.print_exc()
381
+ return message, None
382
+
383
+ # --- 3. Process Input Image ---
384
+ if image_input is None:
385
+ return "Please upload an image.", None
386
+
387
+ print(f"Processing image with thresholds: gen={gen_threshold}, char={char_threshold}")
388
+ try:
389
+ # Handle different input types (PIL, numpy, URL, file path)
390
+ if isinstance(image_input, str):
391
+ if image_input.startswith("http"): # URL
392
+ response = requests.get(image_input, timeout=10)
393
+ response.raise_for_status()
394
+ image = Image.open(io.BytesIO(response.content))
395
+ elif os.path.exists(image_input): # File path
396
+ image = Image.open(image_input)
397
+ else:
398
+ raise ValueError(f"Invalid image input string: {image_input}")
399
+ elif isinstance(image_input, np.ndarray):
400
+ image = Image.fromarray(image_input)
401
+ elif isinstance(image_input, Image.Image):
402
+ image = image_input # Already a PIL image
403
+ else:
404
+ raise TypeError(f"Unsupported image input type: {type(image_input)}")
405
+
406
+ # Preprocess the PIL image
407
+ original_pil_image, input_tensor = preprocess_image(image)
408
+
409
+ # Ensure input tensor is float32, as expected by most ONNX models
410
+ # (even if the model internally uses float16)
411
+ input_tensor = input_tensor.astype(np.float32)
412
+
413
+ except Exception as e:
414
+ message = f"Error processing input image: {e}"
415
+ print(message)
416
+ return message, None
417
+
418
+ # --- 4. Run Inference ---
419
+ try:
420
+ input_name = session.get_inputs()[0].name
421
+ output_name = session.get_outputs()[0].name
422
+ print(f"Running inference with input '{input_name}', output '{output_name}'")
423
+ start_time = time.time()
424
+ outputs = session.run([output_name], {input_name: input_tensor})[0]
425
+ inference_time = time.time() - start_time
426
+ print(f"Inference completed in {inference_time:.3f} seconds")
427
+
428
+ # Check for NaN/Inf in outputs
429
+ if np.isnan(outputs).any() or np.isinf(outputs).any():
430
+ print("Warning: NaN or Inf detected in model output. Clamping...")
431
+ outputs = np.nan_to_num(outputs, nan=0.0, posinf=1.0, neginf=0.0) # Clamp to 0-1 range
432
+
433
+ # Apply sigmoid (outputs are likely logits)
434
+ # Use a stable sigmoid implementation
435
+ def stable_sigmoid(x):
436
+ return 1 / (1 + np.exp(-np.clip(x, -30, 30))) # Clip to avoid overflow
437
+ probs = stable_sigmoid(outputs[0]) # Assuming batch size 1
438
+
439
+ except Exception as e:
440
+ message = f"Error during ONNX inference: {e}"
441
+ print(message)
442
+ import traceback; traceback.print_exc()
443
+ return message, None
444
+ finally:
445
+ # Clean up session if needed (might reduce memory usage between clicks)
446
+ del session
447
+
448
+ # --- 5. Post-process and Format Output ---
449
+ try:
450
+ print("Post-processing results...")
451
+ # Use the correct global variable for labels
452
+ predictions = get_tags(probs, g_labels_data, gen_threshold, char_threshold)
453
+
454
+ # Format output text string
455
+ output_tags = []
456
+ if predictions.get("rating"): output_tags.append(predictions["rating"][0][0].replace("_", " "))
457
+ if predictions.get("quality"): output_tags.append(predictions["quality"][0][0].replace("_", " "))
458
+ # Add other categories, respecting order and filtering meta if needed
459
+ for category in ["artist", "character", "copyright", "general", "meta"]:
460
+ tags_in_category = predictions.get(category, [])
461
+ for tag, prob in tags_in_category:
462
+ # Basic meta tag filtering for text output
463
+ if category == "meta" and any(p in tag.lower() for p in ['id', 'commentary', 'request', 'mismatch']):
464
+ continue
465
+ output_tags.append(tag.replace("_", " "))
466
+ output_text = ", ".join(output_tags)
467
+
468
+ # Generate visualization if requested
469
+ viz_image = None
470
+ if output_mode == "Tags + Visualization":
471
+ print("Generating visualization...")
472
+ # Pass the correct threshold for display title (can pass both if needed)
473
+ # For simplicity, passing gen_threshold as a representative value
474
+ viz_image = visualize_predictions(original_pil_image, predictions, gen_threshold)
475
+ print("Visualization generated.")
476
+ else:
477
+ print("Visualization skipped.")
478
+
479
+ print("Prediction complete.")
480
+ return output_text, viz_image
481
+
482
+ except Exception as e:
483
+ message = f"Error during post-processing: {e}"
484
+ print(message)
485
+ import traceback; traceback.print_exc()
486
+ return message, None
487
+
488
+ # --- Gradio Interface Definition (Full ONNX Version) ---
489
+ css = """
490
+ .gradio-container { font-family: 'IBM Plex Sans', sans-serif; }
491
+ footer { display: none !important; }
492
+ .gr-prose { max-width: 100% !important; }
493
+ """
494
+ # js = """ /* Keep existing JS */ """ # No JS needed currently
495
+
496
+ with gr.Blocks(css=css) as demo:
497
+ gr.Markdown("# CL EVA02 ONNX Tagger")
498
+ gr.Markdown("Upload an image or paste an image URL to predict tags using the CL EVA02 Tagger model (ONNX), fine-tuned from [SmilingWolf/wd-eva02-large-tagger-v3](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3).")
499
+
500
+ with gr.Row():
501
+ with gr.Column(scale=1):
502
+ image_input = gr.Image(type="pil", label="Input Image", elem_id="input-image")
503
+ model_choice = gr.Dropdown(
504
+ choices=list(MODEL_OPTIONS.keys()),
505
+ value=DEFAULT_MODEL,
506
+ label="Model Version",
507
+ interactive=True
508
+ )
509
+ gen_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.55, label="General/Meta Tag Threshold")
510
+ char_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.60, label="Character/Copyright/Artist Tag Threshold")
511
+ output_mode = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", label="Output Mode")
512
+ predict_button = gr.Button("Predict", variant="primary")
513
+ with gr.Column(scale=1):
514
+ output_tags = gr.Textbox(label="Predicted Tags", lines=10, interactive=False)
515
+ output_visualization = gr.Image(type="pil", label="Prediction Visualization", interactive=False)
516
+
517
+ # Handle model change
518
+ model_status = gr.Textbox(label="Model Status", interactive=False, visible=False)
519
+ model_choice.change(
520
+ fn=change_model,
521
+ inputs=[model_choice],
522
+ outputs=[model_status]
523
+ )
524
+
525
+ gr.Examples(
526
+ examples=[
527
+ ["https://pbs.twimg.com/media/GXBXsRvbQAAg1kp.jpg", DEFAULT_MODEL, 0.55, 0.70, "Tags + Visualization"],
528
+ ["https://pbs.twimg.com/media/GjlX0gibcAA4EJ4.jpg", DEFAULT_MODEL, 0.55, 0.70, "Tags Only"],
529
+ ["https://pbs.twimg.com/media/Gj4nQbjbEAATeoH.jpg", DEFAULT_MODEL, 0.55, 0.70, "Tags + Visualization"],
530
+ ["https://pbs.twimg.com/media/GkbtX0GaoAMlUZt.jpg", DEFAULT_MODEL, 0.55, 0.70, "Tags + Visualization"]
531
+ ],
532
+ inputs=[image_input, model_choice, gen_threshold, char_threshold, output_mode],
533
+ outputs=[output_tags, output_visualization],
534
+ fn=predict_onnx, # Use the ONNX prediction function
535
+ cache_examples=False # Disable caching for examples during testing
536
+ )
537
+ predict_button.click(
538
+ fn=predict_onnx, # Use the ONNX prediction function
539
+ inputs=[image_input, model_choice, gen_threshold, char_threshold, output_mode],
540
+ outputs=[output_tags, output_visualization]
541
+ )
542
+
543
+ # --- Main Block ---
544
+ if __name__ == "__main__":
545
+ if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
546
+ # Initialize paths and labels at startup (with default model)
547
+ initialize_onnx_paths(DEFAULT_MODEL)
548
+ # Launch Gradio app
549
+ demo.launch(share=True)