グループ正規化 ぐるーぷせいきか
グループ正規化Group Normalizationバッチ正規化正規化手法CNN画像認識
グループ正規化について教えて
グループ正規化とは
グループ正規化(Group Normalization) は2018年にKaiming He(ResNetの開発者)らが提案した正規化手法です。チャンネル(特徴マップの数)を複数のグループに分割し、各グループ内の値の平均と分散を使って正規化します。
バッチ正規化(Batch Normalization)はバッチサイズが小さいと推定する統計量の精度が落ち、学習が不安定になります。例えば物体検出や画像セグメンテーションでは高解像度処理のためにバッチサイズを2〜4と小さくせざるを得ない場面があります。グループ正規化はバッチサイズに依存しないため、このような状況で安定した性能を発揮します。
グループ数を特徴量数と同じにすると「インスタンス正規化」、グループ数を1にすると「層正規化」と等価になる、正規化手法の中間的な位置付けです。
正規化手法の分類
| 手法 | 正規化の単位 | バッチサイズ依存 | 主な用途 |
|---|---|---|---|
| バッチ正規化 | バッチ×特徴量ごと | あり(大バッチ必須) | CNN・画像分類 |
| 層正規化 | サンプル内全特徴量 | なし | Transformer・LLM |
| インスタンス正規化 | サンプル×チャンネルごと | なし | スタイル変換 |
| グループ正規化 | サンプル×グループごと | なし | 物体検出・医療画像 |
import torch.nn as nn
# グループ正規化の使い方
# num_groups: グループ数(チャンネル数の約数を指定)
# num_channels: チャンネル数
group_norm = nn.GroupNorm(num_groups=32, num_channels=256)
# ResNetブロックでの使用例(バッチサイズが小さい物体検出向け)
class SmallBatchResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.gn1 = nn.GroupNorm(32, channels) # BNの代わりにGN
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.gn2 = nn.GroupNorm(32, channels)
self.relu = nn.ReLU()
def forward(self, x):
out = self.relu(self.gn1(self.conv1(x)))
out = self.gn2(self.conv2(out))
return self.relu(out + x)
歴史と背景
- 2015年:バッチ正規化が発表されるが、物体検出など大バッチが使えない場面で問題に
- 2018年:Heら「Group Normalization」を発表。Faster R-CNN・Mask R-CNNなどへの適用で効果を確認
- 現在:物体検出・インスタンスセグメンテーション・医療画像処理でバッチ正規化の代替として定着
グループ数の選び方
| チャンネル数 | 推奨グループ数 | 1グループあたりのチャンネル数 |
|---|---|---|
| 64 | 32 | 2 |
| 128 | 32 | 4 |
| 256 | 32 | 8 |
| 512 | 32 | 16 |
一般にグループ数は32が推奨されています(Heらの論文より)。ただし小さいモデルではグループ数を小さくして調整します。
関連用語
- 層正規化(Layer Normalization) — グループ数=1の場合と等価な正規化手法
- CNN(畳み込みニューラルネットワーク) — グループ正規化をよく使うモデル
- Faster R-CNN — グループ正規化を採用した物体検出モデル
- Mask R-CNN — グループ正規化でバッチサイズ問題を解決したモデル