自然言語処理

RNN(再帰型ニューラルネットワーク) あーるえぬえぬ(さいきがたにゅーらるねっとわーく)

RNN再帰型ニューラルネットワーク系列データ時系列隠れ状態自然言語処理
RNN(再帰型ニューラルネットワーク)について教えて

簡単に言うとこんな感じ!

RNNは「前の情報を記憶しながら次の入力を処理する」ニューラルネットワークだよ。文章の「私は昨日___を食べた」のように、前の文脈を踏まえて次を予測するのが得意!人間が文章を読む「文脈を積み上げながら理解する」動きをモデル化したものなんだ!


RNN(再帰型ニューラルネットワーク)とは

RNN(Recurrent Neural Network:再帰型ニューラルネットワーク) は、系列データ(テキスト・音声・時系列データ)を処理するために設計されたニューラルネットワークです。通常のニューラルネットワークと異なり、前のステップの出力(隠れ状態 h)を現在のステップの入力と合わせて使うことで、過去の情報を「記憶」しながら処理します。

テキスト処理では「猫が魚を食べた」という文章を1単語ずつ読み込み、各ステップで前の単語の文脈を保持しながら次の単語を予測します。音声認識では音声波形を時系列として処理し、時系列予測では過去のデータを記憶しながら次の値を予測します。

ただし、長い系列では遠い過去の情報が薄れる勾配消失問題があり、これを解決するためにLSTMGRUが開発されました。現在では多くのNLPタスクでTransformerに置き換えられていますが、時系列や組み込みデバイスでの利用はまだ多いです。


RNNの計算ステップ

時刻 t での計算:
  h_t = tanh(W_h × h_{t-1} + W_x × x_t + b)
  y_t = W_y × h_t + b_y

  h_t     : 現在の隠れ状態(記憶)
  h_{t-1} : 前の隠れ状態(過去の記憶)
  x_t     : 現在の入力
  y_t     : 現在の出力
import torch.nn as nn

# PyTorchのRNN
rnn = nn.RNN(
    input_size=100,   # 入力特徴量の次元数
    hidden_size=256,  # 隠れ状態の次元数
    num_layers=2,     # RNN層の積み重ね数
    batch_first=True  # (batch, seq, feature)の形式
)

# output: 各ステップの出力, h_n: 最終隠れ状態
output, h_n = rnn(x)

歴史と背景

  • 1986年:Rumelhart・Hintonがバックプロパゲーションを提案。RNNの学習が可能に
  • 1991年:勾配消失問題が報告される。長い系列の学習が困難と判明
  • 1997年:LSTMが提案され、勾配消失問題に対処
  • 2014年:GRUがよりシンプルな代替手法として提案
  • 2017年:Transformerの登場でRNN系はNLPの主役を譲るが、時系列での利用は継続

RNNのバリアント

モデル特徴向いている用途
Vanilla RNNシンプルだが勾配消失問題あり短い系列の学習
LSTMゲート機構で長期依存関係を保持長文・音声認識
GRULSTMを簡略化。速いLSTMと同用途で高速版
双方向RNN過去と未来の両方向から処理感情分析・固有表現認識

関連用語