|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torchvision.models import efficientnet_v2_l, EfficientNet_V2_L_Weights
|
|
from PIL import Image
|
|
from typing import Optional
|
|
import torchvision.transforms as transforms
|
|
import os
|
|
import json
|
|
|
|
class InitialOnlyImageTagger(nn.Module):
|
|
"""
|
|
A lightweight version of ImageTagger that only includes the backbone and initial classifier.
|
|
This model uses significantly less VRAM than the full model.
|
|
"""
|
|
def __init__(self, total_tags, dataset, model_name='efficientnet_v2_l',
|
|
dropout=0.1, pretrained=True):
|
|
super().__init__()
|
|
|
|
self._flags = {
|
|
'debug': False,
|
|
'model_stats': False
|
|
}
|
|
|
|
|
|
self.dataset = dataset
|
|
self.embedding_dim = 1280
|
|
|
|
|
|
if model_name == 'efficientnet_v2_l':
|
|
weights = EfficientNet_V2_L_Weights.DEFAULT if pretrained else None
|
|
self.backbone = efficientnet_v2_l(weights=weights)
|
|
self.backbone.classifier = nn.Identity()
|
|
|
|
|
|
self.spatial_pool = nn.AdaptiveAvgPool2d((1, 1))
|
|
|
|
|
|
self.initial_classifier = nn.Sequential(
|
|
nn.Linear(self.embedding_dim, self.embedding_dim * 2),
|
|
nn.LayerNorm(self.embedding_dim * 2),
|
|
nn.GELU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(self.embedding_dim * 2, self.embedding_dim),
|
|
nn.LayerNorm(self.embedding_dim),
|
|
nn.GELU(),
|
|
nn.Linear(self.embedding_dim, total_tags)
|
|
)
|
|
|
|
|
|
self.temperature = nn.Parameter(torch.ones(1) * 1.5)
|
|
|
|
@property
|
|
def debug(self):
|
|
return self._flags['debug']
|
|
|
|
@debug.setter
|
|
def debug(self, value):
|
|
self._flags['debug'] = value
|
|
|
|
@property
|
|
def model_stats(self):
|
|
return self._flags['model_stats']
|
|
|
|
@model_stats.setter
|
|
def model_stats(self, value):
|
|
self._flags['model_stats'] = value
|
|
|
|
def preprocess_image(self, image_path, image_size=512):
|
|
"""Process an image for inference using same preprocessing as training"""
|
|
if not os.path.exists(image_path):
|
|
raise ValueError(f"Image not found at path: {image_path}")
|
|
|
|
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
])
|
|
|
|
try:
|
|
with Image.open(image_path) as img:
|
|
|
|
if img.mode in ('RGBA', 'P'):
|
|
img = img.convert('RGB')
|
|
|
|
|
|
width, height = img.size
|
|
aspect_ratio = width / height
|
|
|
|
|
|
if aspect_ratio > 1:
|
|
new_width = image_size
|
|
new_height = int(new_width / aspect_ratio)
|
|
else:
|
|
new_height = image_size
|
|
new_width = int(new_height * aspect_ratio)
|
|
|
|
|
|
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
|
|
|
|
|
new_image = Image.new('RGB', (image_size, image_size), (0, 0, 0))
|
|
paste_x = (image_size - new_width) // 2
|
|
paste_y = (image_size - new_height) // 2
|
|
new_image.paste(img, (paste_x, paste_y))
|
|
|
|
|
|
img_tensor = transform(new_image)
|
|
return img_tensor
|
|
except Exception as e:
|
|
raise Exception(f"Error processing {image_path}: {str(e)}")
|
|
|
|
def forward(self, x):
|
|
"""Forward pass with only the initial predictions"""
|
|
|
|
features = self.backbone.features(x)
|
|
features = self.spatial_pool(features).squeeze(-1).squeeze(-1)
|
|
|
|
|
|
initial_logits = self.initial_classifier(features)
|
|
initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0)
|
|
|
|
|
|
return initial_preds, initial_preds
|
|
|
|
def predict(self, image_path, threshold=0.325, category_thresholds=None):
|
|
"""
|
|
Run inference on an image with support for category-specific thresholds.
|
|
"""
|
|
|
|
img_tensor = self.preprocess_image(image_path).unsqueeze(0)
|
|
|
|
|
|
device = next(self.parameters()).device
|
|
dtype = next(self.parameters()).dtype
|
|
img_tensor = img_tensor.to(device, dtype=dtype)
|
|
|
|
|
|
with torch.no_grad():
|
|
initial_preds, _ = self.forward(img_tensor)
|
|
|
|
|
|
initial_probs = torch.sigmoid(initial_preds)
|
|
|
|
|
|
if category_thresholds:
|
|
|
|
initial_binary = torch.zeros_like(initial_probs)
|
|
|
|
|
|
for category, cat_threshold in category_thresholds.items():
|
|
|
|
category_mask = torch.zeros_like(initial_probs, dtype=torch.bool)
|
|
|
|
|
|
for tag_idx in range(initial_probs.size(-1)):
|
|
try:
|
|
_, tag_category = self.dataset.get_tag_info(tag_idx)
|
|
if tag_category == category:
|
|
category_mask[:, tag_idx] = True
|
|
except:
|
|
continue
|
|
|
|
|
|
cat_threshold_tensor = torch.tensor(cat_threshold, device=device, dtype=dtype)
|
|
initial_binary[category_mask] = (initial_probs[category_mask] >= cat_threshold_tensor).to(dtype)
|
|
|
|
predictions = initial_binary
|
|
else:
|
|
|
|
threshold_tensor = torch.tensor(threshold, device=device, dtype=dtype)
|
|
predictions = (initial_probs >= threshold_tensor).to(dtype)
|
|
|
|
|
|
return {
|
|
'initial_probabilities': initial_probs,
|
|
'refined_probabilities': initial_probs,
|
|
'predictions': predictions
|
|
}
|
|
|
|
def get_tags_from_predictions(self, predictions, include_probabilities=True):
|
|
"""
|
|
Convert model predictions to human-readable tags grouped by category.
|
|
"""
|
|
|
|
if predictions.dim() > 1:
|
|
predictions = predictions[0]
|
|
|
|
|
|
indices = torch.where(predictions > 0)[0].cpu().tolist()
|
|
|
|
|
|
result = {}
|
|
for idx in indices:
|
|
tag_name, category = self.dataset.get_tag_info(idx)
|
|
|
|
if category not in result:
|
|
result[category] = []
|
|
|
|
if include_probabilities:
|
|
prob = predictions[idx].item()
|
|
result[category].append((tag_name, prob))
|
|
else:
|
|
result[category].append(tag_name)
|
|
|
|
|
|
if include_probabilities:
|
|
for category in result:
|
|
result[category] = sorted(result[category], key=lambda x: x[1], reverse=True)
|
|
|
|
return result
|
|
|
|
class FlashAttention(nn.Module):
|
|
def __init__(self, dim, num_heads=8, dropout=0.1, batch_first=True):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.num_heads = num_heads
|
|
self.dropout = dropout
|
|
self.batch_first = batch_first
|
|
self.head_dim = dim // num_heads
|
|
assert self.head_dim * num_heads == dim, "dim must be divisible by num_heads"
|
|
|
|
self.q_proj = nn.Linear(dim, dim, bias=False)
|
|
self.k_proj = nn.Linear(dim, dim, bias=False)
|
|
self.v_proj = nn.Linear(dim, dim, bias=False)
|
|
self.out_proj = nn.Linear(dim, dim, bias=False)
|
|
|
|
for proj in [self.q_proj, self.k_proj, self.v_proj, self.out_proj]:
|
|
nn.init.xavier_uniform_(proj.weight, gain=0.1)
|
|
|
|
self.scale = self.head_dim ** -0.5
|
|
self.debug = False
|
|
|
|
def _debug_print(self, name, tensor):
|
|
"""Debug helper"""
|
|
if self.debug:
|
|
print(f"\n{name}:")
|
|
print(f"Shape: {tensor.shape}")
|
|
print(f"Device: {tensor.device}")
|
|
print(f"Dtype: {tensor.dtype}")
|
|
if tensor.is_floating_point():
|
|
with torch.no_grad():
|
|
print(f"Range: [{tensor.min().item():.3f}, {tensor.max().item():.3f}]")
|
|
print(f"Mean: {tensor.mean().item():.3f}")
|
|
print(f"Std: {tensor.std().item():.3f}")
|
|
|
|
def _reshape_for_flash(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""Reshape input tensor for flash attention format"""
|
|
batch_size, seq_len, _ = x.size()
|
|
x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
|
x = x.transpose(1, 2)
|
|
return x.contiguous()
|
|
|
|
def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None,
|
|
value: Optional[torch.Tensor] = None,
|
|
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
"""Forward pass with flash attention"""
|
|
if self.debug:
|
|
print("\nFlashAttention Forward Pass")
|
|
|
|
batch_size = query.size(0)
|
|
|
|
|
|
key = query if key is None else key
|
|
value = query if value is None else value
|
|
|
|
|
|
q = self.q_proj(query)
|
|
k = self.k_proj(key)
|
|
v = self.v_proj(value)
|
|
|
|
if self.debug:
|
|
self._debug_print("Query before reshape", q)
|
|
|
|
|
|
q = self._reshape_for_flash(q)
|
|
k = self._reshape_for_flash(k)
|
|
v = self._reshape_for_flash(v)
|
|
|
|
if self.debug:
|
|
self._debug_print("Query after reshape", q)
|
|
|
|
|
|
if mask is not None:
|
|
|
|
if mask.dim() == 2:
|
|
mask = mask.view(batch_size, 1, -1, 1)
|
|
elif mask.dim() == 3:
|
|
mask = mask.view(batch_size, 1, mask.size(1), mask.size(2))
|
|
elif mask.dim() == 5:
|
|
mask = mask.squeeze(1).view(batch_size, 1, mask.size(2), mask.size(3))
|
|
|
|
|
|
mask = mask.to(q.dtype)
|
|
|
|
if self.debug:
|
|
self._debug_print("Prepared mask", mask)
|
|
print(f"q shape: {q.shape}, mask shape: {mask.shape}")
|
|
|
|
|
|
seq_len = q.size(2)
|
|
if mask.size(-1) != seq_len:
|
|
|
|
new_mask = torch.zeros(batch_size, 1, seq_len, seq_len,
|
|
device=mask.device, dtype=mask.dtype)
|
|
min_len = min(seq_len, mask.size(-1))
|
|
new_mask[..., :min_len, :min_len] = mask[..., :min_len, :min_len]
|
|
mask = new_mask
|
|
|
|
|
|
key_padding_mask = mask.squeeze(1).sum(-1) > 0
|
|
key_padding_mask = key_padding_mask.view(batch_size, 1, -1, 1)
|
|
|
|
|
|
k = k * key_padding_mask
|
|
v = v * key_padding_mask
|
|
|
|
if self.debug:
|
|
self._debug_print("Query before attention", q)
|
|
self._debug_print("Key before attention", k)
|
|
self._debug_print("Value before attention", v)
|
|
|
|
|
|
dropout_p = self.dropout if self.training else 0.0
|
|
output = flash_attn_func(
|
|
q, k, v,
|
|
dropout_p=dropout_p,
|
|
softmax_scale=self.scale,
|
|
causal=False
|
|
)
|
|
|
|
if self.debug:
|
|
self._debug_print("Output after attention", output)
|
|
|
|
|
|
output = output.transpose(1, 2).contiguous()
|
|
output = output.view(batch_size, -1, self.dim)
|
|
|
|
|
|
output = self.out_proj(output)
|
|
|
|
if self.debug:
|
|
self._debug_print("Final output", output)
|
|
|
|
return output
|
|
|
|
class OptimizedTagEmbedding(nn.Module):
|
|
def __init__(self, num_tags, embedding_dim, num_heads=8, dropout=0.1):
|
|
super().__init__()
|
|
|
|
self.embedding = nn.Embedding(num_tags, embedding_dim)
|
|
self.attention = FlashAttention(embedding_dim, num_heads, dropout)
|
|
self.norm1 = nn.LayerNorm(embedding_dim)
|
|
self.norm2 = nn.LayerNorm(embedding_dim)
|
|
|
|
|
|
self.tag_importance = nn.Parameter(torch.ones(num_tags) * 0.1)
|
|
|
|
|
|
self.context_proj = nn.Sequential(
|
|
nn.Linear(embedding_dim, embedding_dim * 2),
|
|
nn.LayerNorm(embedding_dim * 2),
|
|
nn.GELU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(embedding_dim * 2, embedding_dim),
|
|
nn.LayerNorm(embedding_dim)
|
|
)
|
|
|
|
self.importance_scale = nn.Parameter(torch.tensor(0.1))
|
|
self.context_scale = nn.Parameter(torch.tensor(1.0))
|
|
self.debug = False
|
|
|
|
def _debug_print(self, name, tensor, extra_info=None):
|
|
"""Memory efficient debug printing with type handling"""
|
|
if self.debug:
|
|
print(f"\n{name}:")
|
|
print(f"- Shape: {tensor.shape}")
|
|
if isinstance(tensor, torch.Tensor):
|
|
with torch.no_grad():
|
|
print(f"- Device: {tensor.device}")
|
|
print(f"- Dtype: {tensor.dtype}")
|
|
|
|
|
|
if tensor.dtype not in [torch.float16, torch.float32, torch.float64]:
|
|
calc_tensor = tensor.float()
|
|
else:
|
|
calc_tensor = tensor
|
|
|
|
try:
|
|
min_val = calc_tensor.min().item()
|
|
max_val = calc_tensor.max().item()
|
|
mean_val = calc_tensor.mean().item()
|
|
std_val = calc_tensor.std().item()
|
|
norm_val = torch.norm(calc_tensor).item()
|
|
|
|
print(f"- Value range: [{min_val:.3f}, {max_val:.3f}]")
|
|
print(f"- Mean: {mean_val:.3f}")
|
|
print(f"- Std: {std_val:.3f}")
|
|
print(f"- L2 Norm: {norm_val:.3f}")
|
|
|
|
if extra_info:
|
|
print(f"- Additional info: {extra_info}")
|
|
except Exception as e:
|
|
print(f"- Could not compute statistics: {str(e)}")
|
|
|
|
def _debug_tensor(self, name, tensor):
|
|
"""Debug helper with dtype-specific analysis"""
|
|
if self.debug and isinstance(tensor, torch.Tensor):
|
|
print(f"\n{name}:")
|
|
print(f"- Shape: {tensor.shape}")
|
|
print(f"- Device: {tensor.device}")
|
|
print(f"- Dtype: {tensor.dtype}")
|
|
with torch.no_grad():
|
|
has_nan = torch.isnan(tensor).any().item() if tensor.is_floating_point() else False
|
|
has_inf = torch.isinf(tensor).any().item() if tensor.is_floating_point() else False
|
|
print(f"- Contains NaN: {has_nan}")
|
|
print(f"- Contains Inf: {has_inf}")
|
|
|
|
|
|
if tensor.is_floating_point():
|
|
print(f"- Range: [{tensor.min().item():.3f}, {tensor.max().item():.3f}]")
|
|
print(f"- Mean: {tensor.mean().item():.3f}")
|
|
print(f"- Std: {tensor.std().item():.3f}")
|
|
else:
|
|
|
|
print(f"- Range: [{tensor.min().item()}, {tensor.max().item()}]")
|
|
print(f"- Unique values: {tensor.unique().numel()}")
|
|
|
|
def _process_category(self, indices, masks):
|
|
"""Process a single category of tags"""
|
|
|
|
embeddings = self.embedding(indices)
|
|
|
|
if self.debug:
|
|
self._debug_tensor("Category embeddings", embeddings)
|
|
|
|
|
|
importance = torch.sigmoid(self.tag_importance) * self.importance_scale
|
|
importance = torch.clamp(importance, min=0.01, max=10.0)
|
|
importance_weights = importance[indices].unsqueeze(-1)
|
|
|
|
|
|
embeddings = embeddings * importance_weights
|
|
embeddings = self.norm1(embeddings)
|
|
|
|
|
|
if embeddings.size(1) > 1:
|
|
if masks is not None:
|
|
attention_mask = torch.einsum('bi,bj->bij', masks, masks)
|
|
attended = self.attention(embeddings, mask=attention_mask)
|
|
else:
|
|
attended = self.attention(embeddings)
|
|
embeddings = self.norm2(attended)
|
|
|
|
|
|
if masks is not None:
|
|
masked_embeddings = embeddings * masks.unsqueeze(-1)
|
|
pooled = masked_embeddings.sum(dim=1) / masks.sum(dim=1, keepdim=True).clamp(min=1.0)
|
|
else:
|
|
pooled = embeddings.mean(dim=1)
|
|
|
|
return pooled, embeddings
|
|
|
|
def forward(self, tag_indices_dict, tag_masks_dict=None):
|
|
"""
|
|
Process all tags in a unified embedding space
|
|
Args:
|
|
tag_indices_dict: dict of {category: tensor of indices}
|
|
tag_masks_dict: dict of {category: tensor of masks}
|
|
"""
|
|
if self.debug:
|
|
print("\nOptimizedTagEmbedding Forward Pass")
|
|
|
|
|
|
all_indices = []
|
|
all_masks = []
|
|
batch_size = None
|
|
|
|
for category, indices in tag_indices_dict.items():
|
|
if batch_size is None:
|
|
batch_size = indices.size(0)
|
|
all_indices.append(indices)
|
|
if tag_masks_dict:
|
|
all_masks.append(tag_masks_dict[category])
|
|
|
|
|
|
combined_indices = torch.cat(all_indices, dim=1)
|
|
if tag_masks_dict:
|
|
combined_masks = torch.cat(all_masks, dim=1)
|
|
|
|
if self.debug:
|
|
self._debug_tensor("Combined indices", combined_indices)
|
|
if tag_masks_dict:
|
|
self._debug_tensor("Combined masks", combined_masks)
|
|
|
|
|
|
embeddings = self.embedding(combined_indices)
|
|
|
|
if self.debug:
|
|
self._debug_tensor("Base embeddings", embeddings)
|
|
|
|
|
|
importance = torch.sigmoid(self.tag_importance) * self.importance_scale
|
|
importance = torch.clamp(importance, min=0.01, max=10.0)
|
|
importance_weights = importance[combined_indices].unsqueeze(-1)
|
|
|
|
|
|
embeddings = embeddings * importance_weights
|
|
embeddings = self.norm1(embeddings)
|
|
|
|
if self.debug:
|
|
self._debug_tensor("Weighted embeddings", embeddings)
|
|
|
|
|
|
if tag_masks_dict:
|
|
attention_mask = torch.einsum('bi,bj->bij', combined_masks, combined_masks)
|
|
attended = self.attention(embeddings, mask=attention_mask)
|
|
else:
|
|
attended = self.attention(embeddings)
|
|
|
|
attended = self.norm2(attended)
|
|
|
|
if self.debug:
|
|
self._debug_tensor("Attended embeddings", attended)
|
|
|
|
|
|
if tag_masks_dict:
|
|
masked_embeddings = attended * combined_masks.unsqueeze(-1)
|
|
tag_context = masked_embeddings.sum(dim=1) / combined_masks.sum(dim=1, keepdim=True).clamp(min=1.0)
|
|
else:
|
|
tag_context = attended.mean(dim=1)
|
|
|
|
|
|
tag_context = self.context_proj(tag_context)
|
|
context_scale = torch.clamp(self.context_scale, min=0.1, max=10.0)
|
|
tag_context = tag_context * context_scale
|
|
|
|
if self.debug:
|
|
self._debug_tensor("Final tag context", tag_context)
|
|
|
|
return tag_context, attended
|
|
|
|
class TagDataset:
|
|
"""Lightweight dataset wrapper for inference only"""
|
|
def __init__(self, total_tags, idx_to_tag, tag_to_category):
|
|
self.total_tags = total_tags
|
|
self.idx_to_tag = idx_to_tag if isinstance(idx_to_tag, dict) else {int(k): v for k, v in idx_to_tag.items()}
|
|
self.tag_to_category = tag_to_category
|
|
|
|
def get_tag_info(self, idx):
|
|
"""Get tag name and category for a given index"""
|
|
tag_name = self.idx_to_tag.get(idx, f"unknown-{idx}")
|
|
category = self.tag_to_category.get(tag_name, "general")
|
|
return tag_name, category
|
|
|
|
class ImageTagger(nn.Module):
|
|
def __init__(self, total_tags, dataset, model_name='efficientnet_v2_l',
|
|
num_heads=16, dropout=0.1, pretrained=True,
|
|
tag_context_size=256):
|
|
super().__init__()
|
|
|
|
self._flags = {
|
|
'debug': False,
|
|
'model_stats': False
|
|
}
|
|
|
|
|
|
self.dataset = dataset
|
|
self.tag_context_size = tag_context_size
|
|
self.embedding_dim = 1280
|
|
|
|
|
|
if model_name == 'efficientnet_v2_l':
|
|
weights = EfficientNet_V2_L_Weights.DEFAULT if pretrained else None
|
|
self.backbone = efficientnet_v2_l(weights=weights)
|
|
self.backbone.classifier = nn.Identity()
|
|
|
|
|
|
self.spatial_pool = nn.AdaptiveAvgPool2d((1, 1))
|
|
|
|
|
|
self.initial_classifier = nn.Sequential(
|
|
nn.Linear(self.embedding_dim, self.embedding_dim * 2),
|
|
nn.LayerNorm(self.embedding_dim * 2),
|
|
nn.GELU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(self.embedding_dim * 2, self.embedding_dim),
|
|
nn.LayerNorm(self.embedding_dim),
|
|
nn.GELU(),
|
|
nn.Linear(self.embedding_dim, total_tags)
|
|
)
|
|
|
|
|
|
self.tag_embedding = nn.Embedding(total_tags, self.embedding_dim)
|
|
self.tag_attention = FlashAttention(self.embedding_dim, num_heads, dropout)
|
|
self.tag_norm = nn.LayerNorm(self.embedding_dim)
|
|
|
|
|
|
self.cross_proj = nn.Sequential(
|
|
nn.Linear(self.embedding_dim, self.embedding_dim * 2),
|
|
nn.LayerNorm(self.embedding_dim * 2),
|
|
nn.GELU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(self.embedding_dim * 2, self.embedding_dim)
|
|
)
|
|
|
|
|
|
self.cross_attention = FlashAttention(self.embedding_dim, num_heads, dropout)
|
|
self.cross_norm = nn.LayerNorm(self.embedding_dim)
|
|
|
|
|
|
self.refined_classifier = nn.Sequential(
|
|
nn.Linear(self.embedding_dim * 2, self.embedding_dim * 2),
|
|
nn.LayerNorm(self.embedding_dim * 2),
|
|
nn.GELU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(self.embedding_dim * 2, self.embedding_dim),
|
|
nn.LayerNorm(self.embedding_dim),
|
|
nn.GELU(),
|
|
nn.Linear(self.embedding_dim, total_tags)
|
|
)
|
|
|
|
|
|
self.temperature = nn.Parameter(torch.ones(1) * 1.5)
|
|
|
|
def _get_selected_tags(self, logits):
|
|
"""Select top-K tags based on prediction confidence"""
|
|
|
|
probs = torch.sigmoid(logits)
|
|
|
|
|
|
batch_size = logits.size(0)
|
|
topk_values, topk_indices = torch.topk(
|
|
probs, k=self.tag_context_size, dim=1, largest=True, sorted=True
|
|
)
|
|
|
|
return topk_indices, topk_values
|
|
|
|
@property
|
|
def debug(self):
|
|
return self._flags['debug']
|
|
|
|
@debug.setter
|
|
def debug(self, value):
|
|
self._flags['debug'] = value
|
|
|
|
@property
|
|
def model_stats(self):
|
|
return self._flags['model_stats']
|
|
|
|
@model_stats.setter
|
|
def model_stats(self, value):
|
|
self._flags['model_stats'] = value
|
|
|
|
def preprocess_image(self, image_path, image_size=512):
|
|
"""Process an image for inference using same preprocessing as training"""
|
|
if not os.path.exists(image_path):
|
|
raise ValueError(f"Image not found at path: {image_path}")
|
|
|
|
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
])
|
|
|
|
try:
|
|
with Image.open(image_path) as img:
|
|
|
|
if img.mode in ('RGBA', 'P'):
|
|
img = img.convert('RGB')
|
|
|
|
|
|
width, height = img.size
|
|
aspect_ratio = width / height
|
|
|
|
|
|
if aspect_ratio > 1:
|
|
new_width = image_size
|
|
new_height = int(new_width / aspect_ratio)
|
|
else:
|
|
new_height = image_size
|
|
new_width = int(new_height * aspect_ratio)
|
|
|
|
|
|
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
|
|
|
|
|
new_image = Image.new('RGB', (image_size, image_size), (0, 0, 0))
|
|
paste_x = (image_size - new_width) // 2
|
|
paste_y = (image_size - new_height) // 2
|
|
new_image.paste(img, (paste_x, paste_y))
|
|
|
|
|
|
img_tensor = transform(new_image)
|
|
return img_tensor
|
|
except Exception as e:
|
|
raise Exception(f"Error processing {image_path}: {str(e)}")
|
|
|
|
def forward(self, x):
|
|
"""Forward pass with simplified feature handling"""
|
|
|
|
model_stats = {} if self.model_stats else {}
|
|
debug_tensors = {} if self.debug else None
|
|
|
|
|
|
features = self.backbone.features(x)
|
|
features = self.spatial_pool(features).squeeze(-1).squeeze(-1)
|
|
|
|
|
|
initial_logits = self.initial_classifier(features)
|
|
initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0)
|
|
|
|
|
|
pred_tag_indices, _ = self._get_selected_tags(initial_preds)
|
|
tag_embeddings = self.tag_embedding(pred_tag_indices)
|
|
|
|
|
|
attended_tags = self.tag_attention(tag_embeddings)
|
|
attended_tags = self.tag_norm(attended_tags)
|
|
|
|
|
|
features_proj = self.cross_proj(features)
|
|
features_expanded = features_proj.unsqueeze(1).expand(-1, self.tag_context_size, -1)
|
|
|
|
cross_attended = self.cross_attention(features_expanded, attended_tags)
|
|
cross_attended = self.cross_norm(cross_attended)
|
|
|
|
|
|
fused_features = cross_attended.mean(dim=1)
|
|
|
|
combined_features = torch.cat([features, fused_features], dim=-1)
|
|
|
|
|
|
refined_logits = self.refined_classifier(combined_features)
|
|
refined_preds = torch.clamp(refined_logits / self.temperature, min=-15.0, max=15.0)
|
|
|
|
|
|
return initial_preds, refined_preds
|
|
|
|
def predict(self, image_path, threshold=0.325, category_thresholds=None):
|
|
"""
|
|
Run inference on an image with support for category-specific thresholds.
|
|
"""
|
|
|
|
img_tensor = self.preprocess_image(image_path).unsqueeze(0)
|
|
|
|
|
|
device = next(self.parameters()).device
|
|
dtype = next(self.parameters()).dtype
|
|
img_tensor = img_tensor.to(device, dtype=dtype)
|
|
|
|
|
|
with torch.no_grad():
|
|
initial_preds, refined_preds = self.forward(img_tensor)
|
|
|
|
|
|
initial_probs = torch.sigmoid(initial_preds)
|
|
refined_probs = torch.sigmoid(refined_preds)
|
|
|
|
|
|
if category_thresholds:
|
|
|
|
refined_binary = torch.zeros_like(refined_probs)
|
|
|
|
|
|
for category, cat_threshold in category_thresholds.items():
|
|
|
|
category_mask = torch.zeros_like(refined_probs, dtype=torch.bool)
|
|
|
|
|
|
for tag_idx in range(refined_probs.size(-1)):
|
|
try:
|
|
_, tag_category = self.dataset.get_tag_info(tag_idx)
|
|
if tag_category == category:
|
|
category_mask[:, tag_idx] = True
|
|
except:
|
|
continue
|
|
|
|
|
|
cat_threshold_tensor = torch.tensor(cat_threshold, device=device, dtype=dtype)
|
|
refined_binary[category_mask] = (refined_probs[category_mask] >= cat_threshold_tensor).to(dtype)
|
|
|
|
predictions = refined_binary
|
|
else:
|
|
|
|
threshold_tensor = torch.tensor(threshold, device=device, dtype=dtype)
|
|
predictions = (refined_probs >= threshold_tensor).to(dtype)
|
|
|
|
|
|
return {
|
|
'initial_probabilities': initial_probs,
|
|
'refined_probabilities': refined_probs,
|
|
'predictions': predictions
|
|
}
|
|
|
|
def get_tags_from_predictions(self, predictions, include_probabilities=True):
|
|
"""
|
|
Convert model predictions to human-readable tags grouped by category.
|
|
"""
|
|
|
|
if predictions.dim() > 1:
|
|
predictions = predictions[0]
|
|
|
|
|
|
indices = torch.where(predictions > 0)[0].cpu().tolist()
|
|
|
|
|
|
result = {}
|
|
for idx in indices:
|
|
tag_name, category = self.dataset.get_tag_info(idx)
|
|
|
|
if category not in result:
|
|
result[category] = []
|
|
|
|
if include_probabilities:
|
|
prob = predictions[idx].item()
|
|
result[category].append((tag_name, prob))
|
|
else:
|
|
result[category].append(tag_name)
|
|
|
|
|
|
if include_probabilities:
|
|
for category in result:
|
|
result[category] = sorted(result[category], key=lambda x: x[1], reverse=True)
|
|
|
|
return result
|
|
|
|
def load_model(model_dir, device='cuda'):
|
|
"""Load model with better error handling and warnings"""
|
|
print(f"Loading model from {model_dir}")
|
|
|
|
try:
|
|
|
|
metadata_path = os.path.join(model_dir, "metadata.json")
|
|
if not os.path.exists(metadata_path):
|
|
raise FileNotFoundError(f"Metadata file not found at {metadata_path}")
|
|
|
|
with open(metadata_path, 'r') as f:
|
|
metadata = json.load(f)
|
|
|
|
|
|
model_info_path = os.path.join(model_dir, "model_info_initial_only.json")
|
|
if os.path.exists(model_info_path):
|
|
with open(model_info_path, 'r') as f:
|
|
model_info = json.load(f)
|
|
else:
|
|
print("WARNING: Model info file not found, using default settings")
|
|
model_info = {
|
|
"tag_context_size": 256,
|
|
"num_heads": 16,
|
|
"precision": "float16"
|
|
}
|
|
|
|
|
|
dataset = TagDataset(
|
|
total_tags=metadata['total_tags'],
|
|
idx_to_tag=metadata['idx_to_tag'],
|
|
tag_to_category=metadata['tag_to_category']
|
|
)
|
|
|
|
|
|
model = ImageTagger(
|
|
total_tags=metadata['total_tags'],
|
|
dataset=dataset,
|
|
num_heads=model_info.get('num_heads', 16),
|
|
tag_context_size=model_info.get('tag_context_size', 256),
|
|
pretrained=False
|
|
)
|
|
|
|
|
|
state_dict_path = os.path.join(model_dir, "model.pt")
|
|
if not os.path.exists(state_dict_path):
|
|
raise FileNotFoundError(f"Model state dict not found at {state_dict_path}")
|
|
|
|
state_dict = torch.load(state_dict_path, map_location=device)
|
|
|
|
|
|
try:
|
|
model.load_state_dict(state_dict, strict=True)
|
|
print("✓ Model state dict loaded with strict=True successfully")
|
|
except Exception as e:
|
|
print(f"! Strict loading failed: {str(e)}")
|
|
print("Attempting non-strict loading...")
|
|
|
|
|
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
|
|
|
print(f"Non-strict loading completed with:")
|
|
print(f"- {len(missing_keys)} missing keys")
|
|
print(f"- {len(unexpected_keys)} unexpected keys")
|
|
|
|
if len(missing_keys) > 0:
|
|
print(f"Sample missing keys: {missing_keys[:5]}")
|
|
if len(unexpected_keys) > 0:
|
|
print(f"Sample unexpected keys: {unexpected_keys[:5]}")
|
|
|
|
|
|
model = model.to(device)
|
|
|
|
|
|
if model_info.get('precision') == 'float16':
|
|
model = model.half()
|
|
print("✓ Model converted to half precision")
|
|
|
|
|
|
model.eval()
|
|
print("✓ Model set to evaluation mode")
|
|
|
|
|
|
param_dtype = next(model.parameters()).dtype
|
|
print(f"✓ Model loaded with precision: {param_dtype}")
|
|
|
|
return model, dataset
|
|
|
|
except Exception as e:
|
|
print(f"ERROR loading model: {str(e)}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
|
|
|
|
model_dir = sys.argv[1] if len(sys.argv) > 1 else "./exported_model"
|
|
|
|
|
|
model, dataset, thresholds = load_model(model_dir)
|
|
|
|
|
|
print(f"\nModel information:")
|
|
print(f" Total tags: {dataset.total_tags}")
|
|
print(f" Device: {next(model.parameters()).device}")
|
|
print(f" Precision: {next(model.parameters()).dtype}")
|
|
|
|
|
|
if len(sys.argv) > 2:
|
|
image_path = sys.argv[2]
|
|
print(f"\nRunning inference on {image_path}")
|
|
|
|
|
|
if thresholds and 'categories' in thresholds:
|
|
category_thresholds = {cat: opt['balanced']['threshold']
|
|
for cat, opt in thresholds['categories'].items()}
|
|
results = model.predict(image_path, category_thresholds=category_thresholds)
|
|
else:
|
|
results = model.predict(image_path)
|
|
|
|
|
|
tags = model.get_tags_from_predictions(results['predictions'])
|
|
|
|
|
|
print("\nPredicted tags:")
|
|
for category, category_tags in tags.items():
|
|
print(f"\n{category.capitalize()}:")
|
|
for tag, prob in category_tags:
|
|
print(f" {tag}: {prob:.3f}") |