残差接続(Residual Connection) ざんさせつぞく(れじじゅあるこねくしょん)
残差接続Residual ConnectionResNet勾配消失スキップ接続深層学習
残差接続(Residual Connection)について教えて
簡単に言うとこんな感じ!
層を通った後の出力に「元の入力をそのまま足す」のが残差接続だよ。「変化分(残差)だけ学習すればいい」から学習しやすくなるんだ。おかげで100層以上の超深いネットワークが学習できるようになったよ!
残差接続(Residual Connection)とは
残差接続(Residual Connection) とは、ある層の入力 x を、その層を通った後の出力 F(x) に足し合わせる構造です。式で表すと 出力 = F(x) + x となります。この「+x」の部分が「残差(residual)」を学習するという意味の名前の由来です。
2015年にKaiming Heらが提案したResNetで採用され、深層学習に革命をもたらしました。残差接続なしでは20〜30層を超えると勾配消失により学習が困難になりますが、残差接続を使うことで100層・1000層といった超深いネットワークでも安定して学習できるようになりました。
また残差接続は勾配を直接浅い層に伝達する「高速道路」の役割も果たします。深い層から浅い層へ勾配が短絡的に伝わるため、勾配消失問題が大幅に緩和されます。現在ではResNetに限らず、Transformer・BERT・GPTなどすべての主要な深層学習モデルで残差接続が使われています。
残差接続の仕組み
入力 x ─────────────────────────┐
│ │(ショートカット)
↓ │
[畳み込み層 / 全結合層] │
↓ │
[バッチ正規化 / 層正規化] │
↓ │
[活性化関数(ReLU等)] │
↓ │
F(x) ←──────────────────┘
↓(足し算)
出力 H(x) = F(x) + x
import torch
import torch.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
self.relu = nn.ReLU()
def forward(self, x):
residual = x # 入力を保存
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = out + residual # 残差接続(足し算)
return self.relu(out)
歴史と背景
- 2015年:HeらがResNetを発表。152層のネットワークでImageNetを制覇(従来の最高記録を大幅更新)
- 2017年:Transformerが残差接続を採用(Add & Norm構造)
- 2018年以降:BERT・GPTなどLLMのTransformerブロックに残差接続が不可欠な要素として組み込まれる
- 現在:深層学習の設計原則として、ほぼすべての大規模モデルで採用
残差接続の効果
| 効果 | 内容 |
|---|---|
| 勾配消失の防止 | ショートカットで勾配が直接浅い層に届く |
| 恒等写像の学習 | F(x)=0 を学習すれば入力をそのまま通せる |
| 深層化が可能 | 100層以上のネットでも安定学習 |
| アンサンブル効果 | 異なる深さの部分ネットワークの集合として機能 |
関連用語
- スキップ接続 — 残差接続を含む広義の「スキップ」概念
- ResNet — 残差接続を採用して深層学習を革新したモデル
- 勾配消失 / 勾配爆発 — 残差接続が解決する問題
- 層正規化(Layer Normalization) — Transformerで残差接続と組み合わせる
- 自己注意機構(Self-Attention) — 残差接続でつながれたTransformerの中核