人間のように自然な会話ができる人工知能を作りたい。でも会話データがないから何にも始まらない。じゃ、あきらめよ。と、思っている方がいるかもしれません。
そんなあなたにTwitter APIを使ったSeq2Seq用の会話データを収集する方法について紹介したいと思います。
前準備
TwitterのAPIを利用するにはAPI keyやAccess tokenが必要です。これらはDeveloper登録をすると取得できます。
Developer登録の方法は次の記事に書いてあります。
www.pytry3g.com
すでにTwitter APIを利用できる方は次にtweepyをインストールしてください。pip install tweepy
でインストールできます。
今回集めるデータの形式
今回集める会話データはSeq2Seq用のものです。私は現在、自然な会話ができる人工知能を開発中で、下の画像のような形でSeq2Seqの会話データを使いたいと考えています。

開発中のシステムの流れはこんな感じです。
- 人工知能が適当に呟く。
- あるユーザーがその呟きにリプライする。
- 人工知能があるユーザーのリプライにリプライする。
集めたSeq2Seqの会話データのペア(message, response)を収集し、それぞれを2と3に当てはめてシステムを作る予定です。
ちなみに、今回は2つの発話をペアとして一つの会話にしていますが3つの発話を一つの会話にする方法もあります。
www.pytry3g.com
会話データを収集する
会話データを収集する流れはこんな感じです。
- Streaming APIを使って現在世界で投稿されているツイートを収集する
- 収集したツイートにフィルタリングをかける
- フィルタリングにかけたツイートの返信先のツイートを収集する
- 収集した会話データをデータベースに保存。
会話データを収集する方法は次の通りです。
収集したツイートには固有のIDが含まれています。もし、そのツイートがあるユーザーが投稿したツイートに対するリプライだった場合にはin_reply_to_status_id
には返信先のツイートのIDが入っています。

したがって今回のケースの場合、会話データを集めるためにはまずStreaming APIを使って③にあたるツイートを収集し、そのツイートに含まれるin_reply_to_status_id
を抜き出して、今度はRest APIを使って抜き出したIDから過去に投稿されたツイートを回収します。
今回のプログラムのメインであるreply2reply.py
を簡単に説明していきます。(※ソースコードは下に置いてあります。)
1. Streaming APIを使って世界で投稿されているツイートを収集する
tweepyとStreaming APIを使ってツイートを収集します。
まずはデータベースのパスと中身を用意する。(※パスは各自変えてください。)
if __name__ == "__main__":
db_path = str(Path.home()) + "/Archive/blog/reply2reply.db"
sql = """create table seq2seq(
reply1 text NOT NULL,
reply2 text NOT NULL
);"""
argparse
を使って新しくデータベースを作るか作らないのかを決める。デフォルトではすでにデータベースが存在し、引き続きそのデータベースを使うものとしている。
新しく作るには0以外の数値を与えてプログラムを実行すればOK。
parser = argparse.ArgumentParser()
parser.add_argument("--new", type=int, default=0,
help="0 indicates the database is already there.")
args = parser.parse_args()
if args.new:
util.build_database(db_path, sql)
次にStreaming APIを使うための準備をする。
auth = twitter_api.get_authentication()
api = tweepy.API(auth)
listener = StreamingListener(api, db_path)
streaming = tweepy.Stream(auth, listener)
この辺のことは以前紹介しました。
用意ができたら、sample()
を使ってツイートの収集を始めます。
while True:
try:
streaming.sample()
except KeyboardInterrupt:
streaming.disconnect()
break
except Exception as e:
streaming.disconnect()
print(e)
収集したツイートはStreamingLister
に定義したon_status()
に渡されます。
2. 収集したツイートにフィルタリングをかける
リアルタイムで投稿されているツイートがリプライか確認し、フィルタリングをかけます。
def on_status(self, status):
if self.is_status_tweet(status):
return
if self.is_invalid_tweet(status):
return
is_status_tweet()
でツイートがリプライか確認。ツイートがリプライだった場合、status.in_reply_to_status_id
には数値(ツイートの返信先ツイートのID)が入っていて、リプライでなければNone
が入っています。
is_invalid_tweet()
でツイートを以下の項目よりフィルタリングをかけます。
- ツイートが日本語以外の言語か?
- スクリーンネームにbotが入っているか?
- URLを含んでいる?
- ハッシュタグを含んでいる?
- 複数の相手にリプライしている?
- ツイートの文字数が30以上?
この6つの項目のいずれかに当てはまる場合、そのツイートとはさよならです。
3. フィルタリングにかけたツイートの返信先のツイートを収集する
フィルタリングにかけて生き残ったツイートはlookup_ids
行きとなります。
lookup_ids
はリスト、reply_list
は辞書になっていて、lookup_ids
には生き残ったツイートの返信先のツイートのIDをぶち込み、reply_list
には生き残ったツイートとそのツイートの返信先を紐づける役割があり、key
にはリプライ先のツイートのID、value
には生き残ったツイートの(ツイート、スクリーンネームやユーザIDが入っている。)
if self.is_invalid_tweet(status):
return
self.lookup_ids.append(status.in_reply_to_status_id)
self.reply_list[status.in_reply_to_status_id] = Tweet(status)
print(".", end='', flush=True)
if len(self.lookup_ids) >= 100:
print("\nCalling statuses_lookup API...")
statuses = self.api.statuses_lookup(self.lookup_ids)
フィルタリングにツイートをかけて生き残るたびに返信先のツイートを探しているとAPIの利用制限にすぐに引っかかってしまう恐れがあるので、それを回避するためにstatuses_lookup()
を使う。
これは、与えられたツイートのIDからそのツイートの情報(ユーザーネーム、スクリーンネーム、ツイート、投稿時間など)を取ってきてくれるメソッドで、一度に最大で100個のツイートを回収できる。
したがって、lookup_ids
に100ツイート入るまで何もせず待つ。
4. 収集した会話データをデータベースに保存
100ツイート集まれば、返信先のツイートもフィルタリングにかける。そして、生き残った返信先のツイートと返信元のツイートをデータベースに保存する。
if len(self.lookup_ids) >= 100:
print("\nCalling statuses_lookup API...")
statuses = self.api.statuses_lookup(self.lookup_ids)
for status in statuses:
if self.is_status_tweet(status):
continue
if self.is_invalid_tweet(status):
continue
reply = self.reply_list[status.id]
if status.user.id == reply.user_id:
continue
self.add_conversation(status, reply)
self.print_conversation(status, reply)
集めた会話データを眺める
以下のコードを実行すると集めた会話データを見ることができます。Enterを押すと次の会話が見れるようにしました。
import sqlite3
from pathlib import Path
db_path = str(Path.home()) + "/Archive/blog/reply2reply.db"
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("select reply1, reply2 from seq2seq")
for row in cursor:
try:
reply1, reply2 = row
print('------------ 会話 ------------')
print(" User:", reply1)
print("Agent:", reply2)
input()
except KeyboardInterrupt:
break
except Exception as e:
print(e)
break
今回使ったソースコード。
import tweepy
consumer_key = "consumer_key"
consumer_secret = "consumer_secret"
access_token = "access_token"
access_token_secret = "access_token_secret"
def get_authentication():
auth = tweepy.OAuthHandler(consumer_key, consumer_secret)
auth.set_access_token(access_token, access_token_secret)
return auth
util.py
import sqlite3
def build_database(db_path, sql):
conn = sqlite3.connect(db_path)
cur = conn.cursor()
cur.execute(sql)
conn.commit()
conn.close()
reply2reply.py
このプログラムを実行すると会話データを集めることができます。
データベースをまだ作ってない場合。
python reply2reply.py --new 1
データベースを既に作っていて、引き続き同じデータベースにデータを保存する場合。
python reply2reply.py
長いから閉じています。(クリックすると開きます。)
import argparse
import twitter_api
import tweepy
import re
import sqlite3
import util
from datetime import timedelta
from pathlib import Path
class Tweet:
def __init__(self, status):
self.in_reply_to_status_id = status.in_reply_to_status_id
self.text = status.text
self.created_at = status.created_at
self.screen_name = status.user.screen_name
self.username = status.user.name
self.user_id = status.user.id
class StreamingListener(tweepy.StreamListener):
def __init__(self, api, db_path):
super(StreamingListener, self).__init__()
self.api = api
self.db = db_path
self.lookup_ids = []
self.reply_list = {}
def on_status(self, status):
if self.is_status_tweet(status):
return
if self.is_invalid_tweet(status):
return
self.lookup_ids.append(status.in_reply_to_status_id)
self.reply_list[status.in_reply_to_status_id] = Tweet(status)
print(".", end='', flush=True)
if len(self.lookup_ids) >= 100:
print("\nCalling statuses_lookup API...")
statuses = self.api.statuses_lookup(self.lookup_ids)
for status in statuses:
if self.is_status_tweet(status):
continue
if self.is_invalid_tweet(status):
continue
reply = self.reply_list[status.id]
if status.user.id == reply.user_id:
continue
self.add_conversation(status, reply)
self.print_conversation(status, reply)
self.lookup_ids = []
self.reply_list = {}
def print_conversation(self, reply1, reply2):
print('------------ 会話 ------------')
print("reply1:@{}({}): {}".format(
reply1.user.screen_name,
reply1.created_at + timedelta(hours=+9),
reply1.text)
)
print("reply2:@{}({}): {}".format(
reply2.screen_name,
reply2.created_at + timedelta(hours=+9),
reply2.text)
)
print('-'*30)
def is_status_tweet(self, status):
if status.in_reply_to_status_id is None:
return True
def is_reply_tweet(self, status):
if isinstance(status.in_reply_to_status_id, int):
return True
def is_invalid_tweet(self, status):
if status.user.lang != "ja":
return True
if "bot" in status.user.screen_name:
return True
if re.search(r"https?://", status.text):
return True
if re.search(r"#(\w+)", status.text):
return True
tweet = re.sub(r"@([A-Za-z0-9_]+)", "<unk>", status.text)
if tweet.split().count("<unk>") > 1:
return True
if len(tweet.replace("<unk>", "")) > 30:
return True
return False
def cleanup_text(self, status):
text = re.sub(r"@([A-Za-z0-9_]+) ", "", status.text)
text = re.sub("\s+", ' ', text).strip()
return text.replace(">", ">").replace("<", "<").replace("&", "&")
def on_error(self, code):
pass
def add_conversation(self, reply1, reply2):
reply1 = self.cleanup_text(reply1)
reply2 = self.cleanup_text(reply2)
conn = sqlite3.connect(self.db)
cur = conn.cursor()
cur.execute(
"insert into seq2seq"
"(reply1, reply2)"
"values (?, ?)",
[reply1, reply2]
)
conn.commit()
conn.close()
if __name__ == "__main__":
db_path = str(Path.home()) + "/Archive/blog/reply2reply.db"
sql = """create table seq2seq(
reply1 text NOT NULL,
reply2 text NOT NULL
);"""
parser = argparse.ArgumentParser()
parser.add_argument("--new", type=int, default=0,
help="0 indicates the database is already there.")
args = parser.parse_args()
if args.new:
util.build_database(db_path, sql)
auth = twitter_api.get_authentication()
api = tweepy.API(auth)
listener = StreamingListener(api, db_path)
streaming = tweepy.Stream(auth, listener)
while True:
try:
streaming.sample()
except KeyboardInterrupt:
streaming.disconnect()
break
except Exception as e:
streaming.disconnect()
print(e)