-- ๋ณธ ํฌ์คํ
์ ํ์ดํ ์น๋ก ๋ฐฐ์ฐ๋ ์์ฐ์ด ์ฒ๋ฆฌ (ํ๋น๋ฏธ๋์ด) ์ ๋ฅ๋ฌ๋์ ์ด์ฉํ ์์ฐ์ด ์ฒ๋ฆฌ ์
๋ฌธ(์ํค๋
์ค) ์ ์๋ฅผ ์ฐธ๊ณ ํด์ ์์ฑ๋ ๊ธ์
๋๋ค.
1. Sequence-to-Sequence (seq2seq)
- ๋ฒ์ญ๊ธฐ์์ ๋ํ์ ์ผ๋ก ์ฌ์ฉ๋๋ ๋ชจ๋ธ
- ์ธ์ฝ๋์ ๋์ฝ๋๋ก ๊ตฌ์ฑ๋ ์ธ์ฝ๋ - ๋์ฝ๋ ๋ชจ๋ธ์ ์ผ์ข
- ์กฐ๊ฑด๋ถ ์์ฑ ๋ชจ๋ธ (conditioned generation model)์ ์ผ์ข
์ด๊ธฐ๋ ํจ
- ์กฐ๊ฑด๋ถ ์์ฑ ๋ชจ๋ธ์ด๋? ์ ๋ ฅํํ ๋์ ์ผ๋ฐ์ ์ธ ์กฐ๊ฑด ๋ฌธ๋งฅ์ ํ์ฉํ์ฌ ๋์ฝ๋๊ฐ ์ถ๋ ฅ์ ๋ง๋๋ ๋ชจ๋ธ
1.1 seq2seq ๋ชจ๋ธ์ ๊ตฌ์ฑ
- ์ธ์ฝ๋๋ ์ ๋ ฅ ๋ฌธ์ฅ์ ๋ชจ๋ ๋จ์ด๋ฅผ ์์ฐจ์ ์ผ๋ก ์ ๋ ฅ๋ฐ์ ๋ค์ ๋ง์ง๋ง์ ๋ชจ๋ ๋จ์ด์ ์ ๋ณด๋ฅผ ์์ถํด์ ํ๋์ ๋ฒกํฐ๋ก ๋ง๋ค์ด์ค (์ฆ, context vector๋ฅผ ์์ฑ)
- context ๋ฒกํฐ๋ก ๋ชจ๋ ๋ฌธ์ฅ์ ์ ๋ณด๊ฐ ์์ถ๋๋ฉด ๋์ฝ๋๋ก ์ ์ก
- ๋์ฝ๋๋ context ๋ฒกํฐ๋ฅผ ๋ฐ์์ ๋ฒ์ญ๋ ๋จ์ด๋ฅผ ํ๋์ฉ ์์ฐจ์ ์ผ๋ก ์ถ๋ ฅ
- ์ธ์ฝ๋์ ๋์ฝ๋๋ฅผ ์์ธํ ๋ณด๋ฉด ์ด๋ ๊ฒ RNN ์ํคํ
์ณ๋ก ์ด๋ฃจ์ด์ ธ ์๋ค.
- ์ฑ๋ฅ์ ๋ฌธ์ ๋ก RNN๋ณด๋ค๋ ์ฃผ๋ก ๋ฐ์ ๋ ํํ์ธ LSTM๊ณผ GRU์ ์ ์ฌ์ฉํ๋ค.
- ๊ธฐ๊ณ๋ ๋จ์ด๋ณด๋ค ์ซ์๋ฅผ ๋ ์ ์ธ์ํ๊ธฐ ๋๋ฌธ์ ๊ฐ ์ ์์ ์๋ฒ ๋ฉ ๊ณผ์ ์ ๊ฑฐ์น๋ค.
- ํ๋์ RNN(LSTM)์ ์์๋ t-1 ์์์ ์๋์ํ์, t์์์ ์ ๋ ฅ๋ฒกํฐ๋ฅผ ์ ๋ ฅ์ผ๋ก ๋ฐ์์ t์์์ ์๋์ํ ๋ฒกํฐ๋ฅผ ์์ฑํด์ค๋ค. (์ด์ ํฌ์คํ ์ฐธ๊ณ )
โถ ์ธ์ฝ๋
- ์ ๋ ฅ ๋ฌธ์ฅ์ ๋จ์ด ํ ํฐํ๋ฅผ ํตํด ๋จ์ด ๋จ์๋ก ์ชผ๊ฐ์ง๊ณ , ๋ชจ๋ ๋จ์ด๋ฅผ ์๋ฒ ๋ฉํ๋ค.
- ๊ทธ๋ฆฌ๊ณ ๋จ์ด ํ ํฐ์ ๊ฐ๊ฐ RNN์ ์ ๊ฐ ์์ ์ ์ ๋ ฅ์ด ๋๋ค.
- RNN ๊ฐ ์ ์ ๋ง์ง๋ง ์์ ์ ์๋์ํ๋ฅผ context ๋ฒกํฐ๋ก ๋ง๋ ํ, ๋์ฝ๋๋ก ๋๊ฒจ์ค๋ค.
โถ ๋์ฝ๋
- ์ด๊ธฐ ์ ๋ ฅ์ผ๋ก ๋ฌธ์ฅ์ ์์์ ์๋ฏธํ๋ <sos>๊ฐ ์ ๋ ฅ๋จ
- ๋์ฝ๋๋ <sos>๊ฐ ์ ๋ ฅ๋๋ฉด ๋ค์์ ๋ฑ์ฅํ ํ๋ฅ ์ด ๋์ ๋จ์ด๋ฅผ ์์ธก
- ์ฌ๊ธฐ์๋ ์ฒซ๋ฒ์งธ ์์ ์ ๋์ฌ ๋จ์ด๋ฅผ Je๋ก ์์ธก
- ์ด๋ ๊ฒ ๋์ฝ๋๋ ๊ธฐ๋ณธ์ ์ผ๋ก ๋ค์์ ์ฌ ๋จ์ด๋ฅผ ์์ธกํ๊ณ , ๊ทธ ์์ธกํ ๋จ์ด๋ฅผ ๋ค์ ์์ ์ RNN์ ์ ์ ๋ ฅ์ผ๋ก ๋ฃ๋ ํ์๋ฅผ ๋ฐ๋ณต
- ๋ฌธ์ฅ์ด ๋๋ฌ๋ค๋ ์ฌ๋ณผ์ธ <eos>๊ฐ ์์ธก๋ ๋๊น์ง ๋ฐ๋ณต๋๋ค.
2. seq2seq๋ก ๊ธฐ๊ณ ๋ฒ์ญ๊ธฐ ๊ตฌํํ๊ธฐ (์ฐธ๊ณ )
๋ณธ ์ฝ๋๋ https://wikidocs.net/24996 ํด๋น ๋งํฌ๋ฅผ ์ฐธ๊ณ ํ์ฌ ์์ฑํ๋ค. ์ฐ์ ์ ๊ธ์ ์ฐจ์์์์ ๋ฒ์ญ๊ธฐ๋ฅผ ๊ตฌํํ ๊ฒ์ด๋ค. (ํ ํฐ์ ๋จ์๊ฐ ๋จ์ด๊ฐ ์๋๋ผ ๊ธ์(์ํ๋ฒณ) ์ด๋ผ๋ ์๋ฏธ์ด๋ค!)
2.1 ๋ฐ์ดํฐ์
โถ ๋ฐ์ดํฐ ์ถ์ฒ
๊ธฐ๊ณ ๋ฒ์ญ์ ํ๋ จ์ํค๊ธฐ ์ํด์ ํ๋ จ ๋ฐ์ดํฐ๋ก ๋ณ๋ ฌ ์ฝํผ์ค๊ฐ ํ์ํ๋ค. http://www.manythings.org/anki ํด๋น ๋งํฌ์์ ๋ค์ด๋ฐ์ ํ๋์ค-์์ด ๋ณ๋ ฌ ์ฝํผ์ค์ธ fran-eng.zip ์ด๋ผ๋ ํ์ผ์ ์ฌ์ฉํ ๊ฒ์ด๋ค. ํด๋น ํ์ผ์ ์์ถ์ ํ๊ณ , fra.txt ๋ผ๋ ํ์ผ์ ์ฌ์ฉํด์ ์ค์ตํด๋ณด์.
โถ ๋ณ๋ ฌ ์ฝํผ์ค
๋ณ๋ ฌ ์ฝํผ์ค๋ 'ํ๊น '์์ ๊ณผ๋ ์ด์ง ๋ค๋ฅด๋ค. ํ๊น ์์ ์ ๋ชจ๋ ๋ฐ์ดํฐ์ ์์ ๊ธธ์ด๊ฐ ๊ฐ๋ค๋ ํน์ง์ด ์์ง๋ง ๋ณ๋ ฌ ๋ฐ์ดํฐ๋ ๊ทธ๋ ์ง ์๋ค. ์๋ฅผ๋ค์ด '๋๋ ํ์์ด๋ค' ๋๊ฐ์ ํ ํฐ์ผ๋ก ์ด๋ฃจ์ด์ง ๋ฌธ์ฅ์ 'I am a student' 4๊ฐ์ ํ ํฐ์ผ๋ก ๊ตฌ์ฑ๋ ๋ฌธ์ฅ์ผ๋ก ๋ฒ์ญ๋๋ค. ์ด์ฒ๋ผ seq2seq๋ ์ ๋ ฅ ์ํ์ค์ ์ถ๋ ฅ ์ํ์ค์ ๊ธธ์ด๊ฐ ๋ค๋ฅผ ์ ์๋ค๊ณ ๊ฐ์ ํ๊ณ ์ค์ต์ ์งํํ๋ค.
โถ fra.txt์ ๊ตฌ์ฑ
Watch me. Regardez-moi !
์ด๋ ๊ฒ ์ผ์ชฝ์ ์์ด ๋ฌธ์ฅ๊ณผ ์ค๋ฅธ์ชฝ์ ํ๋์ค์ด ๋ฌธ์ฅ ์ฌ์ด์ ํญ์ผ๋ก ๊ตฌ๋ณ๋๋ ๊ตฌ์กฐ๊ฐ ํ๋์ ์ํ์ด๋ค.
์ด์ ๊ฐ์ ๋ฐ์ดํฐ๊ฐ 16๋ง๊ฐ์ ๋ณ๋ ฌ ๋ฌธ์ฅ ์ํ์ ํฌํจํ๊ณ ์๋ค.
2.2 import
import pandas as pd
import urllib3
import zipfile
import shutil
import os
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical
2.3 ๋ฐ์ดํฐ ๋ถ๋ฌ์ค๊ธฐ
from google.colab import drive
drive.mount('/content/gdrive/')
PATH = '/content/gdrive/MyDrive/Colab Notebooks/NLPstudy/'
- google colab์ ์ฌ์ฉํ๊ธฐ์ google drive๋ฅผ ๋ง์ดํธํด์ฃผ๊ณ , ๊ฒฝ๋ก๋ฅผ PATH์ ์ ์ฅํด์ค๋ค.
lines = pd.read_csv(PATH+'fra.txt', names = ['src', 'tar', 'lic'], sep='\t')
# src๋ source์ ์ค์๋ง๋ก ์
๋ ฅ ๋ฌธ์ฅ์ ์๋ฏธ / tar์ target์ผ๋ก ๋ฒ์ญํ๊ณ ์ํ๋ ๋ฌธ์ฅ์ ์๋ฏธ
del lines['lic']
len(lines) #๋ฐ์ดํฐ์ ๊ฐฏ์ 19๋ง๊ฐ์ ๋
- ์ ๋ ฅ๋ฌธ์ฅ์ src, ๋ฒ์ญํ๊ณ ์ ํ๋ ๋ฌธ์ฅ์ tar ์ผ๋ก ์ง์ ํ๋ค.
- ๋ฐ์ดํฐ๋ ์ด 19๋ง๊ฐ ์ ๋ ๋๋ค.
2.4 ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ
โถ ๋ฐ์ดํฐ ์ ํ
lines = lines.loc[:, 'src':'tar']
lines = lines[0:60000] # 6๋ง๊ฐ๋ง ์ ์ฅ
lines.sample(10) #๋๋ค์ผ๋ก ๋ฝ์ 10๊ฐ์ ์ํ
- ์ค์ต์ 6๋ง๊ฐ์ ๋ฐ์ดํฐ๋ง ์ฌ์ฉํ ๊ฒ์ด๋ค.
โถ ๋ฐ์ดํฐ ํํ ๋ง์ ธ์ฃผ๊ธฐ
lines.tar = lines.tar.apply(lambda x : '\t '+ x + ' \n')
lines.sample(10)
- ์์ seq2seq๋ชจ๋ธ ์ค๋ช ์ ๋ณด๋ฉด ๋ฒ์ญํ๊ณ ์ถ์ ๋ฌธ์ฅ์๋ ์์ ์ฌ๋ณผ์ธ <sos>์ ๋ฌธ์ฅ์ด ๋๋๋ ์ฌ๋ณผ์ธ <eos>๊ฐ ์กด์ฌํ๋ค.
- ํด๋น ๋ฌธ์ฅ์๋ ์์๊ณผ ๋์ ์ฌ๋ณผ์ด ์ ํด์ ธ ์์ง ์์ผ๋ ๋ฐ๋ก ์ง์ ํด์ ๋ฃ์ด์ฃผ์.
- <sos> : \t / <eos> : \n ์ผ๋ก ๋ฃ์ด์ฃผ์๊ณ , tar ๋ฌธ์ฅ์ ์ฌ๋ณผ์ด ์ ์์ ์ผ๋ก ์ ๋ ฅ๋จ์ ํ์ธํ์.
โถ ๊ธ์ ์งํฉ ๊ตฌ์ถ
- ์์ด๋ 79๊ธ์, ํ๋์ค์ด๋ 105๊ธ์๋ก ๊ตฌ์ฑ๋์ด์๋ค. ๊ธ์๋ฅผ ์ผ๋ถ๋ง ์ถ๋ ฅํด๋ณด๋ฉด, ์๋์ ๊ฐ๋ค.
# ๊ธ์ ์งํฉ ์์ฑ (ํ ํฐ๋จ์๊ฐ ์๋ '๊ธ์'๋จ์๋ก ์งํฉ์ ๊ตฌ์ถ)
src_vocab=set()
for line in lines.src: # 1์ค์ฉ ์ฝ์
for char in line: # 1๊ฐ์ ๊ธ์์ฉ ์ฝ์
src_vocab.add(char)
tar_vocab=set()
for line in lines.tar:
for char in line:
tar_vocab.add(char)
src_vocab_size = len(src_vocab)+1
tar_vocab_size = len(tar_vocab)+1
print(src_vocab_size)
print(tar_vocab_size)
src_vocab = sorted(list(src_vocab))
tar_vocab = sorted(list(tar_vocab))
print(src_vocab[45:75]) #์ผ๋ถ๋ง ์ถ๋ ฅํด๋ณด์
print(tar_vocab[45:75])
- ๊ธ์ฅ ์ธ๋ฑ์ค๋ฅผ ๋ถ์ฌํด์ dictionary๋ก ๋ง๋ค ๊ฒ์ด๋ค.
# ๊ธ์์ ์ธ๋ฑ์ค๋ฅผ ๋ถ์ฌํ์ฌ dict์ผ๋ก ํํ
src_to_index = dict([(word, i+1) for i, word in enumerate(src_vocab)])
tar_to_index = dict([(word, i+1) for i, word in enumerate(tar_vocab)])
print(src_to_index)
print(tar_to_index)
โถ ์ ์ ์ธ์ฝ๋ฉ ์งํ
# ์ธ๋ฑ์ค๊ฐ ๋ถ์ฌ๋ ๊ธ์ ์งํฉ์ผ๋ก ๋ถํฐ ๊ฐ๊ณ ์๋ ํ๋ จ ๋ฐ์ดํฐ์ ์ ์ ์ธ์ฝ๋ฉ์ ์ํ
# ์
๋ ฅ์ด ๋ ์์ด ๋ฌธ์ฅ ์ํ์ ๋ํด ์ธ์ฝ๋ฉ์ ์ํ
encoder_input = []
for line in lines.src: #์
๋ ฅ ๋ฐ์ดํฐ์์ 1์ค์ฉ ๋ฌธ์ฅ์ ์ฝ์
temp_X = []
for w in line: #๊ฐ ์ค์์ 1๊ฐ์ฉ ๊ธ์๋ฅผ ์ฝ์
temp_X.append(src_to_index[w]) # ๊ธ์๋ฅผ ํด๋น๋๋ ์ ์๋ก ๋ณํ
encoder_input.append(temp_X)
#์์ 5๊ฐ๋ง ์ถ๋ ฅํด๋ณด์
print(encoder_input[:5])
# ๋์ฝ๋์ ์
๋ ฅ์ด ๋ ํ๋์ค์ด ๋ฐ์ดํฐ์ ๋ํด ์ ์ ์ธ์ฝ๋ฉ ์ํ
decoder_input = []
for line in lines.tar:
temp_X = []
for w in line:
temp_X.append(tar_to_index[w])
decoder_input.append(temp_X)
#์์ 5๊ฐ๋ง ์ถ๋ ฅํด๋ณด์
print(decoder_input[:5])
# ๋์ฝ๋์ ์์ธก๊ฐ๊ณผ ๋น๊ตํ๊ธฐ ์ํ ์ค์ ๊ฐ์ด ํ์ํจ
# ์ค์ ๊ฐ์๋ ์์ ์ฌ๋ณผ <sos>๊ฐ ์์ ํ์๊ฐ ์์
# ์์์ฌ๋ณผ์ธ \t ๋ฅผ ์ ๊ฑฐํด์ฃผ์
decoder_target = []
for line in lines.tar:
t=0 # t๊ฐ 0์ธ ์ฒ์์ ์ ์ธํ๊ณ temp_X์ appendํด์ฃผ๋๊ณผ์ ์ ๊ฑฐ์น๋ค.
temp_X = []
for w in line:
if t>0:
temp_X.append(tar_to_index[w])
t=t+1
decoder_target.append(temp_X)
print(decoder_target[:5])
- decoder input์์๋ ๋ชจ๋ ์๋ฒ ๋ฉ๋ ๊ฒฐ๊ณผ๊ฐ 1๋ก ์์ํ๋ค. (์์์ฌ๋ณผ๋๋ฌธ)
- decoder target ์์๋ 1์ด ์ ์ธ๋ ๊ฒ์ ๋ณด๋ฉด, ์ ์์ ์ผ๋ก ์์ ์ฌ๋ณผ์ด ์ ๊ฑฐ๋จ์ ์ ์ ์๋ค.
โถ ํจ๋ฉ
# ์์ด์ ํ๋์ค์ด์ ๊ฐ์ฅ ๊ธด ๋จ์ด ํ์
max_src_len = max([len(line) for line in lines.src])
max_tar_len = max([len(line) for line in lines.tar])
print(max_src_len)
print(max_tar_len)
- ์์ด 24 / ํ๋์ค์ด 76 ๋ก ํจ๋ฉ ์งํ
- ์ด๋ฒ ๋ณ๋ ฌ ๋ฐ์ดํฐ๋ ํ๋์ ์์ด๋๋ผ๋ ์ ๋ถ ๊ธธ์ด๊ฐ ๋ค๋ฅผ ์ ์์ผ๋ฏ๋ก (์ฒ์์ ์ธ๊ธ)
- ํจ๋ฉ์ ํ ๋๋ ๋ ๊ฐ์ ๋ฐ์ดํฐ ๊ธธ์ด๋ฅผ ์ ๋ถ ๋์ผํ๊ฒ ํ ํ์ ์์
# ํจ๋ฉ ์งํ
encoder_input = pad_sequences(encoder_input, maxlen=max_src_len, padding='post')
decoder_input = pad_sequences(decoder_input, maxlen=max_tar_len, padding='post')
decoder_target = pad_sequences(decoder_target, maxlen=max_tar_len, padding='post')
โถ ์ํซ ์ธ์ฝ๋ฉ
# ๊ธ์ ๋จ์ ๋ฒ์ญ๊ธฐ ์ด๋ฏ๋ก ์๋ ์๋ฒ ๋ฉ์ ๋ณ๋๋ก ์ฌ์ฉํ์ง ์์ ๊ฒ
# ์์ธก๊ณผ ์ค์ฐจ ์ธก์ ์ ์ฌ์ฉ๋๋ ์ค์ ๊ฐ ๋ฟ๋ง ์๋๋ผ ์
๋ ฅ๊ฐ๋ ์ํซ๋ฒกํฐ๋ฅผ ์ฌ์ฉํจ
encoder_input = to_categorical(encoder_input)
decoder_input = to_categorical(decoder_input)
decoder_target = to_categorical(decoder_target)
2.5 ๊ต์ฌ ๊ฐ์ (Teacher Forcing)
- ์ด์ seq2seq ๋ชจ๋ธ ์ค๋ช ์ ๋ณด๋ฉด, ํ์ฌ ์์ ์ ๋์ฝ๋ ์ ์ ์ ๋ ฅ์ ์ด์ ๋์ฝ๋์ ์ถ๋ ฅ์ ์ ๋ ฅ์ผ๋ก ๋ฐ๋๋ค๊ณ ๋ฐฐ์ ๋ค. ๊ทผ๋ฐ ์ decoder_input์ด ํ์ํ๊ฐ?
- ์ด์ ์์ ์ ์ค์ ๊ฐ์ ํ์ฌ ์์ ์ ๋์ฝ๋ ์ ์ ์ ๋ ฅ๊ฐ์ผ๋ก ํ๋ ๋ฐฉ๋ฒ์ ์ฌ์ฉํ ๊ฒ
- ์ด์ ์์ ์ ๋์ฝ๋ ์ ์ ์์ธก์ด ํ๋ ธ๋๋ฐ ์ด๋ฅผ ํ์ฌ ์์ ์ ๋์ฝ๋ ์ ์ ์ ๋ ฅ๊ฐ์ผ๋ก ์ฌ์ฉํ๊ฒ๋๋ฉด ํ์ฌ ์์ ์ ๋์ฝ๋ ์ ์ ์์ธก๊น์ง ์๋ชป๋ ๊ฐ๋ฅ์ฑ์ด ๋๊ณ , ์ด๋ ์ฐ์์ ์ผ๋ก ์์ฉํ์ฌ ๋์ฝ๋ ์ ์ฒด ์์ธก์ ์ด๋ ต๊ฒ ํ๊ธฐ ๋๋ฌธ
- ์ด์ ๊ฐ์ด RNN์ ๋ชจ๋ ์์ ์ ๋ํด์ ์ด์ ์์ ์ ์์ธก๊ฐ ๋์ ์ค์ ๊ฐ์ ์ ๋ ฅ์ผ๋ก ์ฃผ๋ ๋ฐฉ๋ฒ์ ๊ต์ฌ๊ฐ์ ๋ผ๊ณ ํจ
2.6 seq2seq ๊ธฐ๊ณ ๋ฒ์ญ ํ๋ จ์ํค๊ธฐ
from tensorflow.keras.layers import Input, LSTM, Embedding, Dense
from tensorflow.keras.models import Model
import numpy as np
encoder_inputs = Input(shape=(None, src_vocab_size))
encoder_lstm = LSTM(units=256, return_state=True)
encoder_outputs, state_h, state_c = encoder_lstm(encoder_inputs)
# encoder_outputs๋ ๊ฐ์ด ๋ฆฌํด๋ฐ๊ธฐ๋ ํ์ง๋ง ์ฌ๊ธฐ์๋ ํ์์์ผ๋ฏ๋ก ์ด ๊ฐ์ ๋ฒ๋ฆผ.
encoder_states = [state_h, state_c]
# LSTM์ ๋ฐ๋๋ผ RNN๊ณผ๋ ๋ฌ๋ฆฌ ์ํ๊ฐ ๋ ๊ฐ. ๋ฐ๋ก ์๋ ์ํ์ ์
์ํ.
- LSTM์ ์๋์ํ ํฌ๊ธฐ๋ 256์ผ๋ก ์ ํ
- ์ธ์ฝ๋ ๋ด๋ถ์ํ๋ฅผ ๋์ฝ๋๋ก ์ ๋ฌํด์ผํ๊ธฐ์ return_state = True๋ก ์ค์
- LSTM์์ state_h, state_c๋ฅผ ๋ฆฌํด๋ฐ๋๋ฐ state_h๋ ์๋์ํ๊ณ state_c๋ ์ ์ํ์ ํด๋น
- ์ฆ ์๋์ํ์ ์ ์ํ๋ฅผ ์ ๋ฌํด์ค๋ค.
- ์ด ๋๊ฐ๋ฅผ ecoder states์ ์ ์ฅํ๊ณ , ์ด๋ฅผ ๋์ฝ๋์ ์ ๋ฌํ๋ฏ๋ก์ ๋๊ฐ์ง ์ํ๋ฅผ ๋ชจ๋ ๋์ฝ๋๋ก ์ ๋ฌํ ๊ฒ.
- ์์ ๋ฐฐ์ด ๋ฌธ๋งฅ๋ฒกํฐ(context vector)๊ฐ encoder_state ์ ํด๋นํ๋ ๊ฒ!
decoder_inputs = Input(shape=(None, tar_vocab_size))
decoder_lstm = LSTM(units=256, return_sequences=True, return_state=True)
decoder_outputs, _, _= decoder_lstm(decoder_inputs, initial_state=encoder_states)
# ๋์ฝ๋์ ์ฒซ ์ํ๋ฅผ ์ธ์ฝ๋์ ์๋ ์ํ, ์
์ํ๋ก ํฉ๋๋ค.
decoder_softmax_layer = Dense(tar_vocab_size, activation='softmax')
decoder_outputs = decoder_softmax_layer(decoder_outputs)
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.compile(optimizer="rmsprop", loss="categorical_crossentropy")
- ๋์ฝ๋๋ ์ธ์ฝ๋์ ๋ง์ง๋ง ์ํ๋ฅผ ์ด๊ธฐ ์๋์ํ๋ก ์ฌ์ฉ ์ฆ, initial_state์ ์ธ์๊ฐ์ผ๋ก encoder_state๋ฅผ ๋ฐ๋๊ฒ์ด ์ด์ ํด๋น
- ๋์ฝ๋์ ์๋์ํ๋ 256์ผ๋ก ์ฃผ์ด์ง
- ๋์ฝ๋๋ ์๋์ํ, ์ ์ํ๋ฅผ ๋ฆฌํดํ๊ธฐ๋ ํ์ง๋ง ํ๋ จ ๊ณผ์ ์์๋ ์ฌ์ฉํ์ง ์์
- ๊ทธ ํ ์ถ๋ ฅ์ธต์ ํ๋์ค์ด์ ๋จ์ด ์งํฉ์ ํฌ๊ธฐ๋งํผ ๋ด๋ฐ์ ๋ฐฐ์นํ ํ, ์ํํธ๋งฅ์ค ํจ์๋ฅผ ์ฌ์ฉํ์ฌ ์ค์ ๊ฐ๊ณผ์ ์ค์ฐจ๋ฅผ ๊ตฌํจ
model.fit(x=[encoder_input, decoder_input], y=decoder_target, batch_size=64, epochs=3, validation_split=0.2)
- ์๊ฐ๊ด๊ณ์ epochs๋ฅผ 3์ผ๋ก๋ง ํ์ตํ๋ค. (๋ณธ ์ฝ๋๋ 50์ผ๋ก ์งํํ๋ค.)
- ์ ๋ ฅ์ผ๋ก๋ encoder_input, ๋์ฝ๋์ ์ค์ ๊ฐ์ธ decoder_input์ ๋ฃ๋๋ค.
2.7 seq2seq ๊ธฐ๊ณ ๋ฒ์ญ๊ธฐ ๋์์ํค๊ธฐ
ํ๋ จ๊ณผ์ ๊ณผ ๋์๊ณผ์ ์ ๋ค๋ฅด๋ค. ๋์ ๊ณผ์ ์์๋ encoder model๊ณผ decoder model์ ๋ฐ๋ก ๋ง๋ค์ด์ ์
๋ ฅํ ๋ฌธ์ฅ์ ๋ํด์ ๊ธฐ๊ณ๋ฒ์ญ์ ํ๋๋ก ๋ชจ๋ธ์ ์กฐ์ ํ ํ, ๋์์์ผ๋ณผ ๊ฒ์ด๋ค. (ํ๋ จ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ์ง ์๋ ๊ฒ์ธ๊ฐ? ํ .. )
์ ๋ฐ์ ์ธ ๋ฒ์ญ ๋์ ๋จ๊ณ๋ฅผ ์ ๋ฆฌํ๋ฉด ์๋์ ๊ฐ๋ค.
- ๋ฒ์ญํ๊ณ ์ ํ๋ ์ ๋ ฅ ๋ฌธ์ฅ์ด ์ธ์ฝ๋์ ๋ค์ด๊ฐ์ ์๋ ์ํ์ ์ ์ํ๋ฅผ ์ป๋๋ค.
- ์ํ์ <sos>์ ํด๋นํ๋ \t ๋ฅผ ๋์ฝ๋๋ก ๋ณด๋ธ๋ค.
- ๋์ฝ๋๊ฐ <eos>์ ํด๋นํ๋ \n์ด ๋์ฌ ๋๊น์ง ๋ค์ ๋ฌธ์๋ฅผ ์์ธกํ๋ ํ๋์ ๋ฐ๋ณตํ๋ค.
โถ ์ธ์ฝ๋ ๋ชจ๋ธ ์ ์
# ์์์ ์ ์ํ encoder_input = Input(shape=(None, src_vocab_size))
# outputs = encoder_states : encoder_lstm์ผ๋ก ๋ถํฐ ๋ฐ์ ์๋์ํ์ ์
์ํ๊ฐ [state_h, state_c]
encoder_model = Model(inputs=encoder_inputs, outputs=encoder_states)
- ์ฐ์ ์ธ์ฝ๋๋ฅผ encoder_model ๋ก ์ ์ํ์.
โถ ๋์ฝ๋ ๋ชจ๋ธ ์ ์
# ์ด์ ์์ ์ ์ํ๋ค์ ์ ์ฅํ๋ ํ
์
decoder_state_input_h = Input(shape=(256,))
decoder_state_input_c = Input(shape=(256,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_outputs, state_h, state_c = decoder_lstm(decoder_inputs, initial_state=decoder_states_inputs)
# ๋ฌธ์ฅ์ ๋ค์ ๋จ์ด๋ฅผ ์์ธกํ๊ธฐ ์ํด์ ์ด๊ธฐ ์ํ(initial_state)๋ฅผ ์ด์ ์์ ์ ์ํ๋ก ์ฌ์ฉ. ์ด๋ ๋ค์ ํจ์ decode_sequence()์ ๊ตฌํ
decoder_states = [state_h, state_c]
# ํ๋ จ ๊ณผ์ ์์์ ๋ฌ๋ฆฌ LSTM์ ๋ฆฌํดํ๋ ์๋ ์ํ์ ์
์ํ์ธ state_h์ state_c๋ฅผ ๋ฒ๋ฆฌ์ง ์์.
decoder_outputs = decoder_softmax_layer(decoder_outputs)
decoder_model = Model(inputs=[decoder_inputs] + decoder_states_inputs, outputs=[decoder_outputs] + decoder_states)
- ์ด์ ์์ ์ ์ํ๋ฅผ ์ ์ฅํ๋ ํ ์๋ฅผ ๋ง๋ค๊ณ , decoder_lstm ์ผ๋ก ๋ถํฐ ๋
index_to_src = dict((i, char) for char, i in src_to_index.items())
index_to_tar = dict((i, char) for char, i in tar_to_index.items())
- ์ธ๋ฑ์ค๋ก ๋ถํฐ ๋จ์ด๋ฅผ ์ป์ ์ ์๋ index_to_src / index_to_tar์ ๋ง๋ค์ด์ค๋ค.
def decode_sequence(input_seq):
# ์
๋ ฅ์ผ๋ก๋ถํฐ ์ธ์ฝ๋์ ์ํ๋ฅผ ์ป์
states_value = encoder_model.predict(input_seq)
# <SOS>์ ํด๋นํ๋ ์-ํซ ๋ฒกํฐ ์์ฑ
target_seq = np.zeros((1, 1, tar_vocab_size))
target_seq[0, 0, tar_to_index['\t']] = 1.
stop_condition = False
decoded_sentence = ""
# stop_condition์ด True๊ฐ ๋ ๋๊น์ง ๋ฃจํ ๋ฐ๋ณต
while not stop_condition:
# ์ด์ ์์ ์ ์ํ states_value๋ฅผ ํ ์์ ์ ์ด๊ธฐ ์ํ๋ก ์ฌ์ฉ
output_tokens, h, c = decoder_model.predict([target_seq] + states_value)
# ์์ธก ๊ฒฐ๊ณผ๋ฅผ ๋ฌธ์๋ก ๋ณํ
sampled_token_index = np.argmax(output_tokens[0, -1, :])
sampled_char = index_to_tar[sampled_token_index]
# ํ์ฌ ์์ ์ ์์ธก ๋ฌธ์๋ฅผ ์์ธก ๋ฌธ์ฅ์ ์ถ๊ฐ
decoded_sentence += sampled_char
# <eos>์ ๋๋ฌํ๊ฑฐ๋ ์ต๋ ๊ธธ์ด๋ฅผ ๋์ผ๋ฉด ์ค๋จ.
if (sampled_char == '\n' or
len(decoded_sentence) > max_tar_len):
stop_condition = True
# ํ์ฌ ์์ ์ ์์ธก ๊ฒฐ๊ณผ๋ฅผ ๋ค์ ์์ ์ ์
๋ ฅ์ผ๋ก ์ฌ์ฉํ๊ธฐ ์ํด ์ ์ฅ
target_seq = np.zeros((1, 1, tar_vocab_size))
target_seq[0, 0, sampled_token_index] = 1.
# ํ์ฌ ์์ ์ ์ํ๋ฅผ ๋ค์ ์์ ์ ์ํ๋ก ์ฌ์ฉํ๊ธฐ ์ํด ์ ์ฅ
states_value = [h, c]
return decoded_sentence
for seq_index in [3,50,100,300,1001]: # ์
๋ ฅ ๋ฌธ์ฅ์ ์ธ๋ฑ์ค
input_seq = encoder_input[seq_index: seq_index + 1]
decoded_sentence = decode_sequence(input_seq)
print(35 * "-")
print('์
๋ ฅ ๋ฌธ์ฅ:', lines.src[seq_index])
print('์ ๋ต ๋ฌธ์ฅ:', lines.tar[seq_index][1:len(lines.tar[seq_index])-1]) # '\t'์ '\n'์ ๋นผ๊ณ ์ถ๋ ฅ
print('๋ฒ์ญ๊ธฐ๊ฐ ๋ฒ์ญํ ๋ฌธ์ฅ:', decoded_sentence[:len(decoded_sentence)-1]) # '\n'์ ๋นผ๊ณ ์ถ๋ ฅ
์ด๋ ๊ฒ ๊ธ์ ์์ค์์์ ๊ธฐ๊ณ๋ฒ์ญ์ ๊ตฌํํด ๋ณด์๋ค. ๋จ์ด ์์ค์์์ ๋ฒ์ญ๊ธฐ ๊ตฌํ ๋ฐฉ๋ฒ์ด ๊ถ๊ธํ๋ค๋ฉด ์ฌ๊ธฐ๋ฅผ ํด๋ฆญํ๊ณ ๋ณ๋๋ก ๊ณต๋ถ๋ฅผ ์งํํ์.