エイエイレトリック

なぐりがき

都道府県・市区町村をgeopandasで可視化

市区町村の読みって意外と曖昧性があるなと気づき、調べてみることにしました。

市区町村の位置を可視化するにあたり、geopandas を使ってみたのでメモします。

詳細については一番下にgistで埋め込んでます。

要点

geopandas で日本地図を扱うには。

都道府県だけであれば japanmap のポリゴンデータをつかう

市区町村は 国土数値情報 | 行政区域データシェープファイルを使う

課題 (未解決)

geopandasで試行錯誤した部分について

cmap でカテゴリ色づけが順序通りされない

島を含むと画像が縦長・横長になってしまう

  • 沖縄県・鹿児島県の奄美・東京都の島しょ部が外れ値のように存在しているのが原因
  • 日本列島をメインに出力する場合は除外する必要がある

可視化結果

上記点を踏まえて可視化しました。

都道府県の可視化

市区町村の数

「町」の割合、「町」の読みは「まち」か「ちょう」か

「町」の可視化

「町」の割合が高い都道府県 TOP5

都道府県
鳥取県     0.736842
北海道     0.697297
和歌山県    0.666667
徳島県     0.625000
宮城県     0.571429

「町」の読み方は西日本・東日本で分かれている印象。 九州は混在している。

市区町村の可視化

「町」率が高い鳥取県の読み方を可視化。全部読みは「ちょう」。

鳥取県の「町」の読み方

「まち」「ちょう」の読みが混在している九州を可視化

九州地方の「町」の読み方

福岡県・大分県あたりは「ちょう」が集中していて、それ以外は「まち」になっている。

gist

可視化の出力に使ったコードは以下。

japan_map_geopandas_都道府県_市区町村.ipynb · GitHub

もう少し深掘りできそう。

DjangoCongress2022 でドキュメント生成について発表しました

https://django.connpass.com/event/259310/ より

11/12に DjangoCongress2022 で「OAI3を使ったDjango REST frameworkのドキュメント生成とカスタマイズ」というタイトルで発表しました。

朝イチの発表にも関わらず、聴講してくださった方々、ありがとうございました。

資料

スライドとサンプルコードは公開しています。

録画も今後サイトに公開されるようです。

SpeakerDeck

speakerdeck.com

Github

github.com

雑感

テック系のイベントで発表するのははじめてだったので、また、コロナ後にオフラインで発表するのははじめてだったので、少し緊張しました。

Djangoは学生時代のインターンではじめて使いました。 インターンの時は機能の修正がメインで、アプリケーションがどうやって動いているかなど意識せず開発していた気がします。

今では一からDjangoのアプリケーションを作成したり、コアとなる部分を開発したり、Djangoの開発がメインの仕事です。 今現在の集大成として、こうしてDjangoの発表の機会がもらえたのは大変ありがたく思います。

発表したテーマに取り組むきっかけは、既存のDRFのアプリケーションのリファクタリングでした。

そのコードでは、すでに開発が止まっている django-rest-swagger を使っていたので移行が必要でした。 既存の設定をどうにか維持しながら最新のDRFにも対応させるため四苦八苦しました。

OAI3に対応した drf-spectacular に移行するのが難しいカスタマイズがあったため、DRF の標準機能をカスタマイズすることで移行しました。

その際に調べたことをまとめたのが以下のブログです。

eieito.hatenablog.com

自分の理解もこの頃に比べれば深まっていたので、ブログ記事ではわかりにく部分をほぐして発表できたと思います。

誰かの参考になれば嬉しいです。

最後に、素晴らしいイベントを運営してくださった方々に感謝です。

djangoのFileBasedCacheの使い方と仕組み

FileBasedCache の使い方と仕組みについてメモします。

データベースを用意するまでもないけど、情報をキャッシュしたいときによさそうです。

ファイルで出力するのでキャッシュしたかどうか、キャッシュが残っているかどうかがわかりやすいためデバッグしやすい気がします。

また、キャッシュの中身も比較的簡単に確認できます。

動作環境

使い方

設定

設定ファイル (settings.py, settings/*.py など) でCACHESのバックエンド (BACKEND) とファイルの保存先 (LOCATION) を指定します。

LOCATIONは書き込み可能な場所を指定します。 ちなみに存在しないディレクトリでもapp側で作成します。

dockerの場合はvolumesで指定した場所の必要があります。 BASE_DIR ではなくroot近くの場所になりそうです。

CACHES = {
    'default': {
        'BACKEND': 'django.core.cache.backends.filebased.FileBasedCache',
        'LOCATION': BASE_DIR / 'tmp/django_cache',
        # 'KEY_PREFIX': 'prefix',
    }
}

開発・本番といった複数のサーバーを起動していて、それぞれのキャッシュ結果を区別したい時は KEY_PREFIX を設定できます。

参考: Cache key prefixing

キャッシュを保存・呼び出す

バックエンドの仕組みは関係なく、 django.core.cache.cache からキャッシュを利用できます。

値のセットにはオプションとして timeout と version も設定できます。

cache.set(key, value, timeout=DEFAULT_TIMEOUT, version=None)
cache.get(key)

参考: Basic usage

FileBasedCache の仕組み

ソースコード ( django/filebased.py at main · django/django · GitHub ) を確認しながら仕組みを理解していきます。

実際に下記のようにキャッシュをセットしました。

cache.set('text', 'cache_text')

.djcache ファイルが指定した保存場所に確認できます。 読み込み・書き込み権限はあるみたいです。

1つのkeyごとに1つのファイルが作成されます。

ls -al BASE_DIR/tmp/django_cache
-rw-------  1 USER  staff  49 Oct DD HH:MM 4720993faefa6c12c92c4b03540e3e17.djcache

ファイル名

この長いファイル名は MD5 で暗号化された key です。

ですが、そのままkeyを渡しただけでは再現しません。

from django.utils.crypto import md5
md5("text".encode(), usedforsecurity=False).hexdigest()
'1cb251ec0d568de6a929b520c4aed8d1'

ファイル名のオプションとして prefix と version が指定できます。

つまり key, prefix, version の3つをフォーマットした文字列をファイル名とします。

フォーマット方法は KEY_FUNCTION で設定できますが、デフォルトでは django.core.cache.backends.base.default_key_func を使っています。 具体的には "%s:%s:%s" % (key_prefix, version, key) です。

keyprefix のデフォルトは '' (空文字)、 versionのデフォルトは 1 なので、その通り試してみると無事ファイル名が再現しました。

md5(":1:text".encode(), usedforsecurity=False).hexdigest()
'4720993faefa6c12c92c4b03540e3e17'

保存方法

.djcache ファイルは読み込みできるようなので、value を確認します。

ファイルは timeout (expired) の情報 と value を持ちます。

timeout は pickle, value は pickle と zlib で圧縮しています。

その順に読み込むと数値 (エポック秒) とvalueが取得できました。

with open("BASE_DIR/tmp/django_cache/4720993faefa6c12c92c4b03540e3e17.djcache", "rb")as f:
     exp = pickle.load(f)
     value = pickle.loads(zlib.decompress(f.read()))

print(exp, value)
>> 16670XXXXX.XXXXX cache_text

ちなみにデフォルトのtimeoutは300 second なので、永続化の場合 None を設定する必要があります

参考: TIMEOUT

参考サイト

pythonのunittest.Testcaseでmockする・patchする

テストコードで活躍する mock。 だが毎回 これってどこをmockすればいいんだ…… と必要以上に mock.patch を書いてしまいます。

python の公式ドキュメントや解説記事では、mock 単体の振る舞いについて紹介していることが多く、最初の頃は でも実際どう組み込めばいいの? と悩むことが多かったです。

そんなわけで、実装した python のクラスに対して unittest.Testcase のテストコードで mock を使ってみて、仕様を確認してみます。

実行環境

  • macOS 12.4 (Monterey)
  • python3.8

テストするクラス

テストで mock が必要になるのは、個人的には大きく 2 つあるかなと個人的に思っています。

  • 外部にアクセスする
    • e.g. requests, boto3, データベース
  • テストケース を考えるのが面倒 が複雑
    • e.g. クラスの要素でクラスを持つ, 実行環境(本番/開発)によって動作が異なる

外部アクセスに関してはクラスの内部で requests を呼び出すみたいなこともあるので、後者とも関連してきます。

よってまずは テストケースを考えるのが面倒なのでmockしてしまおう と思えるクラスを作ります。

preprocess.pyRSS feed*1 から情報を抽出するクラス RSSParser を宣言します。 python 標準の ElementTree を使って解析するクラスです。

import xml.etree.ElementTree as ET


class RSSParser:
    def __init__(self, text):
        self.text = text
        self.root = ET.fromstring(text)

    def get_title_name(self):
        title_element = self.root.find("channel").find("title")
        return title_element.text

非常にまどろっこしやり方をしていますが、テストケースを色々試すための実装ということで目をつぶってください。

以下のようなファイル構成でテストコード test_preprocess.py を置きます。

.
└── api
    ├── preprocess.py
    └── tests
        └── test_preprocess.py

unittest の実行は以下。

python -m unittest api/tests/test_preprocess.py

テストを書く

TestRSSParser を実装します。 初期値で渡すテキスト (XML 形式) を考えるのが面倒なので、mock で解決しましょう。

※ ElementTree は正しい仕様のテキストを渡せば正しく返ってくるという強い仮定のもと mock します。

どこにpatchするか

import xml.etree.ElementTree as ET を mock するにはどこに patch すればいいのか。

mock 覚えたての pythonista は絶対つまずくだろう部分です。

from unittest import TestCase
from unittest.mock import Mock, patch

from api.preprocess import RSSParser

# NG pattern
@patch("xml.etree.ElementTree")
class TestRSSParser(TestCase):
    def test_et_mock(self, et):
        p = RSSParser("text")

>> 
Traceback (most recent call last):
  File "/XXX/.pyenv/versions/3.8.0/lib/python3.8/unittest/mock.py", line 1342, in patched
    return func(*newargs, **newkeywargs)
  File "/XXX/api/tests/test_preprocess.py", line 11, in test_et_fail
    p = RSSParser("test_et_mock")
  File "/XXX/api/preprocess.py", line 7, in __init__
    self.root = ET.fromstring(text)
  File "/XXX/.pyenv/versions/3.8.0/lib/python3.8/xml/etree/ElementTree.py", line 1321, in XML
    return parser.close()
  File "<string>", line None
xml.etree.ElementTree.ParseError: syntax error: line 1, column 0

これは patch できません。 xml.etree.ET でも同様。

なぜできないのか。 その理由は どこにパッチするか に書いてあります。

基本的な原則は、オブジェクトが ルックアップ されるところにパッチすることです。

テストしたいクラス RSSParserself.root = ET.fromstring(text) を呼び出す前に、 インポートされた ET を patch します。 つまり正しい patch の場所は api.preprocess.ET です。

@patch("api.preprocess.ET")
class TestMain(TestCase):
    def test_et_mock(self, et):
        p = RSSParser("text")

patchした値、 MagicMock

patch された ET にはデフォルトで MagicMock オブジェクトが入っています。

MagicMock は属性、関数にアクセスすると新しい MagicMock オブジェクトを返します。

mock = MagicMock()

mock.fromstring
<MagicMock name='mock.fromstring' id='4373352064'>

mock.fromstring()
<MagicMock name='mock.fromstring()' id='4373475632'>

mock.fromstring("text")
<MagicMock name='mock.fromstring()' id='4373475632'>

よって、 self.root = ET.fromstring(text) では、self.root に MagicMock が代入されます。

mockの返却値を設定する return_value

MagicMock そのままでは動作が正しいかどうかのテストができないので、return_value を使って値を設定します。

self.root"fromstring_return_value" というテキストが入るようにしてみます。

@patch("api.preprocess.ET")
class TestRSSParser(TestCase):
    def test_et_fromstring(self, et):
        et.fromstring = Mock(return_value="fromstring_return_value")
        # 別の書き方
        # et.fromstring.return_value = "fromstring_return_value"
        p = RSSParser("test_et_mock")
        self.assertEqual(p.root, "fromstring_return_value")

上記は ET を patch しているので et.fromstringreturn_value を設定していますが、 fromstring を直接 patch すればデコレーターの引数でも設定できます。

@patch("api.preprocess.ET.fromstring", return_value="fromstring_return_value")
class TestRSSParser2(TestCase):
    def test_et_mock(self, et):
        p = RSSParser("test_et_mock")
        self.assertEqual(p.root, "fromstring_return_value")

深い場所にある属性をmockする nested mock/ chained call

get_title_name をテストします。 返却値が self.root.find("channel").find("title").text と深い場所にある場合はどうすればいいでしょうか。

まず .text で string を返す mock を用意します。

title_mock = Mock()
title_mock.text = "title_string"

次に、RSSParser に title_mock を組み込みます。 単純に代入しようとすると以下になりますが、これだと "title_string" は返ってきません。

@patch("api.preprocess.ET")
class TestRSSParser(TestCase):
    # NG pattern
    def test_get_title_name(self, et):
        parser = RSSParser("text")
        title_mock = Mock()
        title_mock.text = "title_string"
        parser.root.find.find.return_value = title_mock
        self.assertEqual(parser.get_title_name(), "title_string")

>>
AssertionError: <MagicMock name='ET.fromstring().find().find().text' id='4373959920'> != 'title_string'

chained call をモックする にあるように return_value を活用します。

self.root.find("channel") の返す値が .find("title") を呼び出すので、 parser.root.find.find ではなく parser.root.find.return_value.find. とすべきです。

これを踏まえると以下のように書けます。

@patch("api.preprocess.ET")
class TestRSSParser(TestCase):
    def test_get_title_name(self, et):
        parser = RSSParser("text")
        title_mock = Mock()
        title_mock.text = "title_string"
        parser.root.find.return_value.find.return_value = title_mock
        # configure_mockを使った別の書き方
        # parser.root.configure_mock(**{"find.return_value.find.return_value": title_mock})
        self.assertEqual(parser.get_title_name(), "title_string")

また、title_mock を直接渡すことも可能です。

        parser.root.find.return_value.find.return_value = Mock(text="title_string")

参考資料

細々とした工夫についても別途まとめる予定。

textlintのインストールから新しいルール作成までやってみた

textlintを使ったルールを作成しようと四苦八苦したので作業ログを残しておこうと思います。

javascript初心者 (簡単なチュートリアルを一通り終えた程度) が実装していることを先に述べておきます。

MacBook PromacOS 12.4 Montereyで確認しています。

npmのインストール

  • npmはすでにローカルにインストールされていたのですが、インストールしなおしました。
  • このstackoverflow 記事と同様、パッケージをインストールしようとすると Error: EACCES: permission denied が発生して権限周りがうやむやになっていることがわかったからです。
  • npmの公式ページ (Downloading and installing Node.js and npm | npm Docs) をみてnodeのバージョンコントロールn をインストールしました。
    • nodeのバージョンを指定するとnpmもついてくるので、これで環境は整いました。
% n ls
node/10.19.0
node/14.4.0
node/18.1.0
% node -v
v18.1.0
% npm -v
8.8.0

前述のnpmの公式ページではnodeのバージョンマネージャーでnodeとnpmをインストールすることを推奨していました。(下記) ネット上では色々紹介されていますが公式サイトに従っておくのが無難そうだなと思います。

We strongly recommend using a Node version manager like nvm to install Node.js and npm.

textlintのインストール

textlintのドキュメント Getting Started with textlint · textlint の通り導入していきます。

  • 新しくディレクトtextlint-demo を作って動作確認しました。
  • installのオプションに --save-dev をつけることでローカル textlint-demo/package.json が更新されます。
 {
   "name": "textlint-demo",
   ...,
   "devDependencies": {
     "textlint": "^12.1.1",
     "textlint-rule-no-todo": "^2.0.1"
   },
   ...
 }
 
  • 実行コマンドは ./node_modules/.bin/textlint と書いてありますが、ローカルの実行では npx textlint のように npx を使うこともできました。
    • ドキュメントの他のページでは npx 使われているので多分問題ないはずです。

textlintのルールを新規作成

textlintの実行環境が整ったのでルールをつくります。

オプションの設定

  • options = {} で引数を受け取ることができます
  • optionsを使っているルールをみて、どういう使い方をすればいいかなんとなく理解。
  • ルールを使う際は .textlintrc で指定できます。
 {
   "rules": {
     "max-number-of-lines": {
         "max" : 300
     }
   }
 }

テストの実装

  • create-textlint-rule のテンプレートでルールを作成すると テストコード test/index-test.js が作成されます。以下のようなコードです。
  "use strict";
  import TextLintTester from "textlint-tester";
  const tester = new TextLintTester();
  import rule from "../src/index";
  // ruleName, rule, { valid, invalid }
  tester.run("rule", rule, {
      valid: [ ....

作ったルールをローカルで実行する

npmに登録する前にローカル環境 textlint-demo で実行を確認します。

  • まず、 作成したルール textlint-my-rule 配下において、npm run build でルールをビルドし textlint-my-rule/lib/ 配下にファイルが生成されたことを確認しました。
    • treeで表示すると以下のような感じ。
├── README.md
├── lib
│   ├── index.js
│   └── index.js.map
...
├── package-lock.json
├── package.json
├── src
│   └── index.js
└── test
    └── index-test.js
  • textlint-demo ディレクトリに移動し、textlint-my-rule へのパスでインストールします。
    • 同じ階層にあるなら npm install ../textlint-my-rule --save-dev
    • 以下のようにsymbolic linkが貼られます。
%  ls -al ./node_modules/textlint-rule-my-rule
lrwxr-xr-x  1 name  staff  34  7 24 00:00 ./node_modules/textlint-rule-my-rule -> ../../textlint-rule-my-rule

参考: https://docs.npmjs.com/cli/v8/commands/npm-install

VS Codeでtextlintを実行する

毎回textlintコマンドを叩くのは大変なのでVS Code拡張機能 vscode-textlint をインストールしました。

これで追加したルールが適用されるかすぐ確認できるようになりました。

つづく

次回はnpmでの公開までやっていきたいです。

ちなみにこのブログはscrapboxで下書きした後markdownに変換し、textlint-demo 配下にファイルを置き、チェックして作成しました。

試しにいれてみたルールもあるのですが .textlintrc は以下の通りです。

{
  "filters": {},
  "rules": {
    "ja-no-mixed-period": true,
    "ja-hiragana-keishikimeishi": true,
    "no-doubled-conjunctive-particle-ga": true,
    "no-orthographic-variants": true,
    "max-ten": {
      "max": 3,
      "strict": false,
      "touten": "",
      "kuten": ""
    }
  }
}

NLTKのngram言語モデルを日本語データで使う

以前の記事で、古典的ngram言語モデルについて、NLTKを利用し、英語データセットの結果をまとめました。

eieito.hatenablog.com

単語を分かち書きさえすれば日本語でも実行可能なので、日本語データセットでパープレキシティを算出していきます。

データ

学習データとテストデータの準備。

ngramデータさえあればplain textがなくとも実装できますが、評価のためには学習データと同じ性質のテストデータ(plain text)が必要です。 Wikitext-JAWikipediaデータがtrain/testに分割されているので今回利用します。

データも「秀逸な記事」と「良質な記事」のに分かれているので個々に使う場合とどちらも使う場合を考えます。

前処理

モデルに渡す前の処理を以下の通り統一しました。

  • テキスト処理
    • タイトルおよび項目 (e.g. =ノストラダムス=) の = を削除
    • <block><block> タグは削除
    • *1* のようにシンボル化された文字は元の文字に変換する
      • 分かち書き*/1/* のように細かく分割されてしまい、本来の単語が得られないと考えたため
  • 文分割
  • 分かち書き

NLTKモデル向けの前処理・設定

  • padding
    • nltk.lm.preprocessing.pad_both_ends
  • ngram order
    • bigram (n=2)
  • vocab (Vocabulary)
    • 学習時のcutoff = 2
    • =>1回だけ出現した単語は <UNK>
  • ngram 頻度 (NgramCounter )
    • vocabを使って <UNK> に修正する
    • 具体的には、vocab.lookup(word_list) でvocabにない単語を<UNK>に置き換えた単語リストを生成する

モデル

以前の英語の場合と比較できるので、NLTKのモデル ( https://www.nltk.org/api/nltk.lm.models.html )を利用。

現時点で最新の version 3.7 で実装されているモデルを使います。

  • Lidstone
  • Laplace
  • StupidBackoff
  • AbsoluteDiscountingInterpolated
  • WittenBellInterpolated

KneserNeyInterpolated は以前の記事でも指摘した通り、動作がかなり重いので除外しました。

評価方法

パープレキシティにはnltk.lm.api.LanguageModel.perplexity の実装に従います。

つまり pow(2.0, self.entropy(text_ngrams))

実験

以下の設定で実験。 項目数は Wikitext-JA の統計情報から引用しました。

sentence数は ja_sentence_segmenter での分割結果。タイトルおよび項目(一単語だけの行)も1つのsentenceとみなしている。

学習データ 学習データの項目数 学習データのsentence数 評価データ 評価データの項目数
F (秀逸な記事) Train_Data_F.txt 69 30,975 Test_Data_F.txt 8
G (良質な記事) Train_Data_G.txt 1,139 301,019 Test_Data_G.txt 142
F+G Train_Data_F.txt + Train_Data_G.txt 1,208 331,994 Test_Data_F.txt, Test_Data_G.txt 150

コードはGoogle Colabで実行。実際のNotebook は gistにアップロードした。

モデルごと比較

表にまとめる。小数点3桁で丸めています。

F: Test_Data_F G: Test_Data_G F+G: Test_Data_F F+G: Test_Data_G PentreeBank (参考)
Lidstone 9.643 9.989 9.780 9.978 9.291
Laplace 10.110 10.458 10.260 10.449 9.687
StupidBackoff 7.021 7.285 7.100 7.266 7.416
AbsoluteDiscountingInterpolated 7.267 7.504 7.302 7.480 7.622
WittenBellInterpolated 7.286 7.501 7.305 7.477

スムージング手法 (StupidBackoff, AbsoluteDiscountingInterpolated, WittenBellInterpolated) のパープレキシティが低く、StupidBackoff が一番低い結果となりました。

英語 (Pentreebank) データを使った場合( gist )と同様の傾向といえます。

また、データが増えてもモデルごとのパープレキシティの傾向に影響はないです。

StupidBackoff

紹介していなかったので簡単に説明をします。

StupidBackoff は単純なBackoffモデルです。bigramについて、  w_i, w_{i+1} のスコアは以下で求めます。

学習データにある場合、   count(w_i, w_{i+1}) / N という単純な出現確率です。

  • count() は 学習データにおける 出現回数
  • N は学習データにおける全ての単語ngramの合計

NLTK実装では分母に NgramCounter.N を使っているので、unigramの頻度の合計数 + bigramの頻度の合計数となりそうです。

  w_i, w_{i+1} が未知のペアだった場合  \alpha * freq(w_i)

  •  \alpha は定数
  • freq() は楽手データにおける出現 確率

つまり w_i のみでスコアを算出します。 Backoff という手法の名前は、今見ているものより1つorderが小さいngramを使う、というイメージから来ています。

計算の最適化

今回、F (秀逸な記事) とF+G (秀逸な記事+良質な記事)はパープレキシティの算出にかなり時間がかかりました。特に AbsoluteDiscountingInterpolatedが一番時間がかかっていました。

Pentreebankのテストデータは3761 sentence ( torchtext より)で一瞬だったのであまり意識していませんでしたが、基本的にはスケールしない実装なのでデータ量が増えるほど重くなるといえそうです。

特に入力 ( ngram ) に対して毎回スコアを算出しているのがボトルネックになっています。 スコアはある程度キャッシュしたいです。 また、テストデータのngramは出現順に渡していますが、ngramとその頻度を渡す方が効率的な気もします。

   def entropy(self, text_ngrams):
        """Calculate cross-entropy of model for given evaluation text.

        :param Iterable(tuple(str)) text_ngrams: A sequence of ngram tuples.
        :rtype: float

        """
        return -1 * _mean(
            [self.logscore(ngram[-1], ngram[:-1]) for ngram in text_ngrams]
        )

https://www.nltk.org/_modules/nltk/lm/api.html#LanguageModel.entropy

簡単にラッパークラスを作ってみました。 entropy関数で最適化した実装を実行するだけのクラスです。

class LanguageModelWrapper:
  def __init__(self, model):
    self.model = model

  def entropy(self, text_ngrams):
    """Calculate cross-entropy of model for given evaluation text."""
    score_list = []
    text_ngrams_counter = Counter(text_ngrams)
    for ngram, freq in text_ngrams_counter.items():
        score = self.model.logscore(ngram[-1], ngram[:-1])
        score_list.append([score] * freq)
    return -1 * _mean(list(chain.from_iterable(score_list)))

実際に実行してみると、元々の実装を使うよりかなり速くなりました。 以下の表は %%time で測った Wall timeを使ってます。

F G F+G
default setting 1min 31s 59min 19s 1h 8min 42s
LanguageModelWrapper 36.2 s 13min 55s 16min 56s

データ量が多い場合(GとF+G) は時間が 1/6ぐらいに短縮できています。

Gist

gist.github.com

ngram言語モデルについてまとめる (neural language model)

4記事にわたり、複数の古典的ngram言語モデルについて試しに実装してきました。

torchtextのデータセットを使ってきたので、pytorchで簡単な言語モデルを作ってみます。

元となる論文があるわけではないですが、ネット上に多数実装が多数あるので、それらを参考にしました。

実装コードは gist を参照してください。一番下に埋め込んでます。

参考資料について軽くまとめます。

アーキテクチャ

NgramModel(
  (embedding): Embedding(28782, 100)
  (linear1): Linear(in_features=200, out_features=128, bias=True)
  (linear2): Linear(in_features=128, out_features=28782, bias=True)
)
  • 入力単語を語彙数×次元数で表現する Embedding
  • 活性化関数 Relu
  • 線形変換1 Linear
  • 線形変換2 Linear
  • 確率値として出力 log_softmax

という比較的簡単な構造。

Embedding層はいわゆるword2vecとして単語特徴量を学習する。 Mikolovの論文 (https://arxiv.org/abs/1301.3781) におけるSkipgramやCBOWとは厳密には異なる。

アーキテクチャ参考

DatasetとDataloader

上記アーキテクチャにおいて参考にしたコードたちはデータを直接コード上に書いてバッチサイズも設定していないことがある。なので実践向き (大きいデータ向き) ではない。

Dataloaderを設定して、ngramをバッチサイズごとイテレートした。

今回使っている torchtext.datasets.WikiText2 はすでにイテレーターなのでDatasetとDataloaderどっちも使う必要があるのかわからなかったが、仕組みを理解したいので使うことにした。

  • Datasetではvocabの設定をし、単語を数値に変換したngram単語列と正解単語を返す
  • Dataloaderはバッチサイズを設定し、concatしたデータを返す
    • drop_last=True でバッチサイズより小さいデータは学習に使わない

Text classification with the torchtext library — PyTorch Tutorials 1.11.0+cu102 documentation (テキスト分類タスク向け)の実装に従ったので、もしかしたら別の方法もあるかもしれない。

学習・評価

今回10epochで実験。

学習のloss

WikiText2の場合テストデータのエントロピー4.1636 、 PentTreeDatasetの場合は 4.1657 と小さくなった。 今まで実装してきたモデルでエントロピーが5を切っているモデルはなかったのでかなり良いスコアといえる。

さいごに

公式のtutorial (Welcome to PyTorch Tutorials — PyTorch Tutorials 1.11.0+cu102 documentation) で改めてpytorchを勉強したが、モデルの実装よりデータの扱いについての情報が少なくて結構困った。

日本語ngramデータは公開がされているので頻度ベースのモデルは実装できそうだ。今後試したい。

gist.github.com