どん底から這い上がるまでの記録

どん底から這い上がりたいけど這い上がれない人がいろいろ書くブログ(主にプログラミング)

Twitter APIを使ってSeq2Seq用の会話データを収集してみる

 

人間のように自然な会話ができる人工知能を作りたい。でも会話データがないから何にも始まらない。じゃ、あきらめよ。と、思っている方がいるかもしれません。

そんなあなたにTwitter APIを使ったSeq2Seq用の会話データを収集する方法について紹介したいと思います。

 

 

前準備

TwitterAPIを利用するにはAPI keyやAccess tokenが必要です。これらはDeveloper登録をすると取得できます。

Developer登録の方法は次の記事に書いてあります。

www.pytry3g.com

すでにTwitter APIを利用できる方は次にtweepyをインストールしてください。pip install tweepyでインストールできます。

今回集めるデータの形式

今回集める会話データはSeq2Seq用のものです。私は現在、自然な会話ができる人工知能を開発中で、下の画像のような形でSeq2Seqの会話データを使いたいと考えています。

f:id:pytry3g:20190306075016p:plain

開発中のシステムの流れはこんな感じです。

  1. 人工知能が適当に呟く。
  2. あるユーザーがその呟きにリプライする。
  3. 人工知能があるユーザーのリプライにリプライする。

集めたSeq2Seqの会話データのペア(message, response)を収集し、それぞれを2と3に当てはめてシステムを作る予定です。

ちなみに、今回は2つの発話をペアとして一つの会話にしていますが3つの発話を一つの会話にする方法もあります。

www.pytry3g.com

会話データを収集する

会話データを収集する流れはこんな感じです。

  1. Streaming APIを使って現在世界で投稿されているツイートを収集する
  2. 収集したツイートにフィルタリングをかける
  3. フィルタリングにかけたツイートの返信先のツイートを収集する
  4. 収集した会話データをデータベースに保存。

 

会話データを収集する方法は次の通りです。

収集したツイートには固有のIDが含まれています。もし、そのツイートがあるユーザーが投稿したツイートに対するリプライだった場合にはin_reply_to_status_idには返信先のツイートのIDが入っています。

f:id:pytry3g:20190307063941p:plain

したがって今回のケースの場合、会話データを集めるためにはまずStreaming APIを使って③にあたるツイートを収集し、そのツイートに含まれるin_reply_to_status_idを抜き出して、今度はRest APIを使って抜き出したIDから過去に投稿されたツイートを回収します。

今回のプログラムのメインであるreply2reply.pyを簡単に説明していきます。(※ソースコードは下に置いてあります。)

1. Streaming APIを使って世界で投稿されているツイートを収集する

tweepyとStreaming APIを使ってツイートを収集します。

まずはデータベースのパスと中身を用意する。(※パスは各自変えてください。)

if __name__ == "__main__":
    # データベースのpath
    db_path = str(Path.home()) + "/Archive/blog/reply2reply.db"
    sql = """create table seq2seq(
    reply1 text NOT NULL,
    reply2 text NOT NULL
    );"""

 

argparseを使って新しくデータベースを作るか作らないのかを決める。デフォルトではすでにデータベースが存在し、引き続きそのデータベースを使うものとしている。

新しく作るには0以外の数値を与えてプログラムを実行すればOK。

    parser = argparse.ArgumentParser()
    parser.add_argument("--new", type=int, default=0,
                        help="0 indicates the database is already there.")
    args = parser.parse_args()
    if args.new:
        util.build_database(db_path, sql)

 

次にStreaming APIを使うための準備をする。

    # 認証
    auth = twitter_api.get_authentication()
    api = tweepy.API(auth)
    # Streaming API
    listener = StreamingListener(api, db_path)
    streaming = tweepy.Stream(auth, listener)

この辺のことは以前紹介しました。

用意ができたら、sample()を使ってツイートの収集を始めます。

    while True:
        try:
            streaming.sample()
        except KeyboardInterrupt:
            streaming.disconnect()
            break
        except Exception as e:
            streaming.disconnect()
            print(e)

 

収集したツイートはStreamingListerに定義したon_status()に渡されます。

2. 収集したツイートにフィルタリングをかける

リアルタイムで投稿されているツイートがリプライか確認し、フィルタリングをかけます。

    def on_status(self, status):
        # リプライか?
        if self.is_status_tweet(status):
            return
        # フィルタリング
        if self.is_invalid_tweet(status):
            return

is_status_tweet()でツイートがリプライか確認。ツイートがリプライだった場合、status.in_reply_to_status_idには数値(ツイートの返信先ツイートのID)が入っていて、リプライでなければNoneが入っています。

is_invalid_tweet()でツイートを以下の項目よりフィルタリングをかけます。

  1. ツイートが日本語以外の言語か?
  2. スクリーンネームbotが入っているか?
  3. URLを含んでいる?
  4. ハッシュタグを含んでいる?
  5. 複数の相手にリプライしている?
  6. ツイートの文字数が30以上?

この6つの項目のいずれかに当てはまる場合、そのツイートとはさよならです。

3. フィルタリングにかけたツイートの返信先のツイートを収集する

フィルタリングにかけて生き残ったツイートはlookup_ids行きとなります。

lookup_idsはリスト、reply_listは辞書になっていて、lookup_idsには生き残ったツイートの返信先のツイートのIDをぶち込み、reply_listには生き残ったツイートとそのツイートの返信先を紐づける役割があり、keyにはリプライ先のツイートのID、valueには生き残ったツイートの(ツイート、スクリーンネームやユーザIDが入っている。)

        # フィルタリング
        if self.is_invalid_tweet(status):
            return
        self.lookup_ids.append(status.in_reply_to_status_id)
        self.reply_list[status.in_reply_to_status_id] = Tweet(status)
        print(".", end='', flush=True)
        if len(self.lookup_ids) >= 100:
            print("\nCalling statuses_lookup API...")
            statuses = self.api.statuses_lookup(self.lookup_ids)

フィルタリングにツイートをかけて生き残るたびに返信先のツイートを探しているとAPIの利用制限にすぐに引っかかってしまう恐れがあるので、それを回避するためにstatuses_lookup()を使う。

これは、与えられたツイートのIDからそのツイートの情報(ユーザーネーム、スクリーンネーム、ツイート、投稿時間など)を取ってきてくれるメソッドで、一度に最大で100個のツイートを回収できる。

したがって、lookup_idsに100ツイート入るまで何もせず待つ。

4. 収集した会話データをデータベースに保存

100ツイート集まれば、返信先のツイートもフィルタリングにかける。そして、生き残った返信先のツイートと返信元のツイートをデータベースに保存する。

        if len(self.lookup_ids) >= 100:
            print("\nCalling statuses_lookup API...")
            statuses = self.api.statuses_lookup(self.lookup_ids)

            for status in statuses:
                if self.is_status_tweet(status):
                    continue

                if self.is_invalid_tweet(status):
                    continue

                reply = self.reply_list[status.id] 
                # リプライ先が同じユーザー?
                if status.user.id == reply.user_id:
                    continue
                
                self.add_conversation(status, reply)
                self.print_conversation(status, reply)

集めた会話データを眺める

以下のコードを実行すると集めた会話データを見ることができます。Enterを押すと次の会話が見れるようにしました。

import sqlite3
from pathlib import Path

# パスは各自変えてください。
db_path = str(Path.home()) + "/Archive/blog/reply2reply.db"

conn = sqlite3.connect(db_path)
cursor = conn.cursor()

cursor.execute("select reply1, reply2 from seq2seq")
for row in cursor:
    try:
        reply1, reply2 = row
        print('------------ 会話 ------------')
        print(" User:", reply1)
        print("Agent:", reply2)
        input()
    except KeyboardInterrupt:
        break
    except Exception as e:
        print(e)
        break

ソースコード

今回使ったソースコード

twitter_api.py

import tweepy

# 自分のkeyとtokenに置き換えてください。
consumer_key = "consumer_key"
consumer_secret = "consumer_secret"
access_token = "access_token"
access_token_secret = "access_token_secret"

def get_authentication():
    auth = tweepy.OAuthHandler(consumer_key, consumer_secret)
    auth.set_access_token(access_token, access_token_secret)
    return auth

util.py

import sqlite3

def build_database(db_path, sql):
    conn = sqlite3.connect(db_path)
    cur = conn.cursor()
    cur.execute(sql)
    conn.commit()
    conn.close()

reply2reply.py

このプログラムを実行すると会話データを集めることができます。

データベースをまだ作ってない場合。

python reply2reply.py --new 1

データベースを既に作っていて、引き続き同じデータベースにデータを保存する場合。

python reply2reply.py
長いから閉じています。(クリックすると開きます。)
import argparse
import twitter_api
import tweepy
import re
import sqlite3
import util
from datetime import timedelta
from pathlib import Path


class Tweet:
    def __init__(self, status):
        self.in_reply_to_status_id = status.in_reply_to_status_id
        self.text = status.text
        self.created_at = status.created_at
        self.screen_name = status.user.screen_name
        self.username = status.user.name
        self.user_id = status.user.id


class StreamingListener(tweepy.StreamListener):
    def __init__(self, api, db_path):
        super(StreamingListener, self).__init__()
        self.api = api
        self.db = db_path 
        self.lookup_ids = []
        self.reply_list = {}

    def on_status(self, status):
        # リプライか?
        if self.is_status_tweet(status):
            return
        # フィルタリング
        if self.is_invalid_tweet(status):
            return
        self.lookup_ids.append(status.in_reply_to_status_id)
        self.reply_list[status.in_reply_to_status_id] = Tweet(status)
        print(".", end='', flush=True)
        if len(self.lookup_ids) >= 100:
            print("\nCalling statuses_lookup API...")
            statuses = self.api.statuses_lookup(self.lookup_ids)

            for status in statuses:
                if self.is_status_tweet(status):
                    continue

                if self.is_invalid_tweet(status):
                    continue

                reply = self.reply_list[status.id] 
                # リプライ先が同じユーザー?
                if status.user.id == reply.user_id:
                    continue
                
                self.add_conversation(status, reply)
                self.print_conversation(status, reply)

            self.lookup_ids = []
            self.reply_list = {}

    def print_conversation(self, reply1, reply2):
        print('------------ 会話 ------------')
        print("reply1:@{}({}): {}".format(
            reply1.user.screen_name,
            reply1.created_at + timedelta(hours=+9),
            reply1.text)
        )
        print("reply2:@{}({}): {}".format(
            reply2.screen_name,
            reply2.created_at + timedelta(hours=+9),
            reply2.text)
        )
        print('-'*30)

    def is_status_tweet(self, status):
        # リプライではないただのツイートか確認
        if status.in_reply_to_status_id is None:
            return True

    def is_reply_tweet(self, status):
        # リプライならTrue
        if isinstance(status.in_reply_to_status_id, int):
            return True

    def is_invalid_tweet(self, status):
        # いらないツイートか調べる
        if status.user.lang != "ja":
            # 日本語か確認
            return True
        
        if "bot" in status.user.screen_name:
            return True

        if re.search(r"https?://", status.text):
            return True

        if re.search(r"#(\w+)", status.text):
            # ハッシュタグ
            return True

        # 複数の相手にリプライしているか?
        tweet = re.sub(r"@([A-Za-z0-9_]+)", "<unk>", status.text)
        if tweet.split().count("<unk>") > 1:
            return True

        # 長いツイートか?
        if len(tweet.replace("<unk>", "")) > 30:
            return True

        return False

    def cleanup_text(self, status):
        text = re.sub(r"@([A-Za-z0-9_]+) ", "", status.text)
        text = re.sub("\s+", ' ', text).strip()
        return text.replace("&gt;", ">").replace("&lt;", "<").replace("&amp;", "&")

    def on_error(self, code):
        pass

    def add_conversation(self, reply1, reply2):
        reply1 = self.cleanup_text(reply1)
        reply2 = self.cleanup_text(reply2)
        conn = sqlite3.connect(self.db)
        cur = conn.cursor()
        cur.execute(
            "insert into seq2seq"
            "(reply1, reply2)"
            "values (?, ?)",
            [reply1, reply2]
        )
        conn.commit()
        conn.close()


if __name__ == "__main__":
    # データベースのpath
    db_path = str(Path.home()) + "/Archive/blog/reply2reply.db"
    sql = """create table seq2seq(
    reply1 text NOT NULL,
    reply2 text NOT NULL
    );"""
    parser = argparse.ArgumentParser()
    parser.add_argument("--new", type=int, default=0,
                        help="0 indicates the database is already there.")
    args = parser.parse_args()
    if args.new:
        util.build_database(db_path, sql)
    # 認証
    auth = twitter_api.get_authentication()
    api = tweepy.API(auth)
    # Streaming API
    listener = StreamingListener(api, db_path)
    streaming = tweepy.Stream(auth, listener)
    while True:
        try:
            streaming.sample()
        except KeyboardInterrupt:
            streaming.disconnect()
            break
        except Exception as e:
            streaming.disconnect()
            print(e)