【人工智能】【Python】计算机视觉图像分类实验(Swin-T)

我基于PyTorch框架和Swin Transformer模型完成了Caltech101数据集的图像分类实验。针对ViT的局限,我利用滑动窗口自注意力和层次化特征机制构建模型,并采用冻结主干与全局微调结合的策略进行训练。经对批处理大小和学习率的对比调优,确定了最佳超参数组合,最终模型在验证集上达到了99.02%的准确率,实现了高精度的图像分类与推理。

实验目的

1.掌握图像分类模型训练。
2.学会用深度学习框架训练好的模型进行图片分类。

实验原理

传统的卷积神经网络(CNN)通过堆叠卷积层来提取图像特征,在图像识别任务中取得了巨大成功。然而,CNN的卷积操作具有局部性,难以捕捉图像的全局依赖关系,这在处理具有复杂上下文信息的任务时会成为瓶颈。为了解决这个问题,研究者们将最初应用于自然语言处理领域的Transformer模型引入了计算机视觉,Vision Transformer (ViT)便是其中的代表。ViT将图像分割成固定大小的图块(patches),并将它们作为序列输入到标准的Transformer编码器中,从而利用自注意力(Self-Attention)机制来学习全局特征。

尽管ViT在全局建模上表现出色,但它也存在一些问题:首先,它缺乏CNN固有的归纳偏置(如局部性和平移不变性),导致在数据量不足时性能不佳;其次,对整个图像进行全局自注意力计算的复杂度与图像尺寸的平方成正比,这使得它难以应用于高分辨率图像。

本实验采用的Swin Transformer模型正是为了解决上述问题而提出的。它巧妙地结合了CNN和Transformer的优点,其核心创新在于层级式特征表示(Hierarchical Feature Maps)和移动窗口自注意力(Shifted Windows Self-Attention, W-MSA & SW-MSA)机制。

图片[1] - AI科研 编程 读书笔记 - 【人工智能】【Python】计算机视觉图像分类实验(Swin-T) - AI科研 编程 读书笔记 - 小竹の笔记本

实验环境

系统环境基础配置信息

CPU 12th Gen Intel(R) Core(TM) i7-12650H
GPU NVIDIA GeForce RTX 4060 Laptop
内存大小 32G
网络带宽 100Mbps
操作系统 Windows11

实验环境基础配置信息

CUDA 12.9
Python 3.9
PyTorch 2.8

实验内容及步骤

数据集

本案例数据集下载地址:caltech101数据集16类_数据集-飞桨AI Studio星河社区,Caltech101包含101个类别(如ak47、calculator、cannon等),每个类别约40-800张图片。

模型介绍

本实验我使用Swin Transformer进行分类任务,论文于2021年发布于ICCV 2021(国际计算机视觉会议)。作者团队所在机构包括微软亚洲研究院、中国科学技术大学、西安交通大学和清华大学。

图片[2] - AI科研 编程 读书笔记 - 【人工智能】【Python】计算机视觉图像分类实验(Swin-T) - AI科研 编程 读书笔记 - 小竹の笔记本

现有的Vision Transformer (ViT)模型虽然在图像分类任务中表现出色,但在处理高分辨率图像或密集视觉任务时效率较低,且性能不佳。Swin Transformer这篇论文的目标是设计一种新的Transformer架构,既能处理高分辨率图像,又能在多种视觉任务中表现出色。

这篇论文的核心创新点是滑动窗口自注意力机制和层次化特征表示。滑动窗口自注意力机制是指通过在连续的自注意力层之间移动窗口,使得每个窗口的查询块共享相同的键集合,从而提高计算效率,防止小窗口之间独立计算,会损失掉很多信息,如下图所示。

图片[3] - AI科研 编程 读书笔记 - 【人工智能】【Python】计算机视觉图像分类实验(Swin-T) - AI科研 编程 读书笔记 - 小竹の笔记本

层次化特征表示通过将输入图像分割成非重叠的块,并在每个块上应用线性嵌入,然后在多个Transformer块中进行特征变换,并逐步合并块以减少token数量,最终形成层次化的特征表示。

Swin Transformer的整体结构如下图所示,以我这次实验使用的Swin-T为例。模型首先将输入的图像划分为不重叠的4×4补丁,每个补丁展平成一个48维向量(4×4×3),随后经线性嵌入层映射到维度C,形成补丁标记序列。

图片[4] - AI科研 编程 读书笔记 - 【人工智能】【Python】计算机视觉图像分类实验(Swin-T) - AI科研 编程 读书笔记 - 小竹の笔记本

接下来,模型在这些标记上堆叠若干经过改进的Transformer块(也就是图中的Swin Transformer Block),并保持标记数量不变,构成Stage 1。为实现分层特征表示,模型在阶段之间通过补丁合并层(Patch Merging)逐步降低分辨率:每次将相邻2×2补丁的特征拼接后经线性变换,分辨率下降两倍、通道数翻倍。

该过程依次形成四个阶段,特征图的空间分辨率分别为H/4、H/8、H/16和H/32,对应Stage 1到Stage 4。最终模型获得与典型卷积网络(如VGG、ResNet)相似的分层表示结构,可无缝替代它们作为多种视觉任务的通用骨干网络。

在这次实验里,Swin Transformer的两大核心创新带来了显著优势:滑动窗口自注意力机制通过局部窗口内的自注意力计算与跨窗口的移位机制,有效平衡了全局建模能力与计算效率,适合中小规模数据集的高分辨率图像。层次化特征表示通过逐层的补丁合并构建多尺度特征,使模型在捕捉细粒度局部信息的同时保留全局语义结构,从而在复杂背景与外观差异较大的类别间表现出更强的判别能力。

模型构建

基于PyTorch框架和timm库,可调取Swin Transformer-Tiny(Swin-T)模型并加载预训练权重。Swin Transformer模型具体实现如下代码所示:

import math
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_

_int_or_tuple_2_t = Union[int, Tuple[int, int]]

def window_partition(x: torch.Tensor, window_size: Tuple[int, int]) -> torch.Tensor:
    B, H, W, C = x.shape
    x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
    return windows

def window_reverse(windows: torch.Tensor, window_size: Tuple[int, int], H: int, W: int) -> torch.Tensor:
    C = windows.shape[-1]
    x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
    return x

def get_relative_position_index(win_h: int, win_w: int) -> torch.Tensor:
    coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)], indexing='ij'))
    coords_flatten = torch.flatten(coords, 1)
    relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
    relative_coords = relative_coords.permute(1, 2, 0).contiguous()
    relative_coords[:, :, 0] += win_h - 1
    relative_coords[:, :, 1] += win_w - 1
    relative_coords[:, :, 0] *= 2 * win_w - 1
    return relative_coords.sum(-1)

class WindowAttention(nn.Module):
    def __init__(self, dim: int, num_heads: int, head_dim: Optional[int] = None, window_size: _int_or_tuple_2_t = 7, qkv_bias: bool = True, attn_drop: float = 0., proj_drop: float = 0.):
        super().__init__()
        self.dim = dim
        self.window_size = to_2tuple(window_size)
        win_h, win_w = self.window_size
        self.window_area = win_h * win_w
        self.num_heads = num_heads
        head_dim = head_dim or dim // num_heads
        attn_dim = head_dim * num_heads
        self.scale = head_dim ** -0.5
        self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads))
        self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w), persistent=False)
        self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(attn_dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def _get_rel_pos_bias(self) -> torch.Tensor:
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_area, self.window_area, -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
        return relative_position_bias.unsqueeze(0)

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)
        attn = attn + self._get_rel_pos_bias()
        if mask is not None:
            num_win = mask.shape[0]
            attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, -1)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class SwinTransformerBlock(nn.Module):
    def __init__(self, dim: int, input_resolution: _int_or_tuple_2_t, num_heads: int, head_dim: Optional[int] = None, window_size: _int_or_tuple_2_t = 7, shift_size: int = 0, mlp_ratio: float = 4., qkv_bias: bool = True, proj_drop: float = 0., attn_drop: float = 0., drop_path: float = 0., act_layer: Callable = nn.GELU, norm_layer: Callable = nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.window_size = to_2tuple(window_size)
        self.shift_size = to_2tuple(shift_size)
        self.window_area = self.window_size[0] * self.window_size[1]
        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(dim, num_heads=num_heads, head_dim=head_dim, window_size=self.window_size, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop)
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop)
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        if any(self.shift_size):
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))
            cnt = 0
            for h in ((0, -self.window_size[0]), (-self.window_size[0], -self.shift_size[0]), (-self.shift_size[0], None)):
                for w in ((0, -self.window_size[1]), (-self.window_size[1], -self.shift_size[1]), (-self.shift_size[1], None)):
                    img_mask[:, h[0]:h[1], w[0]:w[1], :] = cnt
                    cnt += 1
            mask_windows = window_partition(img_mask, self.window_size)
            mask_windows = mask_windows.view(-1, self.window_area)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None
        self.register_buffer("attn_mask", attn_mask, persistent=False)

    def _attn(self, x):
        B, H, W, C = x.shape
        has_shift = any(self.shift_size)
        if has_shift:
            shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
        else:
            shifted_x = x
        pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
        pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
        shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h))
        _, Hp, Wp, _ = shifted_x.shape
        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = x_windows.view(-1, self.window_area, C)
        attn_windows = self.attn(x_windows, mask=self.attn_mask)
        attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)
        shifted_x = shifted_x[:, :H, :W, :].contiguous()
        if has_shift:
            x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
        else:
            x = shifted_x
        return x

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, H, W, C = x.shape
        x = x + self.drop_path1(self._attn(self.norm1(x)))
        x = x.reshape(B, -1, C)
        x = x + self.drop_path2(self.mlp(self.norm2(x)))
        x = x.reshape(B, H, W, C)
        return x

class PatchMerging(nn.Module):
    def __init__(self, dim: int, out_dim: Optional[int] = None, norm_layer: Callable = nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.out_dim = out_dim or 2 * dim
        self.norm = norm_layer(4 * dim)
        self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, H, W, C = x.shape
        pad_values = (0, 0, 0, W % 2, 0, H % 2)
        x = nn.functional.pad(x, pad_values)
        _, H, W, _ = x.shape
        x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
        x = self.norm(x)
        x = self.reduction(x)
        return x

class SwinTransformerStage(nn.Module):
    def __init__(self, dim: int, out_dim: int, input_resolution: Tuple[int, int], depth: int, downsample: bool = True, num_heads: int = 4, head_dim: Optional[int] = None, window_size: _int_or_tuple_2_t = 7, mlp_ratio: float = 4., qkv_bias: bool = True, proj_drop: float = 0., attn_drop: float = 0., drop_path: Union[List[float], float] = 0., norm_layer: Callable = nn.LayerNorm):
        super().__init__()
        self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution
        window_size = to_2tuple(window_size)
        shift_size = tuple([w // 2 for w in window_size])
        if downsample:
            self.downsample = PatchMerging(dim=dim, out_dim=out_dim, norm_layer=norm_layer)
        else:
            self.downsample = nn.Identity()
        self.blocks = nn.Sequential(*[
            SwinTransformerBlock(dim=out_dim, input_resolution=self.output_resolution, num_heads=num_heads, head_dim=head_dim, window_size=window_size, shift_size=0 if (i % 2 == 0) else shift_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_drop=proj_drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) for i in range(depth)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.downsample(x)
        x = self.blocks(x)
        return x

class SwinTransformer(nn.Module):
    def __init__(self, img_size: _int_or_tuple_2_t = 224, patch_size: int = 4, in_chans: int = 3, num_classes: int = 1000, embed_dim: int = 96, depths: Tuple[int, ...] = (2, 2, 6, 2), num_heads: Tuple[int, ...] = (3, 6, 12, 24), head_dim: Optional[int] = None, window_size: _int_or_tuple_2_t = 7, mlp_ratio: float = 4., qkv_bias: bool = True, drop_rate: float = 0., proj_drop_rate: float = 0., attn_drop_rate: float = 0., drop_path_rate: float = 0.1, embed_layer: Callable = PatchEmbed, norm_layer: Union[str, Callable] = nn.LayerNorm, **kwargs):
        super().__init__()
        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim[0], norm_layer=norm_layer, output_fmt='NHWC')
        patch_grid = self.patch_embed.grid_size
        head_dim = to_ntuple(self.num_layers)(head_dim)
        window_size = to_ntuple(self.num_layers)(window_size)
        mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio)
        dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
        layers = []
        in_dim = embed_dim[0]
        scale = 1
        for i in range(self.num_layers):
            out_dim = embed_dim[i]
            layers += [SwinTransformerStage(dim=in_dim, out_dim=out_dim, input_resolution=(patch_grid[0] // scale, patch_grid[1] // scale), depth=depths[i], downsample=i > 0, num_heads=num_heads[i], head_dim=head_dim[i], window_size=window_size[i], mlp_ratio=mlp_ratio[i], qkv_bias=qkv_bias, proj_drop=proj_drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)]
            in_dim = out_dim
            if i > 0:
                scale *= 2
        self.layers = nn.Sequential(*layers)
        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        x = self.patch_embed(x)
        x = self.layers(x)
        x = self.norm(x)
        return x

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.forward_features(x)
        x = x.flatten(1, 2)
        x = self.avgpool(x.transpose(1, 2))
        x = torch.flatten(x, 1)
        x = self.head(x)
        return x

def swin_tiny_patch4_window7_224(**kwargs) -> SwinTransformer:
    """ Swin-T @ 224x224, ImageNet-1k """
    model_kwargs = dict(patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs)
    model = SwinTransformer(**model_kwargs)
    return model

训练配置

这里本次实验的训练过程基于PyTorch深度学习框架,借助timm库加载预训练的SwinTransformer模型。训练分为两个阶段,先是冻结主干网络训练分类头,随后解冻整个网络进行微调。具体超参数配置如下:输入图片尺寸224×224,Batch Size是16,总训练轮次为35轮,其中冻结训练阶段为5轮,全局微调阶段为30轮。优化器采用AdamW,在冻结训练阶段,学习率设为1e-4,权重衰减为1e-4;全局微调阶段,学习率调整为5e-5(即初始学习率的0.5倍),权重衰减变为1e-5。损失函数选用交叉熵损失(CrossEntropyLoss),设置随机种子为42以确保实验的可复现性。另外优先使用GPU加速训练。配置参数的代码如下。

    parser = argparse.ArgumentParser()
    parser.add_argument('--data-root', default='datasets/caltech101')
    parser.add_argument('--model-name', default='swin_tiny_patch4_window7_224')
    parser.add_argument('--img-size', type=int, default=224)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--epochs', type=int, default=30)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--freeze-epochs', type=int, default=5)
    parser.add_argument('--save-dir', type=str, default='./checkpoints')
    args = parser.parse_args()

模型训练

训练步骤如下:

(1)模型实例化

# 模型与预训练权重
model = timm.create_model(args.model_name, pretrained=True, num_classes=num_classes)
model.to(device)

# 冻结 backbone(只训练 head)
for name, p in model.named_parameters():
    if 'head' not in name and 'head' not in name.lower():
        p.requires_grad = False

(2)配置loss函数

criterion = nn.CrossEntropyLoss()

(3)配置参数优化器

optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=1e-4)

(4)开始训练,每经过一轮打印一次日志,并仅保存最优验证准确率的模型权重,日志如下。

D:\AnacondaEnvs\pytorch\python.exe D:\桌面\编程相关\04_IDE练习项目\PyTorch\Computer-Vision\CV1\train.py
using device: cuda
Found 1567 training images, 205 val images, 0 test images. num_classes=16

[Head] Epoch 1/5 | time 50.8s | train_loss=2.4672 train_acc=0.3854 | val_loss=2.0683 val_acc=0.7268 | lr=1.00e-04

[Head] Epoch 2/5 | time 52.1s | train_loss=1.8048 train_acc=0.8003 | val_loss=1.4663 val_acc=0.8878 | lr=1.00e-04

[Head] Epoch 3/5 | time 51.8s | train_loss=1.3217 train_acc=0.8941 | val_loss=1.0550 val_acc=0.9171 | lr=1.00e-04

[Head] Epoch 4/5 | time 52.8s | train_loss=1.0018 train_acc=0.9196 | val_loss=0.7898 val_acc=0.9366 | lr=1.00e-04
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Head] Epoch 5/5 | time 53.5s | train_loss=0.7803 train_acc=0.9355 | val_loss=0.6165 val_acc=0.9366 | lr=1.00e-04

[Fine] Epoch 6/35 | time 59.1s | train_loss=0.2609 train_acc=0.9311 | val_loss=0.1569 val_acc=0.9463 | lr=5.00e-05

[Fine] Epoch 7/35 | time 56.4s | train_loss=0.0908 train_acc=0.9751 | val_loss=0.0923 val_acc=0.9707 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 8/35 | time 57.6s | train_loss=0.0499 train_acc=0.9892 | val_loss=0.1051 val_acc=0.9659 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 9/35 | time 55.5s | train_loss=0.0414 train_acc=0.9911 | val_loss=0.0927 val_acc=0.9512 | lr=5.00e-05

[Fine] Epoch 10/35 | time 55.2s | train_loss=0.0504 train_acc=0.9860 | val_loss=0.0531 val_acc=0.9805 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 11/35 | time 56.6s | train_loss=0.0253 train_acc=0.9943 | val_loss=0.0636 val_acc=0.9707 | lr=5.00e-05

[Fine] Epoch 12/35 | time 58.1s | train_loss=0.0244 train_acc=0.9943 | val_loss=0.0545 val_acc=0.9854 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 13/35 | time 56.0s | train_loss=0.0084 train_acc=0.9981 | val_loss=0.0514 val_acc=0.9756 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 14/35 | time 56.7s | train_loss=0.0285 train_acc=0.9923 | val_loss=0.0774 val_acc=0.9707 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 15/35 | time 55.7s | train_loss=0.0418 train_acc=0.9892 | val_loss=0.0647 val_acc=0.9756 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 16/35 | time 56.6s | train_loss=0.0225 train_acc=0.9962 | val_loss=0.0529 val_acc=0.9854 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 17/35 | time 59.7s | train_loss=0.0283 train_acc=0.9930 | val_loss=0.1446 val_acc=0.9561 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 18/35 | time 58.6s | train_loss=0.0126 train_acc=0.9987 | val_loss=0.0312 val_acc=0.9854 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 19/35 | time 57.6s | train_loss=0.0170 train_acc=0.9955 | val_loss=0.0678 val_acc=0.9707 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 20/35 | time 60.3s | train_loss=0.0135 train_acc=0.9974 | val_loss=0.0586 val_acc=0.9805 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 21/35 | time 58.6s | train_loss=0.0076 train_acc=0.9987 | val_loss=0.0701 val_acc=0.9756 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 22/35 | time 60.8s | train_loss=0.0209 train_acc=0.9930 | val_loss=0.0522 val_acc=0.9805 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 23/35 | time 58.0s | train_loss=0.0278 train_acc=0.9923 | val_loss=0.3066 val_acc=0.9268 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 24/35 | time 61.0s | train_loss=0.0903 train_acc=0.9668 | val_loss=0.0990 val_acc=0.9610 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 25/35 | time 60.6s | train_loss=0.0236 train_acc=0.9923 | val_loss=0.0782 val_acc=0.9659 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 26/35 | time 53.4s | train_loss=0.0133 train_acc=0.9949 | val_loss=0.0606 val_acc=0.9659 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 27/35 | time 56.7s | train_loss=0.0264 train_acc=0.9917 | val_loss=0.1247 val_acc=0.9756 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 28/35 | time 57.2s | train_loss=0.0069 train_acc=0.9987 | val_loss=0.0673 val_acc=0.9805 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 29/35 | time 61.7s | train_loss=0.0068 train_acc=0.9981 | val_loss=0.0508 val_acc=0.9805 | lr=5.00e-05

[Fine] Epoch 30/35 | time 61.2s | train_loss=0.0053 train_acc=0.9987 | val_loss=0.0471 val_acc=0.9902 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 31/35 | time 61.0s | train_loss=0.0080 train_acc=0.9974 | val_loss=0.0859 val_acc=0.9805 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 32/35 | time 60.2s | train_loss=0.0098 train_acc=0.9981 | val_loss=0.0614 val_acc=0.9805 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 33/35 | time 55.4s | train_loss=0.0195 train_acc=0.9962 | val_loss=0.0722 val_acc=0.9854 | lr=5.00e-05
Training: 0%| | 0/98 [00:00<?, ?it/s]
[Fine] Epoch 34/35 | time 57.9s | train_loss=0.0273 train_acc=0.9936 | val_loss=0.1053 val_acc=0.9805 | lr=5.00e-05

[Fine] Epoch 35/35 | time 56.8s | train_loss=0.0104 train_acc=0.9968 | val_loss=0.0503 val_acc=0.9805 | lr=5.00e-05

(5)训练完成后的loss和accuracy变化图如下

图片[5] - AI科研 编程 读书笔记 - 【人工智能】【Python】计算机视觉图像分类实验(Swin-T) - AI科研 编程 读书笔记 - 小竹の笔记本

绘制图像的代码如下

# 绘图:train+val loss 和 acc 曲线
try:
    import matplotlib
    matplotlib.use('Agg')  # 不依赖可视化环境
    import matplotlib.pyplot as plt
except Exception as ex:
    print("matplotlib is required to plot training curves. Install with: pip install matplotlib")
    return

epochs_range = list(range(1, len(train_losses) + 1))
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(epochs_range, train_losses, label='train_loss', linestyle='-')
axes[0].plot(epochs_range, val_losses, label='val_loss', linestyle='--')
axes[0].set_title('Loss')
axes[0].set_xlabel('epoch')
axes[0].set_ylabel('loss')
axes[0].legend()

axes[1].plot(epochs_range, train_accs, label='train_acc', linestyle='-')
axes[1].plot(epochs_range, val_accs, label='val_acc', linestyle='--')
axes[1].set_title('Accuracy')
axes[1].set_xlabel('epoch')
axes[1].set_ylabel('accuracy')
axes[1].legend()

plt.tight_layout()
plot_path = os.path.join(args.save_dir, 'train_val_curve.png')
plt.savefig(plot_path)
print(f"Saved training curves to: {plot_path}")

模型优化

在深度学习模型的训练过程中,超参数的选择对最终的模型性能和训练效率起着至关重要的作用。为了给Swin Transformer模型找到一套更优的配置,我以batch_size为16、初始学习率为1e-4的实验作为基准,对批处理大小和初始学习率这两个关键超参数进行了一系列对比实验。我的基准实验在训练结束后,模型在验证集上取得了99.02%的准确率,这也是所有对比实验中的最高值。

1.批处理大小(Batch Size)调优

首先,我探索了批处理大小对模型性能的影响。批处理大小指的是一次训练迭代所选取的样本数量,它直接影响着模型的梯度更新方向和训练的稳定性。在合理范围内,增大批处理大小可以提高硬件的并行计算效率。我分别尝试了32、64和128的批处理大小。当设置batch_size为32时,模型在验证集上的准确率为98.05%;当batch_size为64时,准确率为98.54%;当batch_size增大到128时,准确率则回落至98.05%。这些结果表明,对于我的任务,盲目增大批处理大小反而导致了模型性能的下降。另一方面,过小的批处理大小会带来更严重的梯度噪声,影响模型的稳定收敛。当我将batch_size减小到8时,模型在验证集上的最终准确率仅为97.56%,显著低于我的基准模型。因此,在批处理大小的系列实验中,batch_size为16时表现最佳。

图片[5] - AI科研 编程 读书笔记 - 【人工智能】【Python】计算机视觉图像分类实验(Swin-T) - AI科研 编程 读书笔记 - 小竹の笔记本
Batch Size=16;Learning Rate=1e-4
图片[7] - AI科研 编程 读书笔记 - 【人工智能】【Python】计算机视觉图像分类实验(Swin-T) - AI科研 编程 读书笔记 - 小竹の笔记本
Batch Size=8;Learning Rate=1e-4
图片[8] - AI科研 编程 读书笔记 - 【人工智能】【Python】计算机视觉图像分类实验(Swin-T) - AI科研 编程 读书笔记 - 小竹の笔记本
Batch Size=32;Learning Rate=1e-4
图片[9] - AI科研 编程 读书笔记 - 【人工智能】【Python】计算机视觉图像分类实验(Swin-T) - AI科研 编程 读书笔记 - 小竹の笔记本
Batch Size=64;Learning Rate=1e-4
图片[10] - AI科研 编程 读书笔记 - 【人工智能】【Python】计算机视觉图像分类实验(Swin-T) - AI科研 编程 读书笔记 - 小竹の笔记本
Batch Size=128;Learning Rate=1e-4

2.初始学习率(Learning Rate)调优

接着,我固定最优的批处理大小为16,进一步对初始学习率进行了调优。学习率是控制模型参数更新幅度的核心超参数,过大或过小都可能影响模型的收敛效果。在基准实验中,我使用的初始学习率为1e-4。为了进行对比,我尝试了更小的学习率。当初始学习率设置为5e-5时,模型在验证集上的准确率为98.54%;当学习率进一步减小到1e-5时,准确率则下降至97.56%。实验结果显示,降低初始学习率导致了最终性能的下降,这可能是因为过小的学习率使得模型在有限的训练轮次内未能充分收敛。

图片[11] - AI科研 编程 读书笔记 - 【人工智能】【Python】计算机视觉图像分类实验(Swin-T) - AI科研 编程 读书笔记 - 小竹の笔记本
Batch Size=16;Learning Rate=5e-5
图片[12] - AI科研 编程 读书笔记 - 【人工智能】【Python】计算机视觉图像分类实验(Swin-T) - AI科研 编程 读书笔记 - 小竹の笔记本
Batch Size=16;Learning Rate=1e-5

以下是实验各参数情况下的表格:

实验编号Batch SizeLearning RateVal AccVal Loss
train6_bs881e-497.56%0.1054
train1_bs16161e-499.02%0.0471
train3_bs32321e-498.05%0.0488
train4_bs64641e-498.54%0.0360
train5_bs1281281e-498.05%0.0664
train8_lr5e-5165e-598.54%0.0563
train9_lr1e-5161e-597.56%0.1204

综合以上所有实验,我发现batch_size为16、初始学习率为1e-4的组合是本次任务的最优配置,它使模型在验证集上取得了99.02%的最高准确率。其他超参数组合,无论是增大或减小批处理大小,还是降低学习率,都未能超越这一基准性能。

模型评估

我们使用验证集来评估训练过程保存的最优模型,首先加载模型参数,之后遍历验证集进行预测并输出平均准确率。

# 验证集评估最优模型
best_path = os.path.join(args.save_dir, 'best_full.pth')
if os.path.exists(best_path):
    model.load_state_dict(torch.load(best_path, map_location=device))
    vloss, vacc = evaluate(model, val_loader, criterion, device)
    print(f"\n[Best Model on Validation Set] Val Acc: {vacc:.4f} | Val Loss: {vloss:.4f}")
else:
    print("Warning: best_full.pth not found, skipping validation evaluation.")

最优验证准确率:99.02%,验证损失:0.0471

[Best Model on Validation Set] Val Acc: 0.9902 | Val Loss: 0.0471
Saved training curves to: ./checkpoints\train_val_curve.png

进程已结束,退出代码为 0

模型推理测试

采用与训练过程同样的图片转换方式对测试集图片进行预处理,然后加载之前保存的最优模型,将单张图片送入模型进行推理预测,最终使用matplotlib展示原图和预测的类别及其置信度。

import torch
import timm
import torchvision.transforms as T
from PIL import Image
import argparse
import matplotlib.pyplot as plt
import os

# 固定类别映射表(index -> 类别名)
IDX2CLASS = {
    0: "ak47",
    1: "binoculars",
    2: "boom-box",
    3: "calculator",
    4: "cannon",
    5: "computer-keyboard",
    6: "computer-monitor",
    7: "computer-mouse",
    8: "doorknob",
    9: "dumb-bell",
    10: "flashlight",
    11: "head-phones",
    12: "joy-stick",
    13: "palm-pilot",
    14: "video-projector",
    15: "washing-machine"
}
def get_transforms(img_size=224):
    return T.Compose([
        T.Resize(int(img_size * 1.14)),
        T.CenterCrop(img_size),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]),
    ])
@torch.no_grad()
def predict_single_image(model, image_path, transform, device, show=True):
    model.eval()
    img = Image.open(image_path).convert('RGB')
    tensor = transform(img).unsqueeze(0).to(device)
    with torch.amp.autocast("cuda"):
        output = model(tensor)
        pred = torch.argmax(output, dim=1).item()
        conf = torch.softmax(output, dim=1)[0, pred].item()

    pred_name = IDX2CLASS.get(pred, f"Unknown({pred})")
    print(f"Predicted: {pred_name} | confidence={conf:.4f}")
    # 显示预测结果
    if show:
        plt.figure(figsize=(5, 5))
        plt.imshow(img)
        plt.axis('off')
        title = f"{pred_name} ({conf*100:.1f}%)"
        plt.title(title, fontsize=12, color='darkblue', pad=10)
        plt.tight_layout()
        plt.show()
    return pred_name, conf

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-name', default='swin_tiny_patch4_window7_224')
    parser.add_argument('--num-classes', type=int, default=16)
    parser.add_argument('--checkpoint', default='./checkpoints/best_full.pth')
    parser.add_argument('--image-path', default='./datasets/caltech101/Images/calculator/027_0002.jpg')
    parser.add_argument('--img-size', type=int, default=224)
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # 加载模型
    model = timm.create_model(args.model_name, pretrained=False, num_classes=args.num_classes)
    model.load_state_dict(torch.load(args.checkpoint, map_location=device))
    model.to(device)

    transform = get_transforms(args.img_size)
    predict_single_image(model, args.image_path, transform, device, show=True)

if __name__ == '__main__':
    main()

以下是我在测试集中挑选的四张图片进行预测,可见全部预测正确。

图片[13] - AI科研 编程 读书笔记 - 【人工智能】【Python】计算机视觉图像分类实验(Swin-T) - AI科研 编程 读书笔记 - 小竹の笔记本

© 版权声明
THE END
点赞8 分享
相关推荐
评论 抢沙发

请登录后发表评论

    暂无评论内容

【C++】小竹的C++学习笔记分享 | 78篇体系化文档×4.7 万字(PDF 可下载) - AI科研 编程 读书笔记 - 小竹の笔记本
订阅本站更新 - AI科研 编程 读书笔记 - 小竹の笔记本
【日常】2025年度总结(重磅手敲8K字) - AI科研 编程 读书笔记 - 小竹の笔记本