Open-Source AI Cookbook documentation
在自定义数据集上微调语义分割模型并通过推理 API 使用
在自定义数据集上微调语义分割模型并通过推理 API 使用
在本 Notebook 中,我们将介绍如何在自定义数据集上微调一个 语义分割 模型。我们将使用的模型是预训练的 Segformer,这是一个强大且灵活的基于 Transformer 的分割架构,适用于各种分割任务。
对于我们的数据集,我们将使用 segments/sidewalk-semantic,它包含了标注好的人行道图像,十分适合用于城市环境中的应用。
示例应用场景:
该模型可以部署在送餐机器人中,使其能够自动在城市人行道上导航,将披萨直接送到您家门口 🍕。
一旦我们微调完成该模型,接下来会演示如何使用 Serverless 推理 API 将模型部署,使其可以通过一个简单的 API 端点进行访问。
1. 安装依赖
首先,我们需要安装微调语义分割模型所需的必要库。以下是安装依赖的步骤:
!pip install -q datasets transformers evaluate wandb
# Tested with datasets==3.0.0, transformers==4.44.2, evaluate==0.4.3, wandb==0.18.1
2. 加载数据集 📁
我们将使用 sidewalk-semantic 数据集,该数据集包含了 2021 年夏季在比利时收集的关于人行道的图像。
该数据集包括:
- 1,000 张图像及其对应的语义分割掩膜 🖼
- 34 个不同的类别 📦
由于该数据集是受限访问的,你需要登录并接受许可才能访问。同时,我们还需要身份验证来上传微调后的模型到 Hugging Face Hub。
from huggingface_hub import notebook_login
notebook_login()
sidewalk_dataset_identifier = "segments/sidewalk-semantic"
from datasets import load_dataset
dataset = load_dataset(sidewalk_dataset_identifier)
审查数据集的内部结构以熟悉它!
dataset
由于数据集仅包含训练集,我们将手动将其分为 训练集
和 测试集
。我们将 80% 的数据分配用于训练,剩余的 20% 用于评估和测试。 ➗
dataset = dataset.shuffle(seed=42)
dataset = dataset["train"].train_test_split(test_size=0.2)
train_ds = dataset["train"]
test_ds = dataset["test"]
让我们检查一个示例中存在的对象类型。我们可以看到,pixels_values
包含了 RGB 图像,而 label
包含了 ground truth mask。ground truth mask 是一个单通道图像,其中每个像素表示对应 RGB 图像中像素的类别。
image = train_ds[0]
image
3. 可视化示例! 👀
现在我们已经加载了数据集,让我们可视化一些示例以及它们的掩膜,以便更好地理解数据集的结构。
数据集包含一个 JSON 文件,该文件包含了 id2label
映射。我们将打开这个文件,读取与每个 ID 关联的类别标签。
>>> import json
>>> from huggingface_hub import hf_hub_download
>>> filename = "id2label.json"
>>> id2label = json.load(
... open(hf_hub_download(repo_id=sidewalk_dataset_identifier, filename=filename, repo_type="dataset"), "r")
... )
>>> id2label = {int(k): v for k, v in id2label.items()}
>>> label2id = {v: k for k, v in id2label.items()}
>>> num_labels = len(id2label)
>>> print("Id2label:", id2label)
Id2label: {0: 'unlabeled', 1: 'flat-road', 2: 'flat-sidewalk', 3: 'flat-crosswalk', 4: 'flat-cyclinglane', 5: 'flat-parkingdriveway', 6: 'flat-railtrack', 7: 'flat-curb', 8: 'human-person', 9: 'human-rider', 10: 'vehicle-car', 11: 'vehicle-truck', 12: 'vehicle-bus', 13: 'vehicle-tramtrain', 14: 'vehicle-motorcycle', 15: 'vehicle-bicycle', 16: 'vehicle-caravan', 17: 'vehicle-cartrailer', 18: 'construction-building', 19: 'construction-door', 20: 'construction-wall', 21: 'construction-fenceguardrail', 22: 'construction-bridge', 23: 'construction-tunnel', 24: 'construction-stairs', 25: 'object-pole', 26: 'object-trafficsign', 27: 'object-trafficlight', 28: 'nature-vegetation', 29: 'nature-terrain', 30: 'sky', 31: 'void-ground', 32: 'void-dynamic', 33: 'void-static', 34: 'void-unclear'}
让我们为每个类别分配颜色 🎨。这将帮助我们更有效地可视化分割结果,并使得在图像中解释不同类别变得更加容易。
sidewalk_palette = [
[0, 0, 0], # unlabeled
[216, 82, 24], # flat-road
[255, 255, 0], # flat-sidewalk
[125, 46, 141], # flat-crosswalk
[118, 171, 47], # flat-cyclinglane
[161, 19, 46], # flat-parkingdriveway
[255, 0, 0], # flat-railtrack
[0, 128, 128], # flat-curb
[190, 190, 0], # human-person
[0, 255, 0], # human-rider
[0, 0, 255], # vehicle-car
[170, 0, 255], # vehicle-truck
[84, 84, 0], # vehicle-bus
[84, 170, 0], # vehicle-tramtrain
[84, 255, 0], # vehicle-motorcycle
[170, 84, 0], # vehicle-bicycle
[170, 170, 0], # vehicle-caravan
[170, 255, 0], # vehicle-cartrailer
[255, 84, 0], # construction-building
[255, 170, 0], # construction-door
[255, 255, 0], # construction-wall
[33, 138, 200], # construction-fenceguardrail
[0, 170, 127], # construction-bridge
[0, 255, 127], # construction-tunnel
[84, 0, 127], # construction-stairs
[84, 84, 127], # object-pole
[84, 170, 127], # object-trafficsign
[84, 255, 127], # object-trafficlight
[170, 0, 127], # nature-vegetation
[170, 84, 127], # nature-terrain
[170, 170, 127], # sky
[170, 255, 127], # void-ground
[255, 0, 127], # void-dynamic
[255, 84, 127], # void-static
[255, 170, 127], # void-unclear
]
我们可以可视化数据集中的一些示例,包括 RGB 图像、对应的掩膜以及掩膜覆盖在图像上的效果。这将帮助我们更好地理解数据集以及掩膜如何与图像对应。📸
>>> from matplotlib import pyplot as plt
>>> import numpy as np
>>> from PIL import Image
>>> import matplotlib.patches as patches
>>> # Create and show the legend separately
>>> fig, ax = plt.subplots(figsize=(18, 2))
>>> legend_patches = [
... patches.Patch(color=np.array(color) / 255, label=label)
... for label, color in zip(id2label.values(), sidewalk_palette)
... ]
>>> ax.legend(handles=legend_patches, loc="center", bbox_to_anchor=(0.5, 0.5), ncol=5, fontsize=8)
>>> ax.axis("off")
>>> plt.show()
>>> for i in range(5):
... image = train_ds[i]
... fig, ax = plt.subplots(1, 3, figsize=(18, 6))
... # Show the original image
... ax[0].imshow(image["pixel_values"])
... ax[0].set_title("Original Image")
... ax[0].axis("off")
... mask_np = np.array(image["label"])
... # Create a new empty RGB image
... colored_mask = np.zeros((mask_np.shape[0], mask_np.shape[1], 3), dtype=np.uint8)
... # Assign colors to each value in the mask
... for label_id, color in enumerate(sidewalk_palette):
... colored_mask[mask_np == label_id] = color
... colored_mask_img = Image.fromarray(colored_mask, "RGB")
... # Show the segmentation mask
... ax[1].imshow(colored_mask_img)
... ax[1].set_title("Segmentation Mask")
... ax[1].axis("off")
... # Convert the original image to RGBA to support transparency
... image_rgba = image["pixel_values"].convert("RGBA")
... colored_mask_rgba = colored_mask_img.convert("RGBA")
... # Adjust transparency of the mask
... alpha = 128 # Transparency level (0 fully transparent, 255 fully opaque)
... image_2_with_alpha = Image.new("RGBA", colored_mask_rgba.size)
... for x in range(colored_mask_rgba.width):
... for y in range(colored_mask_rgba.height):
... r, g, b, a = colored_mask_rgba.getpixel((x, y))
... image_2_with_alpha.putpixel((x, y), (r, g, b, alpha))
... superposed = Image.alpha_composite(image_rgba, image_2_with_alpha)
... # Show the mask overlay
... ax[2].imshow(superposed)
... ax[2].set_title("Mask Overlay")
... ax[2].axis("off")
... plt.show()
4. 可视化类别分布 📊
为了更深入地了解数据集,让我们绘制每个类别的出现次数。这将帮助我们理解各类别的分布情况,并识别数据集中的潜在偏差或不平衡问题。
import matplotlib.pyplot as plt
import numpy as np
class_counts = np.zeros(len(id2label))
for example in train_ds:
mask_np = np.array(example["label"])
unique, counts = np.unique(mask_np, return_counts=True)
for u, c in zip(unique, counts):
class_counts[u] += c
>>> from matplotlib import pyplot as plt
>>> import numpy as np
>>> from matplotlib import patches
>>> labels = list(id2label.values())
>>> # Normalize colors to be in the range [0, 1]
>>> normalized_palette = [tuple(c / 255 for c in color) for color in sidewalk_palette]
>>> # Visualization
>>> fig, ax = plt.subplots(figsize=(12, 8))
>>> bars = ax.bar(range(len(labels)), class_counts, color=[normalized_palette[i] for i in range(len(labels))])
>>> ax.set_xticks(range(len(labels)))
>>> ax.set_xticklabels(labels, rotation=90, ha="right")
>>> ax.set_xlabel("Categories", fontsize=14)
>>> ax.set_ylabel("Number of Occurrences", fontsize=14)
>>> ax.set_title("Number of Occurrences by Category", fontsize=16)
>>> ax.grid(axis="y", linestyle="--", alpha=0.7)
>>> # Adjust the y-axis limit
>>> y_max = max(class_counts)
>>> ax.set_ylim(0, y_max * 1.25)
>>> for bar in bars:
... height = int(bar.get_height())
... offset = 10 # Adjust the text location
... ax.text(
... bar.get_x() + bar.get_width() / 2.0,
... height + offset,
... f"{height}",
... ha="center",
... va="bottom",
... rotation=90,
... fontsize=10,
... color="black",
... )
>>> fig.legend(
... handles=legend_patches, loc="center left", bbox_to_anchor=(1, 0.5), ncol=1, fontsize=8
... ) # Adjust ncol as needed
>>> plt.tight_layout()
>>> plt.show()
5. 初始化图像处理器并使用 Albumentations 添加数据增强 📸
我们将首先初始化图像处理器,然后使用 Albumentations 应用数据增强 🪄。这将有助于增强我们的数据集,并提高语义分割模型的性能。
import albumentations as A
from transformers import SegformerImageProcessor
image_processor = SegformerImageProcessor()
albumentations_transform = A.Compose(
[
A.HorizontalFlip(p=0.5),
A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=30, p=0.7),
A.RandomResizedCrop(height=512, width=512, scale=(0.8, 1.0), ratio=(0.75, 1.33), p=0.5),
A.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25, p=0.5),
A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=25, val_shift_limit=20, p=0.5),
A.GaussianBlur(blur_limit=(3, 5), p=0.3),
A.GaussNoise(var_limit=(10, 50), p=0.4),
]
)
def train_transforms(example_batch):
augmented_images = [albumentations_transform(image=np.array(x))["image"] for x in example_batch["pixel_values"]]
labels = [x for x in example_batch["label"]]
inputs = image_processor(augmented_images, labels)
return inputs
def val_transforms(example_batch):
images = [x for x in example_batch["pixel_values"]]
labels = [x for x in example_batch["label"]]
inputs = image_processor(images, labels)
return inputs
# Set transforms
train_ds.set_transform(train_transforms)
test_ds.set_transform(val_transforms)
6. 从检查点初始化模型
我们将使用来自检查点的预训练 Segformer 模型:nvidia/mit-b0。该架构在论文 SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers 中有详细描述,并且已在 ImageNet-1k 数据集上进行训练。
from transformers import SegformerForSemanticSegmentation
pretrained_model_name = "nvidia/mit-b0"
model = SegformerForSemanticSegmentation.from_pretrained(pretrained_model_name, id2label=id2label, label2id=label2id)
7. 设置训练参数并连接到 Weights & Biases (W&B) 📉
接下来,我们将配置训练参数,并连接到 Weights & Biases (W&B)。W&B 将帮助我们跟踪实验、可视化指标,并管理模型训练工作流,在整个过程中提供有价值的洞察。
from transformers import TrainingArguments
output_dir = "segformer-b0-segments-sidewalk-finetuned"
training_args = TrainingArguments(
output_dir=output_dir,
learning_rate=6e-5,
num_train_epochs=20,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
save_total_limit=2,
evaluation_strategy="steps",
save_strategy="steps",
save_steps=20,
eval_steps=20,
logging_steps=1,
eval_accumulation_steps=5,
load_best_model_at_end=True,
push_to_hub=True,
report_to="wandb",
)
import wandb
wandb.init(
project="segformer-b0-segments-sidewalk-finetuned", # change this
name="segformer-b0-segments-sidewalk-finetuned", # change this
config=training_args,
)
8. 设置自定义 compute_metrics 方法以增强使用 evaluate 的日志记录
我们将使用 平均交并比 (mean IoU) 作为评估模型性能的主要指标。这将使我们能够详细跟踪每个类别的性能。
此外,我们将调整评估模块的日志记录级别,以减少输出中的警告。如果图像中没有检测到某个类别,你可能会看到类似以下的警告:
RuntimeWarning: invalid value encountered in divide iou = total_area_intersect / total_area_union
如果你希望跳过这些警告并继续执行后续步骤,可以跳过此单元。
import evaluate
evaluate.logging.set_verbosity_error()
import torch
from torch import nn
import multiprocessing
metric = evaluate.load("mean_iou")
def compute_metrics(eval_pred):
with torch.no_grad():
logits, labels = eval_pred
logits_tensor = torch.from_numpy(logits)
# scale the logits to the size of the label
logits_tensor = nn.functional.interpolate(
logits_tensor,
size=labels.shape[-2:],
mode="bilinear",
align_corners=False,
).argmax(dim=1)
# currently using _compute instead of compute: https://github.com/huggingface/evaluate/pull/328#issuecomment-1286866576
pred_labels = logits_tensor.detach().cpu().numpy()
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
metrics = metric._compute(
predictions=pred_labels,
references=labels,
num_labels=len(id2label),
ignore_index=0,
reduce_labels=image_processor.do_reduce_labels,
)
# add per category metrics as individual key-value pairs
per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
per_category_iou = metrics.pop("per_category_iou").tolist()
metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})
return metrics
9. 在我们的数据集上训练模型 🏋
现在是时候在我们的自定义数据集上训练模型了。我们将使用已经准备好的训练参数,并通过连接的 Weights & Biases 集成来监控训练过程,并根据需要进行调整。让我们开始训练,并观察模型如何提高其性能!
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=test_ds,
tokenizer=image_processor,
compute_metrics=compute_metrics,
)
trainer.train()
10. 在新图像上评估模型性能 📸
训练完成后,我们将评估模型在新图像上的表现。我们将使用一张测试图像,并利用 pipeline 来评估模型在未见过的数据上的表现。
import requests
from transformers import pipeline
import numpy as np
from PIL import Image, ImageDraw
url = "https://images.unsplash.com/photo-1594098742644-314fedf61fb6?q=80&w=2672&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
image = Image.open(requests.get(url, stream=True).raw)
image_segmentator = pipeline(
"image-segmentation", model="sergiopaniego/segformer-b0-segments-sidewalk-finetuned" # Change with your model name
)
results = image_segmentator(image)
>>> plt.imshow(image)
>>> plt.axis("off")
>>> plt.show()
模型已经生成了一些掩码,我们可以可视化这些掩码以评估和理解其性能。这将帮助我们查看模型在图像分割上的表现如何,并识别需要改进的地方。
>>> image_array = np.array(image)
>>> segmentation_map = np.zeros_like(image_array)
>>> for result in results:
... mask = np.array(result["mask"])
... label = result["label"]
... label_index = list(id2label.values()).index(label)
... color = sidewalk_palette[label_index]
... for c in range(3):
... segmentation_map[:, :, c] = np.where(mask, color[c], segmentation_map[:, :, c])
>>> plt.figure(figsize=(10, 10))
>>> plt.imshow(image_array)
>>> plt.imshow(segmentation_map, alpha=0.5)
>>> plt.axis("off")
>>> plt.show()
11. 在测试集上评估性能 📊
metrics = trainer.evaluate(test_ds)
print(metrics)
12. 使用推理 API 访问模型并可视化结果 🔌
Hugging Face 🤗 提供了一个 无服务器推理 API,允许你通过 API 端点直接测试模型,且可以免费使用。有关如何使用该 API 的详细指导,请参阅这个 指南。
我们将利用这个 API 来探索其功能,并看看它如何帮助我们测试模型。
重要提示
在使用无服务器推理 API 之前,你需要通过创建模型卡来设置模型任务。在为你的微调模型创建模型卡时,确保正确指定任务。
一旦模型任务设置完成,我们可以下载一张图片并使用 InferenceClient 来测试模型。这个客户端将允许我们通过 API 将图片发送到模型,并获取结果以进行评估。
>>> url = "https://images.unsplash.com/photo-1594098742644-314fedf61fb6?q=80&w=2672&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> plt.imshow(image)
>>> plt.axis("off")
>>> plt.show()
我们将使用 InferenceClient 的 image_segmentation 方法。该方法接受模型和图片作为输入,并返回预测的掩码。这将帮助我们测试模型在新图片上的表现。
from huggingface_hub import InferenceClient
client = InferenceClient()
response = client.image_segmentation(
model="sergiopaniego/segformer-b0-segments-sidewalk-finetuned", # Change with your model name
image="https://images.unsplash.com/photo-1594098742644-314fedf61fb6?q=80&w=2672&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
)
print(response)
通过预测的掩码,我们可以展示结果。
>>> image_array = np.array(image)
>>> segmentation_map = np.zeros_like(image_array)
>>> for result in response:
... mask = np.array(result["mask"])
... label = result["label"]
... label_index = list(id2label.values()).index(label)
... color = sidewalk_palette[label_index]
... for c in range(3):
... segmentation_map[:, :, c] = np.where(mask, color[c], segmentation_map[:, :, c])
>>> plt.figure(figsize=(10, 10))
>>> plt.imshow(image_array)
>>> plt.imshow(segmentation_map, alpha=0.5)
>>> plt.axis("off")
>>> plt.show()
也可以使用 JavaScript 版 Inference API。下面是如何使用 JavaScript 调用该 API 的示例:
import { HfInference } from "@huggingface/inference";
const inference = new HfInference(HF_TOKEN);
await inference.imageSegmentation({
data: await (await fetch("https://picsum.photos/300/300")).blob(),
model: "sergiopaniego/segformer-b0-segments-sidewalk-finetuned",
});
这个示例展示了如何用 JavaScript 获取图片并将其传递给指定的模型进行图像分割任务。
额外加分
你还可以通过 Hugging Face Space 部署微调后的模型。例如,我创建了一个自定义 Space 来展示这个过程:Semantic Segmentation with SegFormer Fine-Tuned on Segments/Sidewalk。

from IPython.display import IFrame
IFrame(src="https://sergiopaniego-segformer-b0-segments-sidewalk-finetuned.hf.space", width=1000, height=800)
总结
在本指南中,我们成功地在自定义数据集上微调了一个语义分割模型,并利用无服务器推理 API 进行了测试。这展示了如何轻松地将模型集成到各种应用中,并利用 Hugging Face 工具进行部署。
希望本指南能够为你提供所需的工具和知识,使你能够自信地微调和部署自己的模型!🚀
< > Update on GitHub