Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,704 Bytes
e85fecb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
""" "
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
import PIL
import numpy as np
import torch
import torch.utils.data
import torchvision
from typing import List, Dict
torchvision.disable_beta_transforms_warning()
__all__ = ["show_sample", "save_samples"]
def save_samples(samples: torch.Tensor, targets: List[Dict], output_dir: str, split: str, normalized: bool, box_fmt: str):
'''
normalized: whether the boxes are normalized to [0, 1]
box_fmt: 'xyxy', 'xywh', 'cxcywh', D-FINE uses 'cxcywh' for training, 'xyxy' for validation
'''
from torchvision.transforms.functional import to_pil_image
from torchvision.ops import box_convert
from pathlib import Path
from PIL import ImageDraw, ImageFont
import os
os.makedirs(Path(output_dir) / Path(f"{split}_samples"), exist_ok=True)
# Predefined colors (standard color names recognized by PIL)
BOX_COLORS = [
"red", "blue", "green", "orange", "purple",
"cyan", "magenta", "yellow", "lime", "pink",
"teal", "lavender", "brown", "beige", "maroon",
"navy", "olive", "coral", "turquoise", "gold"
]
LABEL_TEXT_COLOR = "white"
font = ImageFont.load_default()
font.size = 32
for i, (sample, target) in enumerate(zip(samples, targets)):
sample_visualization = sample.clone().cpu()
target_boxes = target["boxes"].clone().cpu()
target_labels = target["labels"].clone().cpu()
target_image_id = target["image_id"].item()
target_image_path = target["image_path"]
target_image_path_stem = Path(target_image_path).stem
sample_visualization = to_pil_image(sample_visualization)
sample_visualization_w, sample_visualization_h = sample_visualization.size
# normalized to pixel space
if normalized:
target_boxes[:, 0] = target_boxes[:, 0] * sample_visualization_w
target_boxes[:, 2] = target_boxes[:, 2] * sample_visualization_w
target_boxes[:, 1] = target_boxes[:, 1] * sample_visualization_h
target_boxes[:, 3] = target_boxes[:, 3] * sample_visualization_h
# any box format -> xyxy
target_boxes = box_convert(target_boxes, in_fmt=box_fmt, out_fmt="xyxy")
# clip to image size
target_boxes[:, 0] = torch.clamp(target_boxes[:, 0], 0, sample_visualization_w)
target_boxes[:, 1] = torch.clamp(target_boxes[:, 1], 0, sample_visualization_h)
target_boxes[:, 2] = torch.clamp(target_boxes[:, 2], 0, sample_visualization_w)
target_boxes[:, 3] = torch.clamp(target_boxes[:, 3], 0, sample_visualization_h)
target_boxes = target_boxes.numpy().astype(np.int32)
target_labels = target_labels.numpy().astype(np.int32)
draw = ImageDraw.Draw(sample_visualization)
# draw target boxes
for box, label in zip(target_boxes, target_labels):
x1, y1, x2, y2 = box
# Select color based on class ID
box_color = BOX_COLORS[int(label) % len(BOX_COLORS)]
# Draw box (thick)
draw.rectangle([x1, y1, x2, y2], outline=box_color, width=3)
label_text = f"{label}"
# Measure text size
text_width, text_height = draw.textbbox((0, 0), label_text, font=font)[2:4]
# Draw text background
padding = 2
draw.rectangle(
[x1, y1 - text_height - padding * 2, x1 + text_width + padding * 2, y1],
fill=box_color
)
# Draw text (LABEL_TEXT_COLOR)
draw.text((x1 + padding, y1 - text_height - padding), label_text,
fill=LABEL_TEXT_COLOR, font=font)
save_path = Path(output_dir) / f"{split}_samples" / f"{target_image_id}_{target_image_path_stem}.webp"
sample_visualization.save(save_path)
def show_sample(sample):
"""for coco dataset/dataloader"""
import matplotlib.pyplot as plt
from torchvision.transforms.v2 import functional as F
from torchvision.utils import draw_bounding_boxes
image, target = sample
if isinstance(image, PIL.Image.Image):
image = F.to_image_tensor(image)
image = F.convert_dtype(image, torch.uint8)
annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)
fig, ax = plt.subplots()
ax.imshow(annotated_image.permute(1, 2, 0).numpy())
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
fig.tight_layout()
fig.show()
plt.show()
|