RNN(Recurrent Neural Network)
์ฐ์ ๋ ์ต์ํ CNN์์ ์ถ๋ฐํด๋ณด์. CNN์ input(๋ค)์ ์ด์ฉํด output์ ์์ธกํ๋๋ฐ, ๊ทธ ๊ณผ์ ์์ data๊ฐ ์ฌ์ฌ์ฉ๋์ง ์๋๋ค. ๋น์ฐํ๋ค. CNN์ input ํ๋๋ฅผ ํ๊บผ๋ฒ์ ๋ฃ์ด์ฃผ๊ธฐ ๋๋ฌธ์ด๋ค. ํ๋๋ฅผ ํ๊บผ๋ฒ์? ์ด๋ฏธ์ง ํ๋๋ฅผ ๋ฃ์ ๋ ๊ฐ๊ฐ์ ํฝ์ ์ ์์๋๋ก ๋ฃ์ง ์๊ณ ํ๋ฒ์ Convolution layer๋ฅผ ๋ง๋๊ฒ ํด๋ฒ๋ฆฌ๋ ์ผ์ ์์ํด๋ณด๋ฉด ๋๋ค. ๋ฌผ๋ก Convolution layer์ kernel size๋๋ฌธ์ ๋จผ์ ์ฝํ๋ ๋ถ๋ถ์ด ์กด์ฌํ์ง ์๋๋ ์ถ์ ์ ์์ง๋ง, ๊ทธ ์์๊ฐ ์ค์ํ๊ฐ? ์ ๋ ๊ทธ๋ ์ง ์๋ค. ์ด๋ฏธ์ง์์ locality๊ฐ ์ค์ํ ๊ฒ์ sequence๊ฐ ์ค์ํ ๊ฒ๊ณผ๋ ๋ค๋ฅธ ์๋ฏธ๋ค.
RNN์ sequence data(์๊ณ์ด data, ์์ฐ์ด ๋ฑ)๋ฅผ ์ฒ๋ฆฌํ๋ ๋ฐ์ ํนํ๋์ด ์๋ neural network์ธ๋ฐ, ์ด๋ฆ ๊ทธ๋๋ก '์ํ'ํ๋ '์ํ์ ๊ฒฝ๋ง'์ด๋ค.
์ ์ฌ์ง์์ input sequence data ํ๋๋ [x0, x1, x2, ..., xt]์ด๋ค. RNN์ ๊ธฐ๋ณธ์ ์ผ๋ก input data๋ฅผ time step์ผ๋ก ๋๋์ด ์ง์ด๋ฃ๊ฒ ๋๋ค. ๊ทธ๋ฆฌ๊ณ h_t๋ ์๊ฐ t์์์ hidden state๋ค. RNN์ ์ด hidden state๋ฅผ ์ด์ฉํด ์ด์ time step์ ์ ๋ณด๋ฅผ ๋ค์ time step์์๋ ๊ฐ์ง๊ณ ์์ ์ ์๊ฒ ํ๋ ๊ฒ์ธ๋ฐ, ์ ๊ทธ๋ฆผ์์ '๋ฑํธ(=)'๋ฅผ ์ดํดํ๋ ๊ฒ์ด ์ค์ํ๋ค. RNN์ ๊ฐ์ฅ ๊ฐ๋จํ๊ฒ ํํํ๋ฉด ๋ฑํธ ์ข๋ณ์ ๊ทธ๋ฆผ๊ณผ ๊ฐ์๋ฐ, ์ํํ๋ ํ์ดํ๋ ์ฌ์ค h_t๊ฐ ์ํํ๋ ๊ฒ์ด๋ค. h_t์ ์ํ์ ์๊ฐ์ถ์ผ๋ก ํผ์น๋ฉด ์ฐ๋ณ์ ๊ทธ๋ฆผ์ฒ๋ผ ํํ๋๋ ์
์ด๋ค. ์ฝ๊ฒ ๋งํด output์ ๋ด๋ณด๋ด๋ฉด์ ๋์์ ๋ค์ time์๋ ์ ๋ฌํ๋ ๊ฒ์ด๋ค. (์๋ฐํ output๊ณผ hidden state๋ ๋ค๋ฅผ ์๋ ์๋๋ฐ, ์ด๋ฅผ ์๋์์ ์์ธํ ์ดํด๋ณด์.)
๋์ฑ ์ฝ๊ฒ ์ดํดํ ์ ์๋ gif๋ฅผ ์ฐพ์๋ณด์๋ค. ์ฌ์ค hidden state๋ ์ฐ๋ฆฌ๊ฐ ํํ ์๋ MLP์ hidden layer์ ๋์๋๋ค.
f๋ activation function์ด๊ณ (์ผ๋ฐ์ ์ผ๋ก tanh๋ฅผ ์ด๋ค.), y_hat์ด ์ต์ข output prediction์ด๋ค. ์์์ผ๋ก ํ์ธํ ์ ์๋ฏ h_t๊ฐ output์ธ ๊ฒ์ ์๋๊ณ , h_t๋ฅผ ๋ค์์ํ์ ์ ๋ฌํ ๋ ๊ทธ๋๋ก ์ฌ์ฉํ๊ณ , output์ผ๋ก ๋ด๋ณด๋ผ ๋๋ h_t๋ฅผ ๊ทธ๋๋ก ์ฐ๊ฑฐ๋ ์ฌ๊ธฐ์ softmax๋ sigmoid๊ฐ์ activation function์ ๊ฑฐ์น๊ฒ ํ๋ค. ์๋ฆฌ๋ฅผ ๋ณด๋ ๊ฐ๋จํ๋ค. ๊ทธ๋ ๊ทธ๋ด๊ฒ์ด RNN ๊ฐ๋ ์์ฒด๋ 1986๋ ์ ์ฒ์ ๋์๋ค.
๊ฒฐ๊ตญ RNN์ ์ข ๋ ๊น๋ณด๋ ค๊ณ ํด๋ ์์ ๊ทธ๋ฆผ์ด ์ ๋ถ๋ค. ๊ทธ์น๋ง input sequence์ hidden state, time step์ ์๋ฒฝํ ์ดํดํด์ผ LSTM๊ณผ GRU๋ฅผ ์ดํดํ ์ ์๋ ๊ฒ ๊ฐ๋ค.
'์๋ฐํ ๋งํ๋ฉด output๊ณผ hidden state๋ ๋ค๋ฅผ ์๋ ์๋ค๋ ๋ง์ ์?'
regression task๋ฉด h_t๋ฅผ ๊ทธ๋๋ก output์ผ๋ก๋ ์ฐ๋ฉด ๋๊ณ , binary classification task๋ฉด sigmoid๋ ReLU, multi-class classification task๋ฉด softmax๋ฅผ ๊ฑฐ์น๊ฒ ํด output์ผ๋ก ๋ผ ๊ฒ์ด๋ ๊ทธ๋ ๋ค.
๊ฐ์ธ์ ์ผ๋ก ์ฒ์ ์ธ๊ณต์ง๋ฅ ๊ณต๋ถํ ๋ ๊ฐ์ ๊ฐ๋ ์ ๋ค๋ฅธ ๊ทธ๋ฆผ์ด๋ ๋ง๋ก ์ค๋ช ํด๋ ๊ฒ๋ค, ๊ฐ์ ๊ฐ๋ ์ ๋ค๋ฅธ activation function์ ๊ฐ๋ค ๋ถ์ธ ๊ฒ๋ค์ด ํท๊ฐ๋ ธ์ด์ ๋ง๋ถ์ฌ ์ ์ด๋ณธ๋คโ!
LSTM(Long Short Term Memory) & GRU(Gated Recurrent Unit)
์ผ๋จ ์ด๊ฒ๋ค์ ์๊ธฐ ์ ์ RNN์ ๋ฌธ์ ์ ์ ํ์ ํด์ผ ๋ฑ์ฅ ๋ฐฐ๊ฒฝ์ ์ดํดํ๊ธฐ ์์ํ๋ค. RNN์ ๋ฌธ์ ๋ ๋ญ๊น? sequence data๊ฐ ๊ต์ฅํ ๊ธธ๋ค๊ณ ํด๋ณด์. ๊ทธ๋ผ data ํ๋ ์์ time step๋ ๊ต์ฅํ ๋ง์์ง๋ค. ๊ทธ๋ง์ ๊ณง activation function์ผ๋ก ๋ง๋ค์ด์ง๋ hidden state์ weight๊ฐ gradient vanishing problem์ ๊ฒช์ ํ๋ฅ ์ด ๋์์ง๋ค๋ ๋ป์ด ๋๋ค. (gradient vanishing problem์ด ๊ถ๊ธํ๋ค๋ฉด sigmoid->ReLU ๋ณ์ฒ์ฌ์ ๋ํด ์ฐพ์๋ณด๊ธธ ๋ฐ๋๋ค.)
์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ๋ฑ์ฅํ ๊ฒ์ด LSTM, ์ด๋ฅผ time-consuming ์ธก๋ฉด์์ ๋์ฑ ๋ณด์ํ ๊ฒ์ด GRU๋ผ๊ณ ๋ณผ ์ ์๊ฒ ๋ค.
LSTM๋ถํฐ๋ gate๋ผ๋ ๊ฐ๋ ์ด ๋ฑ์ฅํ๋๋ฐ, ์ฐ์ ์ด gate์ ๋ํด ์ดํดํด๋ณด์.
gate๋ ํ๋ง๋๋ก 'element-wise coefficient multiplicator'์ธ๋ฐ, ๋ฅ๋ฌ๋์์์ coefficient, ์ฆ weight๋ฅผ ๋ค์ element wiseํ๊ฒ ์กฐ์ ํด์ฃผ๋ function์ด๋ผ๊ณ ๋ณด๋ฉด ๋๋ค. ์ผ๋ฐ์ ์ผ๋ก sigmoid function์ ์ฌ์ฉํด 0๊ณผ 1์ฌ์ด ๊ฐ์ผ๋ก weight๋ฅผ ์กฐ์ ํ๋ค.
์์ ๊ทธ๋ฆผ์์ sigmoid๋ฅผ ํต๊ณผํด ๋ง๋ค์ด์ง f, i, o๊ฐ ๊ฐ๊ฐ forget gate, input gate, output gate์ด๋ค. ๊ฐ๊ฐ์ gate๋ input๊ณผ hidden state๋ฅผ ๊ฐ์ง๊ณ '๊ธฐ์กด์ ์ ๋ณด๋ฅผ ์ผ๋ง๋ ์์์ง', '์๋ก ๋ค์ด์จ input์ ์ผ๋ง๋ ์์ฉํ ์ง', 'output์ ์ด๋ฒ time step์์์ ์ ๋ณด๋ค์ ์ผ๋ง๋ ์ ๋ฌํ ์ง'๋ฅผ ํ์ตํ๊ฒ ๋๋ค. ์ด ์ธ ๊ฐ์ง๋ฅผ gate๋ค์ด, ๊ทธ๋ฆฌ๊ณ ๋ชจ๋ธ์ด '์์์ ํ์ตํ๋ค'๋๊ฒ ์ ๋ง ๋๋ผ์ด ํฌ์ธํธ๋ค.
๊ทธ๋ฆฌ๊ณ LSTM์์ ์ถ๊ฐ๋ ๋๋ค๋ฅธ ๊ฐ๋ ์ด ๋ฐ๋ก cell state๋ค. ์ ๊ทธ๋ฆผ์์ ์ด์ ๋จ๊ณ์ hidden state์ ํจ๊ป C(cell state)๊ฐ ๋ค์ด์ค๋ ๊ฒ์ ๋ณผ ์ ์๋ค. ์ฌ์ค LSTM์ด ๋ง๋ค์ด์ง ๊ณผ์ ์ ๋ค์ ๋ ์ฌ๋ ค๋ณด๋ฉด 'gradient vanishing problem' ๋ณด์ํ๊ธฐ๋ค. ์ง๊ธ๊น์ง gate ์ด์ผ๊ธฐ์๋ ์ ๋ถ activation function์ ์ฌ์ฉํ๋ค. ๊ทธ๋ผ ๋์ฒด ๋ญ๊ฐ ์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋๋ฐ? ๊ทธ๊ฒ ๋ฐ๋ก cell state๋ค.
๊ฒฐ๊ตญ ๋ชจ๋ธ์ด ๊ฐ์ง๊ณ ๊ฐ๋ main stream์ activation function์ ๊ฑฐ์น์ง ์์ cell state๋ค. ๊ทธ์น๋ง ์ ์ ํ ํ์ต์ ์ํด non-linearity๋ฅผ ์์ ์ ๊ฑฐํ ์๋ ์์ผ๋ ์ฌ๊ธฐ์ hidden state๋ฅผ ๋นผ๋ด๊ณ gate 3๊ฐ๋ก ์ง์ง๊ณ ๋ณถ๊ณ ํ ๋ค์ cell state์๋ element-wise multiplication์ด๋ element-wise summation๋ง ํด์ฃผ๋ ๊ฒ์ด๋ค. ์ฒ์ฌ๋ค.. ์ ๊ทธ๋ฆผ์์ pink box๋ก ํ์๋ input gate์์ it๋ฅผ ๋นผ๊ณ ๋ณด๋ฉด ์ ๊ฒ ์๋ basic RNN์ด๋ค. (๊ทธ๋์ g_t๋ฅผ temporal C_t๋ผ๊ณ ํํํ๊ธฐ๋ ํ๋ค.) ์ง์ง ๋ง๋ ์ฌ๋ ์ฒ์ฌ ์๋๋! LSTM ํต์ฌ ์ฐ์ฐ์ ์๋์ ๊ฐ๋ค.
GRU๋ ๋ญ๊น? ์ด๋ฆ๋ถํฐ Gated Recurrent Unit. (์๊ธฐ๋ค์ด gate๋ก๋ ๊ฒฐํ์ ์ง์๋ค๊ณ ์๊ฐํ๋ ๋ชจ์์ด๋ค.) LSTM์ ์ดํดํ๋ค๋ฉด GRU๋ ์ดํดํ๊ธฐ ํธํ๋ค. GRU๋ LSTM์ gate๋ฅผ 3๊ฐ์์ 2๊ฐ๋ก ์ค์๊ณ , ๋ค์ cell state๋ฅผ ์์ด๋ค. ๊ทธ๋ฌ๋ parameter๊ฐ ์ค๊ณ ๊ณ์ฐ๋์ด ์ค์ด๋ค ์๋ฐ์. ์ฑ๋ฅ๋ LSTM๋ณด๋ค ์ข ๋ ์ข๋ค.
sigmoid๋ฅผ ๊ฑฐ์น๋ gate ๋ ๊ฐ๋ ๊ฐ๊ฐ z(update gate)์ r(reset gate)๋ก, ๊ฐ๊ฐ ๊ธฐ์กด์ input gate์ forget gate์ ๋์๋๋ ์ผ์ ํ๋ค. cell state๋ฅผ ๋ฐ๋ก ๋์ง ์์์ผ๋ฉฐ, ๊ทธ๋ฌ๋ cell๊ณผ hidden์ ๋ถ๋ฆฌํ๊ธฐ ์ํด ํ์ํ๋ output gate๋ ๋ถํ์ํด์ก๋ค. ๋ฌผ๋ก ๋จ์ํ ์ด๋ ๊ฒ ์ค์ผ ์ ์๋ ๊ฒ์ ์๋๋ค. update gate z์ usage๋ฅผ ์ดํด๋ณด๋ฉด ๋ง์น binary cross entropy ๊ณ์ฐ์ฒ๋ผ (1-z)์ z๊ฐ ์ฐ์ธ๋ค. ์ด์ time step์์์ hidden state์ ๋ฐฉ๊ธ ๋ง๋ temporal hidden state์ ๋น์ค์ ์ด๋ป๊ฒ ์กฐ์ ํด์ ํ์ฌ์ hidden state๋ฅผ ๋ง๋ค์ง ์กฐ์ ํ๋ ๊ฒ์ด๋ค. ์๊ณ ๋ณด๋ฉด ์์ด๋์ด๋ ๊ฐ๋จํ๋ฐ ๋ง๋ ์ฌ๋์ ์ฒ์ฌ!!! ์ ๋ง ์ฒ์ฌ๋ค.
์ฌ๋ด) Computer vision์ด ์ข์์ RNN์ ์ซ์ด(?)ํ๋๋ฐ, ์ ๋๋ก ๋ฏ์ด๋ณด๋ ์ฒ์ CNN ๋ฐฐ์ ์ ๋๋ณด๋ค 10๋ฐฐ๋ ๋ ์ ๊ธฐํ๊ณ ์ฌ๋ฐ๋ ๊ฐ๋ ์ด๋ค. ๊ทธ๋ฆฌ๊ณ Video๋ sequence data์์! ๊ทธ๋ฆฌ๊ณ ViT๋ transformer์์ ๋์จ๊ฑฐ์์! ์ด๋ ๊ฒ ์๊ฐํ๋ฉด ๊ฐ์๊ธฐ RNN? ๋ ์ข ์ข์์ง๋ค. ํธํธ.
[์ฐธ๊ณ ์๋ฃ]
- (Blog) RNN (https://medium.com/@archit.saxena/recurrent-neural-network-rnn-92faf7c01fd4)
- (Blog) LSTM (http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
- (Blog) GRU (https://towardsdatascience.com/understanding-gru-networks-2ef37df6c9be)
- (Blog) [๋ฅ๋ฌ๋] ์ธ์ด๋ชจ๋ธ, RNN, GRU, LSTM, Attention, Transformer, GPT, BERT ๊ฐ๋ ์ ๋ฆฌ (velog.io)
'Study > Deep Learning' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[Deep learning] What is 'Style transfer'? (0) | 2024.07.07 |
---|---|
[Deep learning] Accelerating the Super-Resolution Convolutional Neural Network ๋ ผ๋ฌธ ๋ฆฌ๋ทฐ (0) | 2024.07.07 |
[Deep Learning] Attention, Seq2Seq, Transformer (0) | 2024.02.19 |
[AI] What is ResNet? (1) | 2024.01.05 |
[AI] 2023-2ํ๊ธฐ ๊ณต๋ถํ ๋ด์ฉ ์์ฝ (0) | 2023.08.14 |