ニューラルネットワーク

グループ正規化 ぐるーぷせいきか

グループ正規化Group Normalizationバッチ正規化正規化手法CNN画像認識
グループ正規化について教えて

簡単に言うとこんな感じ!

グループ正規化は、チャンネルをいくつかのグループに分けてグループ内で正規化する方法だよ。バッチサイズが小さくてもバッチ正規化のように安定して動いて、物体検出や医療画像など「大バッチが使えない場面」で活躍するんだ!


グループ正規化とは

グループ正規化(Group Normalization) は2018年にKaiming He(ResNetの開発者)らが提案した正規化手法です。チャンネル(特徴マップの数)を複数のグループに分割し、各グループ内の値の平均と分散を使って正規化します

バッチ正規化(Batch Normalization)はバッチサイズが小さいと推定する統計量の精度が落ち、学習が不安定になります。例えば物体検出や画像セグメンテーションでは高解像度処理のためにバッチサイズを2〜4と小さくせざるを得ない場面があります。グループ正規化はバッチサイズに依存しないため、このような状況で安定した性能を発揮します。

グループ数を特徴量数と同じにすると「インスタンス正規化」、グループ数を1にすると「層正規化」と等価になる、正規化手法の中間的な位置付けです。


正規化手法の分類

手法正規化の単位バッチサイズ依存主な用途
バッチ正規化バッチ×特徴量ごとあり(大バッチ必須)CNN・画像分類
層正規化サンプル内全特徴量なしTransformerLLM
インスタンス正規化サンプル×チャンネルごとなしスタイル変換
グループ正規化サンプル×グループごとなし物体検出・医療画像
import torch.nn as nn

# グループ正規化の使い方
# num_groups: グループ数(チャンネル数の約数を指定)
# num_channels: チャンネル数

group_norm = nn.GroupNorm(num_groups=32, num_channels=256)

# ResNetブロックでの使用例(バッチサイズが小さい物体検出向け)
class SmallBatchResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.gn1 = nn.GroupNorm(32, channels)  # BNの代わりにGN
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.gn2 = nn.GroupNorm(32, channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.relu(self.gn1(self.conv1(x)))
        out = self.gn2(self.conv2(out))
        return self.relu(out + x)

歴史と背景

  • 2015年:バッチ正規化が発表されるが、物体検出など大バッチが使えない場面で問題に
  • 2018年:Heら「Group Normalization」を発表。Faster R-CNNMask R-CNNなどへの適用で効果を確認
  • 現在:物体検出・インスタンスセグメンテーション・医療画像処理でバッチ正規化の代替として定着

グループ数の選び方

チャンネル数推奨グループ数1グループあたりのチャンネル数
64322
128324
256328
5123216

一般にグループ数は32が推奨されています(Heらの論文より)。ただし小さいモデルではグループ数を小さくして調整します。


関連用語