๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
Study/Deep Learning

[Deep Learning] RNN, LSTM, GRU ์ด์ •๋ฆฌโ˜… (+ํŒ์„œ)

by ์œ ๋ฏธ๋ฏธYoomimi 2024. 1. 12.

+ ์น ํŒ์— ์ •๋ฆฌํ•ด๋ณธ RNN, LSTM, GRU ๊ฐœ๋…

 

 

RNN(Recurrent Neural Network)

์šฐ์„  ๋” ์ต์ˆ™ํ•œ CNN์—์„œ ์ถœ๋ฐœํ•ด๋ณด์ž. CNN์€ input(๋“ค)์„ ์ด์šฉํ•ด output์„ ์˜ˆ์ธกํ•˜๋Š”๋ฐ, ๊ทธ ๊ณผ์ •์—์„œ data๊ฐ€ ์žฌ์‚ฌ์šฉ๋˜์ง€ ์•Š๋Š”๋‹ค. ๋‹น์—ฐํ•˜๋‹ค. CNN์€ input ํ•˜๋‚˜๋ฅผ ํ•œ๊บผ๋ฒˆ์— ๋„ฃ์–ด์ฃผ๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. ํ•˜๋‚˜๋ฅผ ํ•œ๊บผ๋ฒˆ์—? ์ด๋ฏธ์ง€ ํ•˜๋‚˜๋ฅผ ๋„ฃ์„ ๋•Œ ๊ฐ๊ฐ์˜ ํ”ฝ์…€์„ ์ˆœ์„œ๋Œ€๋กœ ๋„ฃ์ง€ ์•Š๊ณ  ํ•œ๋ฒˆ์— Convolution layer๋ฅผ ๋งŒ๋‚˜๊ฒŒ ํ•ด๋ฒ„๋ฆฌ๋Š” ์ผ์„ ์ƒ์ƒํ•ด๋ณด๋ฉด ๋œ๋‹ค. ๋ฌผ๋ก  Convolution layer์˜ kernel size๋•Œ๋ฌธ์— ๋จผ์ € ์ฝํžˆ๋Š” ๋ถ€๋ถ„์ด ์กด์žฌํ•˜์ง€ ์•Š๋Š๋ƒ ์‹ถ์„ ์ˆ˜ ์žˆ์ง€๋งŒ, ๊ทธ ์ˆœ์„œ๊ฐ€ ์ค‘์š”ํ•œ๊ฐ€? ์ ˆ๋Œ€ ๊ทธ๋ ‡์ง€ ์•Š๋‹ค. ์ด๋ฏธ์ง€์—์„œ locality๊ฐ€ ์ค‘์š”ํ•œ ๊ฒƒ์€ sequence๊ฐ€ ์ค‘์š”ํ•œ ๊ฒƒ๊ณผ๋Š” ๋‹ค๋ฅธ ์˜๋ฏธ๋‹ค.

 

RNN์€ sequence data(์‹œ๊ณ„์—ด data, ์ž์—ฐ์–ด ๋“ฑ)๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” ๋ฐ์— ํŠนํ™”๋˜์–ด ์žˆ๋Š” neural network์ธ๋ฐ, ์ด๋ฆ„ ๊ทธ๋Œ€๋กœ '์ˆœํ™˜'ํ•˜๋Š” '์ˆœํ™˜์‹ ๊ฒฝ๋ง'์ด๋‹ค.

 

Basic of basic rnn


์œ„ ์‚ฌ์ง„์—์„œ input sequence data ํ•˜๋‚˜๋Š” [x0, x1, x2, ..., xt]์ด๋‹ค. RNN์€ ๊ธฐ๋ณธ์ ์œผ๋กœ input data๋ฅผ time step์œผ๋กœ ๋‚˜๋ˆ„์–ด ์ง‘์–ด๋„ฃ๊ฒŒ ๋œ๋‹ค. ๊ทธ๋ฆฌ๊ณ  h_t๋Š” ์‹œ๊ฐ„ t์—์„œ์˜ hidden state๋‹ค. RNN์€ ์ด hidden state๋ฅผ ์ด์šฉํ•ด ์ด์ „ time step์˜ ์ •๋ณด๋ฅผ ๋‹ค์Œ time step์—์„œ๋„ ๊ฐ€์ง€๊ณ  ์žˆ์„ ์ˆ˜ ์žˆ๊ฒŒ ํ•˜๋Š” ๊ฒƒ์ธ๋ฐ, ์œ„ ๊ทธ๋ฆผ์—์„œ '๋“ฑํ˜ธ(=)'๋ฅผ ์ดํ•ดํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค. RNN์„ ๊ฐ€์žฅ ๊ฐ„๋‹จํ•˜๊ฒŒ ํ‘œํ˜„ํ•˜๋ฉด ๋“ฑํ˜ธ ์ขŒ๋ณ€์˜ ๊ทธ๋ฆผ๊ณผ ๊ฐ™์€๋ฐ, ์ˆœํ™˜ํ•˜๋Š” ํ™”์‚ดํ‘œ๋„ ์‚ฌ์‹ค h_t๊ฐ€ ์ˆœํ™˜ํ•˜๋Š” ๊ฒƒ์ด๋‹ค. h_t์˜ ์ˆœํ™˜์„ ์‹œ๊ฐ„์ถ•์œผ๋กœ ํŽผ์น˜๋ฉด ์šฐ๋ณ€์˜ ๊ทธ๋ฆผ์ฒ˜๋Ÿผ ํ‘œํ˜„๋˜๋Š” ์…ˆ์ด๋‹ค. ์‰ฝ๊ฒŒ ๋งํ•ด output์„ ๋‚ด๋ณด๋‚ด๋ฉด์„œ ๋™์‹œ์— ๋‹ค์Œ time์—๋„ ์ „๋‹ฌํ•˜๋Š” ๊ฒƒ์ด๋‹ค. (์—„๋ฐ€ํžˆ output๊ณผ hidden state๋Š” ๋‹ค๋ฅผ ์ˆ˜๋„ ์žˆ๋Š”๋ฐ, ์ด๋ฅผ ์•„๋ž˜์—์„œ ์ž์„ธํžˆ ์‚ดํŽด๋ณด์ž.)

 

๋”์šฑ ์‰ฝ๊ฒŒ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋Š” gif๋ฅผ ์ฐพ์•„๋ณด์•˜๋‹ค. ์‚ฌ์‹ค hidden state๋Š” ์šฐ๋ฆฌ๊ฐ€ ํ”ํžˆ ์•„๋Š” MLP์˜ hidden layer์— ๋Œ€์‘๋œ๋‹ค.

 

 

f๋Š” activation function์ด๊ณ (์ผ๋ฐ˜์ ์œผ๋กœ tanh๋ฅผ ์“ด๋‹ค.), y_hat์ด ์ตœ์ข… output prediction์ด๋‹ค. ์ˆ˜์‹์œผ๋กœ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋“ฏ h_t๊ฐ€ output์ธ ๊ฒƒ์€ ์•„๋‹ˆ๊ณ , h_t๋ฅผ ๋‹ค์Œ์ƒํƒœ์— ์ „๋‹ฌํ•  ๋•Œ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉํ•˜๊ณ , output์œผ๋กœ ๋‚ด๋ณด๋‚ผ ๋•Œ๋Š” h_t๋ฅผ ๊ทธ๋Œ€๋กœ ์“ฐ๊ฑฐ๋‚˜ ์—ฌ๊ธฐ์— softmax๋‚˜ sigmoid๊ฐ™์€ activation function์„ ๊ฑฐ์น˜๊ฒŒ ํ•œ๋‹ค. ์›๋ฆฌ๋ฅผ ๋ณด๋‹ˆ ๊ฐ„๋‹จํ•˜๋‹ค. ๊ทธ๋„ ๊ทธ๋Ÿด๊ฒƒ์ด RNN ๊ฐœ๋… ์ž์ฒด๋Š” 1986๋…„์— ์ฒ˜์Œ ๋‚˜์™”๋‹ค.

 

 

 

 

๊ฒฐ๊ตญ RNN์€ ์ข€ ๋” ๊นŒ๋ณด๋ ค๊ณ  ํ•ด๋„ ์œ„์˜ ๊ทธ๋ฆผ์ด ์ „๋ถ€๋‹ค. ๊ทธ์น˜๋งŒ input sequence์™€ hidden state, time step์„ ์™„๋ฒฝํžˆ ์ดํ•ดํ•ด์•ผ LSTM๊ณผ GRU๋ฅผ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒƒ ๊ฐ™๋‹ค.

 

'์—„๋ฐ€ํžˆ ๋งํ•˜๋ฉด output๊ณผ hidden state๋Š” ๋‹ค๋ฅผ ์ˆ˜๋„ ์žˆ๋‹ค๋Š” ๋ง์€ ์™œ?'

regression task๋ฉด h_t๋ฅผ ๊ทธ๋Œ€๋กœ output์œผ๋กœ๋„ ์“ฐ๋ฉด ๋˜๊ณ , binary classification task๋ฉด sigmoid๋‚˜ ReLU, multi-class classification task๋ฉด softmax๋ฅผ ๊ฑฐ์น˜๊ฒŒ ํ•ด output์œผ๋กœ ๋‚ผ ๊ฒƒ์ด๋‹ˆ ๊ทธ๋ ‡๋‹ค.

๊ฐœ์ธ์ ์œผ๋กœ ์ฒ˜์Œ ์ธ๊ณต์ง€๋Šฅ ๊ณต๋ถ€ํ•  ๋•Œ ๊ฐ™์€ ๊ฐœ๋…์„ ๋‹ค๋ฅธ ๊ทธ๋ฆผ์ด๋‚˜ ๋ง๋กœ ์„ค๋ช…ํ•ด๋‘” ๊ฒƒ๋“ค, ๊ฐ™์€ ๊ฐœ๋…์— ๋‹ค๋ฅธ activation function์„ ๊ฐ–๋‹ค ๋ถ™์ธ ๊ฒƒ๋“ค์ด ํ—ท๊ฐˆ๋ ธ์–ด์„œ ๋ง๋ถ™์—ฌ ์ ์–ด๋ณธ๋‹คโ˜†!

 

 

LSTM(Long Short Term Memory) & GRU(Gated Recurrent Unit)

์ผ๋‹จ ์ด๊ฒƒ๋“ค์„ ์•Œ๊ธฐ ์ „์— RNN์˜ ๋ฌธ์ œ์ ์„ ํŒŒ์•…ํ•ด์•ผ ๋“ฑ์žฅ ๋ฐฐ๊ฒฝ์„ ์ดํ•ดํ•˜๊ธฐ ์ˆ˜์›”ํ•˜๋‹ค. RNN์˜ ๋ฌธ์ œ๋Š” ๋ญ˜๊นŒ? sequence data๊ฐ€ ๊ต‰์žฅํžˆ ๊ธธ๋‹ค๊ณ  ํ•ด๋ณด์ž. ๊ทธ๋Ÿผ data ํ•˜๋‚˜ ์•ˆ์— time step๋„ ๊ต‰์žฅํžˆ ๋งŽ์•„์ง„๋‹ค. ๊ทธ๋ง์€ ๊ณง activation function์œผ๋กœ ๋งŒ๋“ค์–ด์ง€๋Š” hidden state์˜ weight๊ฐ€ gradient vanishing problem์„ ๊ฒช์„ ํ™•๋ฅ ์ด ๋†’์•„์ง„๋‹ค๋Š” ๋œป์ด ๋œ๋‹ค. (gradient vanishing problem์ด ๊ถ๊ธˆํ•˜๋‹ค๋ฉด sigmoid->ReLU ๋ณ€์ฒœ์‚ฌ์— ๋Œ€ํ•ด ์ฐพ์•„๋ณด๊ธธ ๋ฐ”๋ž€๋‹ค.)

์ด ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด ๋“ฑ์žฅํ•œ ๊ฒƒ์ด LSTM, ์ด๋ฅผ time-consuming ์ธก๋ฉด์—์„œ ๋”์šฑ ๋ณด์™„ํ•œ ๊ฒƒ์ด GRU๋ผ๊ณ  ๋ณผ ์ˆ˜ ์žˆ๊ฒ ๋‹ค.

 

LSTM๋ถ€ํ„ฐ๋Š” gate๋ผ๋Š” ๊ฐœ๋…์ด ๋“ฑ์žฅํ•˜๋Š”๋ฐ, ์šฐ์„  ์ด gate์— ๋Œ€ํ•ด ์ดํ•ดํ•ด๋ณด์ž.

gate๋Š” ํ•œ๋งˆ๋””๋กœ 'element-wise coefficient multiplicator'์ธ๋ฐ, ๋”ฅ๋Ÿฌ๋‹์—์„œ์˜ coefficient, ์ฆ‰ weight๋ฅผ ๋‹ค์‹œ element wiseํ•˜๊ฒŒ ์กฐ์ ˆํ•ด์ฃผ๋Š” function์ด๋ผ๊ณ  ๋ณด๋ฉด ๋œ๋‹ค. ์ผ๋ฐ˜์ ์œผ๋กœ sigmoid function์„ ์‚ฌ์šฉํ•ด 0๊ณผ 1์‚ฌ์ด ๊ฐ’์œผ๋กœ weight๋ฅผ ์กฐ์ ˆํ•œ๋‹ค.

 

 

์œ„์˜ ๊ทธ๋ฆผ์—์„œ sigmoid๋ฅผ ํ†ต๊ณผํ•ด ๋งŒ๋“ค์–ด์ง„ f, i, o๊ฐ€ ๊ฐ๊ฐ forget gate, input gate, output gate์ด๋‹ค. ๊ฐ๊ฐ์˜ gate๋Š” input๊ณผ hidden state๋ฅผ ๊ฐ€์ง€๊ณ  '๊ธฐ์กด์˜ ์ •๋ณด๋ฅผ ์–ผ๋งˆ๋‚˜ ์žŠ์„์ง€', '์ƒˆ๋กœ ๋“ค์–ด์˜จ input์„ ์–ผ๋งˆ๋‚˜ ์ˆ˜์šฉํ• ์ง€', 'output์— ์ด๋ฒˆ time step์—์„œ์˜ ์ •๋ณด๋“ค์„ ์–ผ๋งˆ๋‚˜ ์ „๋‹ฌํ• ์ง€'๋ฅผ ํ•™์Šตํ•˜๊ฒŒ ๋œ๋‹ค. ์ด ์„ธ ๊ฐ€์ง€๋ฅผ gate๋“ค์ด, ๊ทธ๋ฆฌ๊ณ  ๋ชจ๋ธ์ด '์•Œ์•„์„œ ํ•™์Šตํ•œ๋‹ค'๋Š”๊ฒŒ ์ •๋ง ๋†€๋ผ์šด ํฌ์ธํŠธ๋‹ค.

 

๊ทธ๋ฆฌ๊ณ  LSTM์—์„œ ์ถ”๊ฐ€๋œ ๋˜๋‹ค๋ฅธ ๊ฐœ๋…์ด ๋ฐ”๋กœ cell state๋‹ค. ์œ„ ๊ทธ๋ฆผ์—์„œ ์ด์ „ ๋‹จ๊ณ„์˜ hidden state์™€ ํ•จ๊ป˜ C(cell state)๊ฐ€ ๋“ค์–ด์˜ค๋Š” ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ๋‹ค. ์‚ฌ์‹ค LSTM์ด ๋งŒ๋“ค์–ด์ง„ ๊ณผ์ •์„ ๋‹ค์‹œ ๋– ์˜ฌ๋ ค๋ณด๋ฉด 'gradient vanishing problem' ๋ณด์™„ํ•˜๊ธฐ๋‹ค. ์ง€๊ธˆ๊นŒ์ง€ gate ์ด์•ผ๊ธฐ์—๋Š” ์ „๋ถ€ activation function์„ ์‚ฌ์šฉํ–ˆ๋‹ค. ๊ทธ๋Ÿผ ๋Œ€์ฒด ๋ญ๊ฐ€ ์ด ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋Š”๋ฐ? ๊ทธ๊ฒŒ ๋ฐ”๋กœ cell state๋‹ค.

๊ฒฐ๊ตญ ๋ชจ๋ธ์ด ๊ฐ€์ง€๊ณ  ๊ฐ€๋Š” main stream์€ activation function์„ ๊ฑฐ์น˜์ง€ ์•Š์€ cell state๋‹ค. ๊ทธ์น˜๋งŒ ์ ์ ˆํ•œ ํ•™์Šต์„ ์œ„ํ•ด non-linearity๋ฅผ ์•„์˜ˆ ์ œ๊ฑฐํ•  ์ˆ˜๋Š” ์—†์œผ๋‹ˆ ์—ฌ๊ธฐ์„œ hidden state๋ฅผ ๋นผ๋‚ด๊ณ  gate 3๊ฐœ๋กœ ์ง€์ง€๊ณ  ๋ณถ๊ณ  ํ•œ ๋‹ค์Œ cell state์—๋Š” element-wise multiplication์ด๋‚˜ element-wise summation๋งŒ ํ•ด์ฃผ๋Š” ๊ฒƒ์ด๋‹ค. ์ฒœ์žฌ๋“ค.. ์œ„ ๊ทธ๋ฆผ์—์„œ pink box๋กœ ํ‘œ์‹œ๋œ input gate์—์„œ it๋ฅผ ๋นผ๊ณ  ๋ณด๋ฉด ์ €๊ฒŒ ์›๋ž˜ basic RNN์ด๋‹ค. (๊ทธ๋ž˜์„œ g_t๋ฅผ temporal C_t๋ผ๊ณ  ํ‘œํ˜„ํ•˜๊ธฐ๋„ ํ•œ๋‹ค.) ์ง„์งœ ๋งŒ๋“  ์‚ฌ๋žŒ ์ฒœ์žฌ ์•„๋‹ˆ๋ƒ! LSTM ํ•ต์‹ฌ ์—ฐ์‚ฐ์€ ์•„๋ž˜์™€ ๊ฐ™๋‹ค.

element-wise multiplication๊ณผ  element-wise summation
hidden state๋Š” Ct์™€ output gate๋ฅผ ๊ฐ€์ง€๊ณ  ๋งŒ๋“ ๋‹ค.

 

 

 

GRU๋Š” ๋ญ˜๊นŒ? ์ด๋ฆ„๋ถ€ํ„ฐ Gated Recurrent Unit. (์ž๊ธฐ๋“ค์ด gate๋กœ๋Š” ๊ฒฐํŒ์„ ์ง€์—ˆ๋‹ค๊ณ  ์ƒ๊ฐํ–ˆ๋˜ ๋ชจ์–‘์ด๋‹ค.) LSTM์„ ์ดํ•ดํ–ˆ๋‹ค๋ฉด GRU๋„ ์ดํ•ดํ•˜๊ธฐ ํŽธํ•˜๋‹ค. GRU๋Š” LSTM์˜ gate๋ฅผ 3๊ฐœ์—์„œ 2๊ฐœ๋กœ ์ค„์˜€๊ณ , ๋‹ค์‹œ cell state๋ฅผ ์—†์•ด๋‹ค. ๊ทธ๋Ÿฌ๋‹ˆ parameter๊ฐ€ ์ค„๊ณ  ๊ณ„์‚ฐ๋Ÿ‰์ด ์ค„์–ด๋“ค ์ˆ˜๋ฐ–์—. ์„ฑ๋Šฅ๋„ LSTM๋ณด๋‹ค ์ข€ ๋” ์ข‹๋‹ค.

 

sigmoid๋ฅผ ๊ฑฐ์น˜๋Š” gate ๋‘ ๊ฐœ๋Š” ๊ฐ๊ฐ z(update gate)์™€ r(reset gate)๋กœ, ๊ฐ๊ฐ ๊ธฐ์กด์˜ input gate์™€ forget gate์— ๋Œ€์‘๋˜๋Š” ์ผ์„ ํ•œ๋‹ค. cell state๋ฅผ ๋”ฐ๋กœ ๋‘์ง€ ์•Š์•˜์œผ๋ฉฐ, ๊ทธ๋Ÿฌ๋‹ˆ cell๊ณผ hidden์„ ๋ถ„๋ฆฌํ•˜๊ธฐ ์œ„ํ•ด ํ•„์š”ํ–ˆ๋˜ output gate๋„ ๋ถˆํ•„์š”ํ•ด์กŒ๋‹ค. ๋ฌผ๋ก  ๋‹จ์ˆœํžˆ ์ด๋ ‡๊ฒŒ ์ค„์ผ ์ˆ˜ ์žˆ๋˜ ๊ฒƒ์€ ์•„๋‹ˆ๋‹ค. update gate z์˜ usage๋ฅผ ์‚ดํŽด๋ณด๋ฉด ๋งˆ์น˜ binary cross entropy ๊ณ„์‚ฐ์ฒ˜๋Ÿผ (1-z)์™€ z๊ฐ€ ์“ฐ์ธ๋‹ค. ์ด์ „ time step์—์„œ์˜ hidden state์™€ ๋ฐฉ๊ธˆ ๋งŒ๋“  temporal hidden state์˜ ๋น„์ค‘์„ ์–ด๋–ป๊ฒŒ ์กฐ์ ˆํ•ด์„œ ํ˜„์žฌ์˜ hidden state๋ฅผ ๋งŒ๋“ค์ง€ ์กฐ์ ˆํ•˜๋Š” ๊ฒƒ์ด๋‹ค. ์•Œ๊ณ  ๋ณด๋ฉด ์•„์ด๋””์–ด๋Š” ๊ฐ„๋‹จํ•œ๋ฐ ๋งŒ๋“  ์‚ฌ๋žŒ์€ ์ฒœ์žฌ!!! ์ •๋ง ์ฒœ์žฌ๋‹ค.

 

 

์—ฌ๋‹ด) Computer vision์ด ์ข‹์•„์„œ RNN์€ ์‹ซ์–ด(?)ํ–ˆ๋Š”๋ฐ, ์ œ๋Œ€๋กœ ๋œฏ์–ด๋ณด๋‹ˆ ์ฒ˜์Œ CNN ๋ฐฐ์› ์„ ๋•Œ๋ณด๋‹ค 10๋ฐฐ๋Š” ๋” ์‹ ๊ธฐํ•˜๊ณ  ์žฌ๋ฐŒ๋Š” ๊ฐœ๋…์ด๋‹ค. ๊ทธ๋ฆฌ๊ณ  Video๋Š” sequence data์ž–์•„! ๊ทธ๋ฆฌ๊ณ  ViT๋„ transformer์—์„œ ๋‚˜์˜จ๊ฑฐ์ž–์•„! ์ด๋ ‡๊ฒŒ ์ƒ๊ฐํ•˜๋ฉด ๊ฐ‘์ž๊ธฐ RNN? ๋„ˆ ์ข€ ์ข‹์•„์ง„๋‹ค. ํ˜ธํ˜ธ.

 


[์ฐธ๊ณ ์ž๋ฃŒ]