CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: TensorFlow/Keras で Word2Vec の SGNS を実装してみる

以前のエントリで、Word2Vec の CBOW (ContinuousBagOfWords) モデルを TensorFlow/Keras で実装した。 CBOW は、コンテキスト (周辺語) からターゲット (入力語) を推定する多値分類のタスクが考え方のベースになっている。

blog.amedama.jp

今回扱うのは、CBOW と対を成すモデルの Skip Gram をベースにした SGNS (Skip Gram with Negative Sampling) になる。 Skip Gram では、CBOW とは反対にターゲット (入力語) からコンテキスト (周辺語) を推定する多値分類のタスクを扱う。 ただし、with Negative Sampling と付くことで、タスクを多値分類から二値分類にして計算量を削減している。 SGNS では、ターゲットとコンテキストを入力にして、それらが共起 (Co-occurrence) するか否かを推定することになる。 コーパスを処理して実際に共起する単語ペアを正例、出現頻度を元にランダムにサンプルした単語ペアを共起していない負例としてモデルに与える。

使った環境は次のとおり。

$ sw_vers
ProductName:    macOS
ProductVersion: 11.2.1
BuildVersion:   20D74
$ python -V  
Python 3.8.7

下準備

まずは、必要なパッケージをインストールする。

$ pip install tensorflow gensim scipy tqdm

そして、コーパスとして PTB (Penn Treebank) データセットをダウンロードしておく。

$ wget https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.train.txt

サンプルコード

早速だけど、サンプルコードを以下に示す。 いくらかマジックナンバーがコードに残ってしまっていて、あんまりキレイではないかも。 各エポックの終了時には、WordSim353 データセットを使って単語間類似度で単語埋め込みを評価している。 また、学習が終わった後には、いくつかの単語で類似する単語や類推語の結果を確認している。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from __future__ import annotations

import re
from itertools import count
from typing import Iterable
from typing import Iterator
from functools import reduce
from functools import partial
from collections import Counter

import numpy as np  # type: ignore
import tensorflow as tf  # type: ignore
from tensorflow.keras import Model  # type: ignore
from tensorflow.keras.layers import Embedding  # type: ignore
from tensorflow import Tensor  # type: ignore
from tensorflow.data import Dataset  # type: ignore
from tensorflow.keras.optimizers import Adam  # type: ignore
from tensorflow.keras.layers import Dense  # type: ignore
from tensorflow.keras.models import Sequential  # type: ignore
from tensorflow.keras.callbacks import Callback  # type: ignore
from tensorflow.keras.layers import Dot  # type: ignore
from tensorflow.keras.layers import Flatten  # type: ignore
from tensorflow.keras.losses import BinaryCrossentropy  # type: ignore
from gensim.test.utils import datapath  # type: ignore
from scipy.stats import pearsonr  # type: ignore
from tqdm import tqdm  # type: ignore


class SkipGramWithNegativeSampling(Model):
    """Word2Vec の SGNS モデルを実装したクラス"""

    def __init__(self, vocab_size: int, embedding_size: int):
        super().__init__()

        # ターゲット (入力語) の埋め込み
        self.target_embedding = Embedding(input_dim=vocab_size,
                                          input_shape=(1, ),
                                          output_dim=embedding_size,
                                          name='word_embedding',
                                          )
        # コンテキスト (周辺語) の埋め込み
        self.context_embedding = Embedding(input_dim=vocab_size,
                                           input_shape=(1, ),
                                           output_dim=embedding_size,
                                           name='context_embedding',
                                           )

        self.dot = Dot(axes=1)
        self.output_layer = Sequential([
            Flatten(),
            Dense(1, activation='sigmoid'),
        ])

    def call(self, inputs: Tensor) -> Tensor:
        # ターゲットのベクトルを取り出す
        target_label = inputs[:, 0]
        target_vector = self.target_embedding(target_label)
        # コンテキストのベクトルを取り出す
        context_label = inputs[:, 1]
        context_vector = self.context_embedding(context_label)
        # ターゲットとコンテキストの内積を計算する
        x = self.dot([target_vector, context_vector])
        # 共起したか・していないかを二値の確率にする
        prediction = self.output_layer(x)
        return prediction


def cosine_similarity_one_to_one(x, y):
    """1:1 のコサイン類似度"""
    nx = x / np.sqrt(np.sum(x ** 2))
    ny = y / np.sqrt(np.sum(y ** 2))
    return np.dot(nx, ny)


class WordSimilarity353Callback(Callback):
    """WordSim353 データセットを使って単語間の類似度を評価するコールバック"""

    def __init__(self, word_id_table: dict[str, int]):
        super().__init__()

        self.word_id_table = word_id_table
        self.model = None

        # 評価用データを読み込む
        self.eval_data = []
        wordsim_filepath = datapath('wordsim353.tsv')
        with open(wordsim_filepath, mode='r') as fp:
            # 最初の 2 行はヘッダなので読み飛ばす
            fp.readline()
            fp.readline()
            for line in fp:
                word1, word2, sim_score = line.strip().split('\t')
                self.eval_data.append((word1, word2, float(sim_score)))

    def set_model(self, model):
        self.model = model

    def on_epoch_end(self, epoch, logs=None):
        # モデルから学習させたレイヤーの重みを取り出す
        model_layers = {layer.name: layer for layer in self.model.layers}
        embedding_layer = model_layers['word_embedding']
        word_vectors = embedding_layer.weights[0].numpy()

        # 評価用データセットに含まれる単語間の類似度を計算する
        labels = []
        preds = []
        for word1, word2, sim_score in self.eval_data:
            # Out-of-Vocabulary な単語はスキップ
            if word1 not in self.word_id_table or word2 not in self.word_id_table:
                continue

            # コサイン類似度を計算する
            word1_vec = word_vectors[self.word_id_table[word1]]
            word2_vec = word_vectors[self.word_id_table[word2]]
            pred = cosine_similarity_one_to_one(word1_vec, word2_vec)
            preds.append(pred)
            # 正解ラベル
            labels.append(sim_score)

        # ピアソンの相関係数を求める
        r_score = pearsonr(labels, preds)[0]
        print(f'Pearson\'s r score with WordSim353: {r_score}')


def load_corpus(filepath: str) -> Iterator[str]:
    """テキストファイルからコーパスを読み出す"""
    with open(filepath, mode='r') as fp:
        for line in fp:
            # 改行コードは取り除く
            yield line.rstrip()


def sentences_to_words(sentences: Iterable[str], lower: bool = True) -> Iterator[list[str]]:
    """文章を単語に分割する"""
    for sentence in sentences:
        if lower:
            sentence = sentence.lower()
        words = re.split('\\W+', sentence)
        yield [word for word in words if len(word) > 0]  # 空文字は取り除く


def word_id_mappings(sentences: Iterable[Iterable[str]]) -> dict[str, int]:
    """単語を ID に変換する対応テーブルを作る"""
    counter = count(start=0)

    word_to_id = {}
    for sentence in sentences:
        for word in sentence:

            if word in word_to_id:
                # 登録済みの単語はスキップする
                continue

            # 単語の識別子を採番する
            word_id = next(counter)
            word_to_id[word] = word_id

    return word_to_id


def words_to_ids(sentences: Iterable[list[str]], word_to_id: dict[str, int]) -> Iterator[list[int]]:
    # 単語を対応するインデックスに変換する
    for words in sentences:
        # NOTE: Out-of-Vocabulary への対応がない
        yield [word_to_id[word] for word in words]


def extract_contexts(word_ids: Tensor, window_size: int) -> Tensor:
    """コンテキストの単語をラベル形式で得る"""
    target_ids = word_ids[:-window_size]
    context_ids = word_ids[window_size:]
    # ウィンドウサイズ分ずらした Tensor 同士をくっつける
    co_occurrences = tf.transpose([target_ids, context_ids])
    # 逆順でも共起したのは同じ
    reversed_co_occurrences = tf.transpose([context_ids, target_ids])
    concat_co_occurrences = tf.concat([co_occurrences,
                                       reversed_co_occurrences],
                                      axis=0)
    # ラベル (正例なので 1)
    labels = tf.ones_like(concat_co_occurrences[:, 0],
                          dtype=tf.int8)
    return concat_co_occurrences, labels


def positive_pipeline(ds: Dataset, window_size: int) -> Dataset:
    """正例を供給するパイプライン"""

    ctx_ds_list = []
    for window in range(1, window_size + 1):
        partialed = partial(extract_contexts, window_size=window)
        # ウィンドウサイズごとに共起した単語を抽出する
        mapped_ds = ds.map(partialed,
                           num_parallel_calls=tf.data.AUTOTUNE,
                           deterministic=False)
        ctx_ds_list.append(mapped_ds)

    # すべての Dataset をつなげる
    context_ds = reduce(lambda l_ds, r_ds: l_ds.concatenate(r_ds),
                        ctx_ds_list)
    return context_ds


def word_frequency(sentences: Iterable[Iterable[str]], word_id_table: dict[str, int]) -> dict[str, int]:
    """単語の出現頻度を調べる"""
    counter = Counter(word for words in sentences for word in words)
    id_count = {word_id_table[word]: count for word, count in counter.items()}
    # ID 順でソートされた出現頻度
    sorted_freq = np.array([count for _, count in sorted(id_count.items(), key=lambda x: x[0])],
                           dtype=np.int32)
    return sorted_freq


def noisy_word_pairs(word_proba: list[float], eps: float = 1e-6) -> Iterator[Tensor]:
    """単語の出現頻度を元にネガティブサンプルの単語ペアを生成するジェネレータ関数"""
    p = tf.constant(word_proba) + eps
    logits = tf.math.log([p, p])
    while True:
        word_pair = tf.random.categorical(logits, num_samples=2**12)
        word_pair_t = tf.transpose(word_pair)
        # ラベル (負例なので 0)
        labels = tf.zeros_like(word_pair_t[:, 0], dtype=tf.int8)
        yield word_pair_t, labels


def negative_pipeline(sentences: Iterable[Iterable[str]], word_id_table: dict[str, int]) -> Dataset:
    """負例を供給するパイプライン"""
    # 単語の出現頻度からサンプリングテーブルを求める
    word_freq = word_frequency(sentences, word_id_table)
    word_proba = word_freq / np.sum(word_freq)
    # 0.75 乗することで、出現頻度の低い単語をちょっとだけ選ばれやすくする
    ADJUST_FACTOR = 0.75
    adjusted_word_proba = np.power(word_proba, ADJUST_FACTOR)
    adjusted_word_proba /= np.sum(adjusted_word_proba)
    # 単語の出現頻度を元にノイジーワードペアを生成する
    negative_ds = Dataset.from_generator(lambda: noisy_word_pairs(adjusted_word_proba),
                                         (tf.int32, tf.int8))

    return negative_ds


def batched_concat(pos_tensor: Tensor, neg_tensor: Tensor) -> Tensor:
    """正例と負例を直列に結合する関数"""
    pos_word_pairs, pos_labels = pos_tensor
    neg_word_pairs, neg_labels = neg_tensor
    word_pairs = tf.concat((pos_word_pairs, neg_word_pairs), axis=0)
    labels = tf.concat((pos_labels, neg_labels), axis=0)
    return word_pairs, labels


def skip_grams_with_negative_sampling_dataset(positive_ds: Dataset,
                                              negative_ds: Dataset,
                                              negative_sampling_ratio: int):
    """データセットで共起した単語ペアを正例、出現頻度を元にランダムに選んだ単語ペアを負例として供給するパイプライン"""
    positive_batch_size = 1024  # 正例の供給単位
    batched_pos_ds = positive_ds.unbatch().batch(positive_batch_size)
    batched_neg_ds = negative_ds.unbatch().batch(positive_batch_size * negative_sampling_ratio)
    zipped_ds = tf.data.Dataset.zip((batched_pos_ds, batched_neg_ds))
    concat_ds = zipped_ds.map(batched_concat,
                              num_parallel_calls=tf.data.AUTOTUNE,
                              deterministic=False).unbatch()
    # バッチサイズ単位でシャッフルする
    shuffle_buffer_size = positive_batch_size * (negative_sampling_ratio + 1)
    shuffled_ds = concat_ds.shuffle(buffer_size=shuffle_buffer_size)
    return shuffled_ds


def cosine_similarity_matrix(word_vectors: np.ndarray, eps: float = 1e-8) -> np.ndarray:
    """N:N のコサイン類似度を計算する"""
    word_norm = np.sqrt(np.sum(word_vectors ** 2, axis=1)).reshape(word_vectors.shape[0], -1)
    normalized_word_vectors = word_vectors / (word_norm + eps)
    cs_matrix = np.dot(normalized_word_vectors, normalized_word_vectors.T)
    return cs_matrix


def most_similar_words(similarities: np.ndarray, top_n: int = 5):
    """コサイン類似度が最も高い単語の ID を得る"""
    similar_word_ids = np.argsort(similarities)[::-1]
    top_n_word_ids = similar_word_ids[:top_n]
    top_n_word_sims = similarities[similar_word_ids][:top_n]
    return zip(top_n_word_ids, top_n_word_sims)


def cosine_similarity_one_to_many(word_vector: np.ndarray,
                                  word_vectors: np.ndarray,
                                  eps: float = 1e-8):
    """1:N のコサイン類似度"""
    normalized_word_vector = word_vector / np.sqrt(np.sum(word_vector ** 2))
    word_norm = np.sqrt(np.sum(word_vectors ** 2, axis=1)).reshape(word_vectors.shape[0], -1)
    normalized_word_vectors = word_vectors / (word_norm + eps)
    return np.dot(normalized_word_vector, normalized_word_vectors.T)


def main():
    # Penn Treebank コーパスを読み込む
    train_sentences = load_corpus('ptb.train.txt')

    # コーパスを単語に分割する
    train_corpus_words = list(sentences_to_words(train_sentences))

    # 単語に ID を振る
    word_to_id = word_id_mappings(train_corpus_words)

    # コーパスの語彙数
    vocab_size = len(word_to_id.keys())

    # データセットを準備する
    # ID に変換したコーパスを行ごとに読み出せるデータセット
    train_word_ids_ds = Dataset.from_generator(lambda: words_to_ids(train_corpus_words, word_to_id),
                                               tf.int32,
                                               output_shapes=[None])

    # 共起したと判断する単語の距離
    CONTEXT_WINDOW_SIZE = 5
    positive_ds = positive_pipeline(train_word_ids_ds, window_size=CONTEXT_WINDOW_SIZE)
    negative_ds = negative_pipeline(train_corpus_words, word_to_id)

    # 正例に対する負例の比率 (一般的に 5 ~ 10)
    NEGATIVE_SAMPLING_RATIO = 5
    train_ds = skip_grams_with_negative_sampling_dataset(positive_ds,
                                                         negative_ds,
                                                         NEGATIVE_SAMPLING_RATIO)

    # モデルとタスクを定義する
    EMBEDDING_SIZE = 100  # 埋め込み次元数
    criterion = BinaryCrossentropy()
    optimizer = Adam(learning_rate=1e-2)
    model = SkipGramWithNegativeSampling(vocab_size, EMBEDDING_SIZE)
    model.compile(optimizer=optimizer,
                  loss=criterion,
                  )

    # データセットを準備する
    TRAIN_BATCH_SIZE = 2 ** 14
    train_ds = train_ds.batch(TRAIN_BATCH_SIZE)
    train_ds = train_ds.prefetch(buffer_size=tf.data.AUTOTUNE)
    train_ds = train_ds.cache()

    print('caching train data...')
    num_of_steps_per_epoch = sum(1 for _ in tqdm(train_ds))
    print(f'{num_of_steps_per_epoch=}')
    train_ds = train_ds.repeat()

    callbacks = [
        # WordSim353 データセットを使って単語間の類似度を相関係数で確認する
        WordSimilarity353Callback(word_to_id),
    ]
    # 学習する
    model.fit(train_ds,
              steps_per_epoch=num_of_steps_per_epoch,
              epochs=5,
              callbacks=callbacks,
              verbose=1,
              )

    # モデルから学習させたレイヤーの重みを取り出す
    model_layers = {layer.name: layer for layer in model.layers}
    embedding_layer = model_layers['word_embedding']
    word_vectors = embedding_layer.weights[0].numpy()

    # 単語を表すベクトル間のコサイン類似度を計算する
    cs_matrix = cosine_similarity_matrix(word_vectors)
    # ID -> 単語
    id_to_word = {value: key for key, value in word_to_id.items()}

    # いくつか似ているベクトルを持った単語を確認してみる
    example_words = ['you', 'year', 'car', 'toyota']
    for target_word in example_words:
        # ID に変換した上で最も似ている単語とそのベクトルを取り出す
        print(f'The most similar words of "{target_word}"')
        target_word_id = word_to_id[target_word]
        similarities = cs_matrix[target_word_id, :]
        top_n_most_similars = most_similar_words(similarities, top_n=6)
        # 先頭は自分自身になるので取り除く
        next(top_n_most_similars)
        # 単語と類似度を表示する
        for rank, (similar_word_id, similarity) in enumerate(top_n_most_similars, start=1):
            similar_word = id_to_word[similar_word_id]
            print(f'TOP {rank}: {similar_word} = {similarity}')
        print('-' * 50)

    # いくつか類推語を確認してみる
    analogies = [
        ('king', 'man', 'woman'),
        ('took', 'take', 'go'),
        ('cars', 'car', 'child'),
        ('better', 'good', 'bad'),
    ]
    for word1, word2, word3 in analogies:
        print(f'The most similar words of "{word1}" - "{word2}" + "{word3}"')
        word1_vec = word_vectors[word_to_id[word1]]
        word2_vec = word_vectors[word_to_id[word2]]
        word3_vec = word_vectors[word_to_id[word3]]
        new_vec = word1_vec - word2_vec + word3_vec
        similarities = cosine_similarity_one_to_many(new_vec, word_vectors)
        top_n_most_similars = most_similar_words(similarities)
        # 単語と類似度を表示する
        for rank, (similar_word_id, similarity) in enumerate(top_n_most_similars, start=1):
            similar_word = id_to_word[similar_word_id]
            print(f'TOP {rank}: {similar_word} = {similarity}')
        print('-' * 50)


if __name__ == '__main__':
    main()

上記を実行してみよう。 コーパスの前処理に結構時間がかかる。 今回使った環境では 1 分 47 秒かかった。 学習に関しては、WordSim353 データセットを使った内省的評価で 4 エポック目にはサチる感じ。

$ python sgns.py
caching train data...
2802it [01:47, 26.01it/s]
num_of_steps_per_epoch=2802
Epoch 1/5
2802/2802 [==============================] - 50s 18ms/step - loss: 0.3845
Pearson's r score with WordSim353: 0.2879142572631919
Epoch 2/5
2802/2802 [==============================] - 47s 17ms/step - loss: 0.3412
Pearson's r score with WordSim353: 0.367370159567898
Epoch 3/5
2802/2802 [==============================] - 48s 17ms/step - loss: 0.3307
Pearson's r score with WordSim353: 0.3898624474454972
Epoch 4/5
2802/2802 [==============================] - 48s 17ms/step - loss: 0.3248
Pearson's r score with WordSim353: 0.39416965929977094
Epoch 5/5
2802/2802 [==============================] - 48s 17ms/step - loss: 0.3211
Pearson's r score with WordSim353: 0.39503500234447125
The most similar words of "you"
TOP 1: your = 0.6509251594543457
TOP 2: i = 0.6414255499839783
TOP 3: we = 0.569475531578064
TOP 4: re = 0.5692735314369202
TOP 5: someone = 0.5565952658653259
--------------------------------------------------
The most similar words of "year"
TOP 1: earlier = 0.5841510891914368
TOP 2: month = 0.5817509293556213
TOP 3: period = 0.572060763835907
TOP 4: ago = 0.5633273720741272
TOP 5: last = 0.5298276543617249
--------------------------------------------------
The most similar words of "car"
TOP 1: cars = 0.6105561852455139
TOP 2: luxury = 0.5986034870147705
TOP 3: truck = 0.563898503780365
TOP 4: ford = 0.5133273005485535
TOP 5: auto = 0.5039612054824829
--------------------------------------------------
The most similar words of "toyota"
TOP 1: infiniti = 0.6949869394302368
TOP 2: honda = 0.6433103084564209
TOP 3: mazda = 0.6296555995941162
TOP 4: lexus = 0.6275536417961121
TOP 5: motor = 0.6175971627235413
--------------------------------------------------
The most similar words of "king" - "man" + "woman"
TOP 1: king = 0.7013309001922607
TOP 2: woman = 0.6033017039299011
TOP 3: burger = 0.4819037914276123
TOP 4: md = 0.46715104579925537
TOP 5: egg = 0.45565301179885864
--------------------------------------------------
The most similar words of "took" - "take" + "go"
TOP 1: took = 0.6121359467506409
TOP 2: go = 0.6096295714378357
TOP 3: stands = 0.4905664920806885
TOP 4: hammack = 0.45641931891441345
TOP 5: refuge = 0.4478893578052521
--------------------------------------------------
The most similar words of "cars" - "car" + "child"
TOP 1: child = 0.8112377524375916
TOP 2: women = 0.5026379823684692
TOP 3: cars = 0.4889577627182007
TOP 4: patients = 0.4796864092350006
TOP 5: custody = 0.47176921367645264
--------------------------------------------------
The most similar words of "better" - "good" + "bad"
TOP 1: bad = 0.6880872249603271
TOP 2: better = 0.510699987411499
TOP 3: involved = 0.45395123958587646
TOP 4: serious = 0.4192639887332916
TOP 5: hardest = 0.41672468185424805
--------------------------------------------------

CBOW のときは WordSim353 の評価が 0.25 前後だったことを考えると、今回の 0.39 前後という結果はかなり良く見える。 ただし、CBOW の実験ではコンテキストウィンドウサイズが 1 だったのに対し、上記の SGNS では 5 を使っている。 同じコンテキストウィンドウサイズの 1 に揃えると、評価指標は 0.28 前後まで落ちる。

学習が終わってから確認している類似語は、CBOW のときと同じでなかなか良い感じに見える。 一方で、類推語の方はほとんど上手くいっていない。 類推語を解ける位の単語埋め込みを学習するには、もっと大きなコーパスが必要なのだろうか?

ゼロから作るDeep Learning ❷ ―自然言語処理編

ゼロから作るDeep Learning ❷ ―自然言語処理編

  • 作者:斎藤 康毅
  • 発売日: 2018/07/21
  • メディア: 単行本(ソフトカバー)