PyTorchによるSeq2seqの実装

 

PyTorchを使ってSeq2seqの実装をしてみたので、その実装方法について書いてみます。学習に用いるデータセットには実装したSeq2seqが正しく学習できているか簡単に評価できる、トイ・プロブレム(足し算)を使っています。

Google Colaboratoryで動作確認をしたので、誰でも簡単にプログラムを実行できるようにしています。

 

 

環境

今回書いたプログラムはGoogle Colaboratoryで実装から実行まですることができるので環境構築に手間取ることはないと思います。さらにGoogle ColaboratoryのGPUを利用するので短時間で学習からテストをし、実装したSeq2seqの性能を見ることができます。

Google Colaboratoryの導入の仕方についてはこちら

Seq2seq

プログラムを実装する前にSeq2seqについて少し書いてみます。

Seq2seqとは Sequence to Sequence モデル、モデルの入力と出力が時系列データになっており、時系列データを別の時系列データに変換してくれます。

今回は最大で3桁の数字どうしを足し算する式を入力とし、出力にはその計算結果を与えます。例えば、入力は100+10のような式の場合、出力にはその計算結果である110を与えるようにします。Seq2seqモデルは100+10が与えられたとき計算結果の110が出力されるように学習していきます。

Seq2seqはEncoder Decoder モデルとも呼ばれ、イメージとしては以下のようになります。

f:id:pytry3g:20181124192523p:plain

Encoderに入力として変換したいデータを渡しEncoderでそれを処理したのちDecoderに処理結果を渡して入力したデータの変換結果をDecoderが出力します。

EncoderとDecoderにはRNNが使われており与えられた時系列データをそれぞれ処理します。

f:id:pytry3g:20181125190343p:plain

上の例ではRNNとしてGRUを使っています。GRUではなく、シンプルなRNNやLSTMを用いることも可能です。

Encoderには入力として文字単位か単語単位でデータを与えますが、今回は与える入力データを文字単位に分割してEncoderに与えます。入力データが100+10だとすると、”1”、”0”、”0”、”+”、”1”、”0”に分割します。

ただ、ここで注意しなければならないことがあります。

今回は最大で3桁の数字の足し算を入力とします。つまり、入力データには73+2や112+999が与えられます。ここで問題となるのが入力データの長さの異なる可変長のデータという点です。73+2や112+999を文字単位に分割すると長さは4と7になります。

今回実装するプログラムではミニバッチ学習をするので入力データの形状を揃える必要があります。そのための方法としてパディングを使います。パディングをすることにより、可変長のデータの長さを全て同じ長さに揃えミニバッチ学習が可能になります。

入力データはパディングを使って長さを7に統一します。長さが7未満の場合、余ったところにはパディングを表す記号 <pad> を入れることにします。

73+2にパディングを適用すると、

”7”、”3”、”+”、”2”、”<pad>”、”<pad>”、”<pad>”

 

112+999だと、

”1”、”1”、”2”、”+”、”9”、”9”、”9”

のようになります。

Encoderに入力データを与えると、h(隠れ状態)が出力されます。この隠れ状態には入力したデータから正しい出力結果を得るための必要な情報に変換したものが入っており、これを初期状態としてDecoderに渡しています。

DecoderはEncoderから渡された情報と教師データから損失を計算していき学習を進めていきます。

学習データの用意

学習データを用意します。今回は2万のデータを作成し訓練用とテスト用に分けて使用します。学習データを用意するコードは下に置いてあります。ここでは、コードの説明をざっくりしていきます。

まず、文字がkey、文字IDがvalueの辞書を用意します。

word2id = {str(i): i for i in range(10)}
word2id.update({"<pad>": 10, "+": 11, "<eos>": 12})

この辞書には0~9の数字とパディングのための<pad>、演算子+、区切り文字である<eos>が登録されています。(※区切り文字は後ほど出てきます。)

Encoderに入力データを与えるときは100+10のような式を文字単位で分解し、辞書をもとにしてIDに変換します。

 

生成したデータセットのひとつを見てみると、

print(train_x[0])
# [3, 11, 7, 10, 10, 10, 10]

このデータは入力データで3+7を変換したものになります。

 

print(train_t[0])
# [12, 1, 0, 12, 10, 10]

このデータは出力データで計算結果の10を変換したものになります。12は区切り文字の<eos>です。

EncoderとDecoderの実装

PyTorchを使ってEncoderとDecoderの実装をしていきます。コード全体は下に置いてあります。

Encoderのイメージとしてはこんな感じです。

3+11を入力としたときのEncoderです。文字IDに変換したものをEmbedding Layerに渡していきます。

f:id:pytry3g:20181126171103p:plain

このイメージをコードで実装すると以下になります。

class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, batch_size=100):
        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=word2id["<pad>"])
        self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True)

    def forward(self, indices):
        embedding = self.word_embeddings(indices)
        if embedding.dim() == 2:
            embedding = torch.unsqueeze(embedding, 1)
        _, state = self.gru(embedding, torch.zeros(1, self.batch_size, self.hidden_dim, device=device))
        
        return state

Encoderでは受け取った文字IDをEmbeding Layerに与え単語ベクトルに変換したものをGRUに渡しています。GRUの出力stateh(隠れ状態)です。Encoderは与えられた文字IDを順番に処理していき、最後の文字を処理した結果、stateを出力します。このstateには入力データを変換するための必要な情報が入っており、これをDecoderへと渡します。

 

DecoderはEncoderから出力されたstateと教師データをもとに目的とする出力を生成します。

f:id:pytry3g:20181126180743p:plain

出力が10のときのDecoderの図です。12は<eos>、区切り文字です。

この図をコードで実装すると以下のようになります。

class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, batch_size=100):
        super(Decoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=word2id["<pad>"])
        self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
        self.output = nn.Linear(hidden_dim, vocab_size)

    def forward(self, index, state):
        embedding = self.word_embeddings(index)
        if embedding.dim() == 2:
            embedding = torch.unsqueeze(embedding, 1)
        gruout, state = self.gru(embedding, state)
        output = self.output(gruout)
        return output, state

区切り文字を入れるには2つの理由があります。

  1. Decoderに文字列の生成開始を知らせる。
  2. 文字列生成の終了を知らせる。

出力を10としたとき、文字IDに変換すると[12, 1, 0, 12, 10, 10] になります。今回は3桁どうしの足し算をしているので計算結果は最大で4桁になり、ここに区切り文字を入れることにより出力の長さは6になります。

出力が10のときのDecoderには教師データとして[12, 1, 0, 12, 10, 10] を使いますが、入力には[12, 1, 0, 12, 10] を与え、それに対応する出力として[1, 0, 12, 10, 10] を与えます。

学習

コードは下に置いてあります。だいだい100epochでそれなりの正答率がとれます。

学習コードはここのPyTorchのTutorialを参考にして実装しました。

Encoderには入力データを文字列単位で与えていますが、Decoderには1文字ずつ与えています。

テスト

コードは下に置いてあります。100epochで正答率はだいだい8割くらいです。10epochごとに保存したモデルにテストデータを渡してテストをしています。

学習するときのパラメータをいじったり、epoch数を増やせばもっといい結果が得られるかもしれません。

ソースコード

データの用意

import random
from sklearn.model_selection import train_test_split

word2id = {str(i): i for i in range(10)}
word2id.update({"<pad>": 10, "+": 11, "<eos>": 12})
id2word = {v: k for k, v in word2id.items()}

def load_dataset(N=20000):
    def generate_number():
        number = [random.choice(list("0123456789")) for _ in range(random.randint(1, 3))] 
        # a <= N <= b random.randint(a, b)
        return int("".join(number))
    
    def padding(string, training=True):
        string = "{:*<7s}".format(string) if training else "{:*<6s}".format(string)
        return string.replace("*", "<pad>")
    
    def transform(string, seq_len=7):
        tmp = []
        for i, c in enumerate(string):
            try:
                tmp.append(word2id[c])
            except:
                tmp += [word2id["<pad>"]] * (seq_len - i)
                break
        return tmp
        
    data = []
    target = []    
    for _ in range(N):
        x = generate_number()
        y = generate_number()
        z = x + y
        left = padding(str(x) + "+" + str(y))
        right = padding(str(z), training=False)
        data.append(transform(left))
        right = transform(right, seq_len=6)
        right = [12] + right[:5]
        right[right.index(10)] = 12
        target.append(right)
        
    return data, target

data, target = load_dataset()
train_x, test_x, train_t, test_t = train_test_split(data, target, test_size=0.1)

EncoderとDecoder

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


embedding_dim = 16
hidden_dim = 128
vocab_size = len(word2id)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, batch_size=100):
        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=word2id["<pad>"])
        self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True)

    def forward(self, indices):
        embedding = self.word_embeddings(indices)
        if embedding.dim() == 2:
            embedding = torch.unsqueeze(embedding, 1)
        _, state = self.gru(embedding, torch.zeros(1, self.batch_size, self.hidden_dim, device=device))
        
        return state


class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, batch_size=100):
        super(Decoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=word2id["<pad>"])
        self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
        self.output = nn.Linear(hidden_dim, vocab_size)

    def forward(self, index, state):
        embedding = self.word_embeddings(index)
        if embedding.dim() == 2:
            embedding = torch.unsqueeze(embedding, 1)
        gruout, state = self.gru(embedding, state)
        output = self.output(gruout)
        return output, state


encoder = Encoder(vocab_size, embedding_dim, hidden_dim).to(device)
decoder = Decoder(vocab_size, embedding_dim, hidden_dim).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=word2id["<pad>"])

# Initialize opotimizers
encoder_optimizer = optim.Adam(encoder.parameters(), lr=0.001)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=0.001)

学習する

from datetime import datetime
from sklearn.utils import shuffle

batch_size=100
def train2batch(data, target, batch_size=100):
    input_batch = []
    output_batch = []
    data, target = shuffle(data, target)
    
    for i in range(0, len(data), batch_size):
        input_tmp = []
        output_tmp = []
        for j in range(i, i+batch_size):
            input_tmp.append(data[j])
            output_tmp.append(target[j])
        input_batch.append(input_tmp)
        output_batch.append(output_tmp)
    return input_batch, output_batch

def get_current_time():
    return datetime.now().strftime("%Y-%m-%d %H:%M:%S")


print("Training...")
n_epoch = 100
for epoch in range(1, n_epoch+1):

    
    input_batch, output_batch = train2batch(train_x, train_t)
    for i in range(len(input_batch)):
        # Zero gradients
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        # Prepare tensor
        inputs = torch.tensor(input_batch[i], device=device)
        outputs = torch.tensor(output_batch[i], device=device)
        # Forward pass through encoder
        encoder_hidden = encoder(inputs)
        # Create source and target
        source = outputs[:, :-1]
        target = outputs[:, 1:]
        decoder_hidden = encoder_hidden
        
        # Forward batch of sequences through decoder one time step at a time
        loss = 0
        for i in range(source.size(1)):
            decoder_output, decoder_hidden = decoder(source[:, i], decoder_hidden)
            decoder_output = torch.squeeze(decoder_output)
            loss += criterion(decoder_output, target[:, i])

        # Perform backpropagation
        loss.backward()
        
        # Adjust model weights
        encoder_optimizer.step()
        decoder_optimizer.step()
    
    if epoch % 10 == 0:
        print(get_current_time(), "Epoch %d: %.2f" % (epoch, loss.item()))        
        
    if epoch % 10 == 0:
        model_name = "seq2seq_calculator_v{}.pt".format(epoch)
        torch.save({
            'encoder_model': encoder.state_dict(),
            'decoder_model': decoder.state_dict(),
        }, model_name)
        print("Saving the checkpoint...")

テストする

import numpy as np


result = """---------------
Q:{:>10s}
A:{:>10s}
T/F: {}
---------------"""

encoder = Encoder(vocab_size, embedding_dim, hidden_dim, batch_size=1).to(device)
decoder = Decoder(vocab_size, embedding_dim, hidden_dim, batch_size=1).to(device)


for epoch in range(10, 101, 10):
    model_name = "seq2seq_calculator_v{}.pt".format(epoch)
    checkpoint = torch.load(model_name)
    encoder.load_state_dict(checkpoint["encoder_model"])
    decoder.load_state_dict(checkpoint["decoder_model"])
    
    print("Checkpoint {:>3d}".format(epoch))
    print("-"*30)
    accuracy = 0
    with torch.no_grad():
        for i in range(len(test_x)):
            x = test_x[i]
            input_tensor = torch.tensor([x], device=device)
            state = encoder(input_tensor)
            token = "<eos>"
            try:
                padded_idx_x = x.index(word2id["<pad>"])
            except ValueError:
                padded_idx_x = len(x)
            left = "".join(map(lambda c: str(id2word[c]), x[:padded_idx_x]))
            right = []
            for _ in range(7):
                index = word2id[token]
                input_tensor = torch.tensor([index], device=device)
                output, state = decoder(input_tensor, state)
                prob = F.softmax(torch.squeeze(output))
                index = torch.argmax(prob.cpu().detach()).item()
                token = id2word[index]
                if token == "<eos>":
                    break
                right.append(token)
            right = "".join(right)
            flag = ["F", "T"][eval(left) == int(right)]
            #print(result.format(left, right, flag))
            if flag == "T":
                accuracy += 1
    print("Accuracy: {:.2f}".format(accuracy / len(test_x)))
    print("-"*30)

Deep Learningによる自然言語処理に興味がある方は

Deep Learningによる自然言語処理を勉強したい方へのおすすめの本です。

tensorflow, chainerやPyTorchといったフレームワークを使わずにゼロからnumpyを使ってディープラーニングの実装をしています。

扱っている内容はword2vec, RNN, GRU, seq2seqやAttentionなど、、、

Google Colaboratoryでプログラムを動かす

最後にGoogle Colaboratoryを使った学習方法を紹介します。(※ランタイムのタイプはPython 3、GPUに設定しておきます。)

まずは、drive内のディレクトリをマウントします。

www.pytry3g.com

 

以下のコードをセルにコピペして実行します。

from google.colab import drive
drive.mount('/content/drive')

これによりdriveのMy Drive以下にGoogle Driveのルートディレクトリがマウントされます。

cd drive/My\ Drive/任意の作業ディレクト

次にGoogle Driveにある任意の作業ディレクトリに移動します。この作業ディレクトリには学習が進んでいくとモデルが保存されます。

!pip install torch > /dev/null

次にPyTorchをインストールします。

あとはソースコードを順に実行していくだけです。