エイエイレトリック

なぐりがき

NLTKのngram言語モデルを日本語データで使う

以前の記事で、古典的ngram言語モデルについて、NLTKを利用し、英語データセットの結果をまとめました。

eieito.hatenablog.com

単語を分かち書きさえすれば日本語でも実行可能なので、日本語データセットでパープレキシティを算出していきます。

データ

学習データとテストデータの準備。

ngramデータさえあればplain textがなくとも実装できますが、評価のためには学習データと同じ性質のテストデータ(plain text)が必要です。 Wikitext-JAWikipediaデータがtrain/testに分割されているので今回利用します。

データも「秀逸な記事」と「良質な記事」のに分かれているので個々に使う場合とどちらも使う場合を考えます。

前処理

モデルに渡す前の処理を以下の通り統一しました。

  • テキスト処理
    • タイトルおよび項目 (e.g. =ノストラダムス=) の = を削除
    • <block><block> タグは削除
    • *1* のようにシンボル化された文字は元の文字に変換する
      • 分かち書き*/1/* のように細かく分割されてしまい、本来の単語が得られないと考えたため
  • 文分割
  • 分かち書き

NLTKモデル向けの前処理・設定

  • padding
    • nltk.lm.preprocessing.pad_both_ends
  • ngram order
    • bigram (n=2)
  • vocab (Vocabulary)
    • 学習時のcutoff = 2
    • =>1回だけ出現した単語は <UNK>
  • ngram 頻度 (NgramCounter )
    • vocabを使って <UNK> に修正する
    • 具体的には、vocab.lookup(word_list) でvocabにない単語を<UNK>に置き換えた単語リストを生成する

モデル

以前の英語の場合と比較できるので、NLTKのモデル ( https://www.nltk.org/api/nltk.lm.models.html )を利用。

現時点で最新の version 3.7 で実装されているモデルを使います。

  • Lidstone
  • Laplace
  • StupidBackoff
  • AbsoluteDiscountingInterpolated
  • WittenBellInterpolated

KneserNeyInterpolated は以前の記事でも指摘した通り、動作がかなり重いので除外しました。

評価方法

パープレキシティにはnltk.lm.api.LanguageModel.perplexity の実装に従います。

つまり pow(2.0, self.entropy(text_ngrams))

実験

以下の設定で実験。 項目数は Wikitext-JA の統計情報から引用しました。

sentence数は ja_sentence_segmenter での分割結果。タイトルおよび項目(一単語だけの行)も1つのsentenceとみなしている。

学習データ 学習データの項目数 学習データのsentence数 評価データ 評価データの項目数
F (秀逸な記事) Train_Data_F.txt 69 30,975 Test_Data_F.txt 8
G (良質な記事) Train_Data_G.txt 1,139 301,019 Test_Data_G.txt 142
F+G Train_Data_F.txt + Train_Data_G.txt 1,208 331,994 Test_Data_F.txt, Test_Data_G.txt 150

コードはGoogle Colabで実行。実際のNotebook は gistにアップロードした。

モデルごと比較

表にまとめる。小数点3桁で丸めています。

F: Test_Data_F G: Test_Data_G F+G: Test_Data_F F+G: Test_Data_G PentreeBank (参考)
Lidstone 9.643 9.989 9.780 9.978 9.291
Laplace 10.110 10.458 10.260 10.449 9.687
StupidBackoff 7.021 7.285 7.100 7.266 7.416
AbsoluteDiscountingInterpolated 7.267 7.504 7.302 7.480 7.622
WittenBellInterpolated 7.286 7.501 7.305 7.477

スムージング手法 (StupidBackoff, AbsoluteDiscountingInterpolated, WittenBellInterpolated) のパープレキシティが低く、StupidBackoff が一番低い結果となりました。

英語 (Pentreebank) データを使った場合( gist )と同様の傾向といえます。

また、データが増えてもモデルごとのパープレキシティの傾向に影響はないです。

StupidBackoff

紹介していなかったので簡単に説明をします。

StupidBackoff は単純なBackoffモデルです。bigramについて、  w_i, w_{i+1} のスコアは以下で求めます。

学習データにある場合、   count(w_i, w_{i+1}) / N という単純な出現確率です。

  • count() は 学習データにおける 出現回数
  • N は学習データにおける全ての単語ngramの合計

NLTK実装では分母に NgramCounter.N を使っているので、unigramの頻度の合計数 + bigramの頻度の合計数となりそうです。

  w_i, w_{i+1} が未知のペアだった場合  \alpha * freq(w_i)

  •  \alpha は定数
  • freq() は楽手データにおける出現 確率

つまり w_i のみでスコアを算出します。 Backoff という手法の名前は、今見ているものより1つorderが小さいngramを使う、というイメージから来ています。

計算の最適化

今回、F (秀逸な記事) とF+G (秀逸な記事+良質な記事)はパープレキシティの算出にかなり時間がかかりました。特に AbsoluteDiscountingInterpolatedが一番時間がかかっていました。

Pentreebankのテストデータは3761 sentence ( torchtext より)で一瞬だったのであまり意識していませんでしたが、基本的にはスケールしない実装なのでデータ量が増えるほど重くなるといえそうです。

特に入力 ( ngram ) に対して毎回スコアを算出しているのがボトルネックになっています。 スコアはある程度キャッシュしたいです。 また、テストデータのngramは出現順に渡していますが、ngramとその頻度を渡す方が効率的な気もします。

   def entropy(self, text_ngrams):
        """Calculate cross-entropy of model for given evaluation text.

        :param Iterable(tuple(str)) text_ngrams: A sequence of ngram tuples.
        :rtype: float

        """
        return -1 * _mean(
            [self.logscore(ngram[-1], ngram[:-1]) for ngram in text_ngrams]
        )

https://www.nltk.org/_modules/nltk/lm/api.html#LanguageModel.entropy

簡単にラッパークラスを作ってみました。 entropy関数で最適化した実装を実行するだけのクラスです。

class LanguageModelWrapper:
  def __init__(self, model):
    self.model = model

  def entropy(self, text_ngrams):
    """Calculate cross-entropy of model for given evaluation text."""
    score_list = []
    text_ngrams_counter = Counter(text_ngrams)
    for ngram, freq in text_ngrams_counter.items():
        score = self.model.logscore(ngram[-1], ngram[:-1])
        score_list.append([score] * freq)
    return -1 * _mean(list(chain.from_iterable(score_list)))

実際に実行してみると、元々の実装を使うよりかなり速くなりました。 以下の表は %%time で測った Wall timeを使ってます。

F G F+G
default setting 1min 31s 59min 19s 1h 8min 42s
LanguageModelWrapper 36.2 s 13min 55s 16min 56s

データ量が多い場合(GとF+G) は時間が 1/6ぐらいに短縮できています。

Gist

gist.github.com