識別器の作成で学んだこと(マルチラベル分類と不均衡データの取り扱い)
概要
機械学習の勉強として画像の識別器を作成しました。
偏りが大きいデータをどのように扱うかの勉強になりました。
イントロダクション
はじめまして。機械学習に興味を持ち、勉強をしています。
この度は勉強の一環(と趣味)として画像の識別器を作成しました。
識別器は機械学習の習作のファーストステップとしてまず考えられる題材ですが、実際に作ってみると想像してたよりも多くの事が勉強ができたので、まとめておこうと思い執筆しました。
ブログ等を書いた経験はないので、読みにくいかと思いますが読んでいただければ嬉しいです。
今回は、東方Projectのキャラクターの画像を与えたときにどのキャラクターが描かれているかを判別する、という問題設定で取り組みました。
手法
データ収集
pixivの画像データを使用しました。
各イラストに付けられているタグを教師ラベルとして利用します。
pixivのデータ収集にはPixivPyを利用しました。
画像はアニメーションでない静止画の一枚絵だけ取得しました。(複数絵投稿はキャラクターのタグとイラストの対応付けができないため)
タグからキャラクターのラベルを設定しますが、1キャラも設定できなかったイラストについては学習時に除いています。
また、複数のキャラクターのタグがつけられているデータもあるので、各データは1個以上のラベルを持つことになります。
集計時期は主に2022年1月頃から約2ヶ月間で、その後新たに投稿された作品については気が向いたときに収集しています。
記事執筆時点で、クラス数は176、ラベル付けが出来た画像は1836432枚になります。
データの収集方法や中身の詳細については後日別に記事を作成します。
学習
ソースはPyTorchの転移学習のチュートリアルを参考にしながら作成しました。
識別器のモデルは、まずは計算量の少ないmobilenetv3_small_050
(Timm)を使います。
入力画像の端ピクセルの平均値を画像の短辺に加えて正方形に整えてから、224*224にリサイズし、torchivision.transforms
のモジュールを使ってコントラスト/彩度/色相の変化、左右ランダムフリップ、アフィン変換をしてAugmentationをしています。
データの性質上、マルチラベル学習となります。
損失関数は二値クロスエントロピー(BCE)を使用しています。
モデル出力層の活性化にロジスティックシグモイド関数を適用して、全クラスに対して[0, 1]の推測値を出力します。
結果 #1
まずはバッチサイズは64でデータ全体を2エポック学習させたモデルを作成しました。
いくつか画像を識別させてみます。
とりあえず学習は出来ているようです。複数キャラの画像も識別できています。
但し学習データにおいてデータ数の多いクラスについては高い精度で予測できる一方、データ数の少ないクラスの識別性能はイマイチなようです。
上の例ではデータ数の少ないキャラよりも特徴が似ていてデータ数の多い(不正解の)キャラの方がスコアが高くなってしまっています。
では、データの数の偏りを見てみます。
教師データ全体の各クラスのデータ数を降順に並べてみます。
最も多いクラスでは127479枚、最も少ないクラスでは26枚であり、約5000倍もの差がありました。
データ数と識別性能の関係を見てみます。
教師データ全体をモデルに推定させて、データ数と正解ラベルのモデルの出力平均の散布図を各クラスでプロットします。(学習データ時に使ったデータそのものを評価しているのでよろしくないですが…)
やはり、データ数が少ないクラスに対しては識別性能が低く、データ数が多いクラスに対しては識別性能が高いようです。
便宜上、データ数900未満のクラス、900以上4000未満のクラス、4000以上のクラスで3分割して考えます。(境界値は見た目から決めました)
ピンクの線より左のクラスをマイナークラス、ピンクより右で橙色の線より左のクラスをミドルクラス、橙色より右のクラスをメジャークラスのように呼ぶことにします。
それぞれのモデル出力平均を求めるとメジャークラスでは0.60, ミドルクラスでは0.37, マイナークラスでは0.18でした。
ミドルクラスやマイナークラスの識別性能はもっと改善できそうに思えたので調査してみました。
改善手法調査
不均衡データ(データ数に偏りがある場合)の対策はいくつかあるようですが、以下の二つの方針で進めます。
- 教師データを整える
- 損失関数を工夫する
教師データを整える
データ数の偏りそのものを改善する方針です。
調べてみたところ、数が少ないデータ群を増やすover-sampling, 数が多いデータ群を減らすunder-sampling の二つの指針あるようです。
どちらも元データ群からデータを選択して不均衡を解消するということは同じなので、手法としてはデータセットから抽出するときの確率を設定してサンプリングすることで両方実現できそうです。(つまり多項分布からのサンプリング)
全てのデータが単一ラベルの場合、各データがもつラベルのクラスのデータ数の逆数比の確率でサンプリングすれば各クラスで同じデータ数が期待出来ます(ややこしい日本語ですが、1000枚のデータ数を持つキャラAと100枚のデータ数のキャラBがいたら、キャラAの全画像を1/1000、キャラBの全画像を1/100の確率でそれぞれ抽出すると期待値はどちらも1枚)。
ただし、複数ラベル含まれているデータがあるため、画像一枚ずつのラベルからその抽出確率を決定する場合は、抽出後のデータ数期待値は各クラスで均等には出来ません。
とはいっても厳密に均等にすることを求めているわけではないので、単純な方法でどの程度改善が見られるか試してみます。
今回は複数ラベルを持つデータについては、最もマイナーなクラスとして扱うこととします。
多項分布のサンプリングをするtorch.multinomial
を使って実装することが出来そうです。
torch.multinomial
は第一引数に確率分布(の割合)、第二引数に抽出後サンプルのサイズを渡すとサンプリングした結果を返します。(ドキュメント)
先程のキャラABの例に適用すると、元データセットが[キャラAの画像, キャラBの画像, キャラAの画像, ..]という順で並んでいたら、torch.multinomial([1/1000, 1/100, 1/1000, ..], 10)
のようにデータ数の逆数比を第一引数、第二引数に抽出後のサイズを10にして渡すと、AとBが同数(5枚ずつ)のサンプリングが期待できます。
これを使ってサンプリングをします。
def sampling(list_labels, num_samples, replacement=True): # 各クラスのデータ数を算出 count_label = np.zeros(num_classes, dtype=int) for label in list_labels: for l in label: count_label[l] += 1 # 各データの抽出確率 dataset_weights = torch.empty(len(list_labels)) # データの中で最もマイナーなクラスのラベルをそのデータの抽出確率に適用する inital_weight = 0.0 is_update = lambda data,max: data > max # 抽出確率を設定 for idx, label in enumerate(list_labels): data_weight = initial_weight for l in label: if is_update(1./count_label[l], data_weight): data_weight = 1./count_label[l] dataset_weights[idx] = data_weight # 多項分布からサンプリング return torch.multinomial(dataset_weights, num_samples, replacement=replacement)
list_labels
はサンプリング対象のデータセットのラベルをリストにしたものです。マルチラベル分類なので、データによってラベル数が異なります。データセットと同じ順序でデータが格納されていることを前提とします。具体的には[[1], [1,4], [2], [2,3,4],...]
のような変数。
torch.multinomial
では抽出確率の合計値は1でなくても良いので各データの抽出確率は 1/(クラスのデータ数) と設定しています。
また、同データが2回以上サンプリングされることを許可するかどうかを引数replacement
で設定できます。データの重複を許可しない場合(False
, デフォルト)は、under-samplingのみ、許可する場合(True
)はunder-samplingとover-samplingを同時に適用していることになります。
この関数ではデータセットのindexを返すので、torch.utils.data.Subset
を使用して抽出後のデータセットを生成します。
indices_sampled = sampling(list_labels_orig, num_samples) dataset_sampled = torch.utils.data.Subset(dataset_orig, indices_sampled)
抽出後の総データ数を10000にしてサンプリングすると、以下のような分布の変化を得ました。
元に比べて均衡が改善されたデータセットを得ることが出来ました。
この時点で十分効果が期待できそうですが、マルチラベルのデータのため、メジャークラスのデータが多く選ばれています。
多項分布によるサンプリング後(20000枚)、各クラスのデータ数に上限値を設けて抽出することで、さらに均衡にしてみます。
上限値を(10000/全クラス数) * (画像1枚あたりの平均ラベル数) としたとき、以下の分布が得られました。
(やりすぎな気もしますが)これで均衡なデータを得ることが出来ました。
損失関数を工夫する
不均衡データに対する損失関数の検証をされている下記ページを参考させて頂きました。
https://qiita.com/tancoro/items/c58cbb33ee1b5971ee3b
この項目はほぼ上記事の内容を私なりにまとめたものとなります。
データ数の偏りが学習結果に影響するということは、学習時の損失関数に工夫をすれば改善できる、という考えです。
最初のモデルではBCEを使用していましたが、損失関数自体を変えたりBCEに手を加えたりします。
以下の方法を考えてみます。
- 学習係数にデータ数の逆数比をかける(Balanced CE)
- Class Balanced Loss
- Focal Loss
一つずつ見ていきましょう。
学習係数にデータ数の逆数比をかける(Balanced CE)
まず、学習係数にデータ数の逆数比をかける方法ですが、マイナーデータに対しての学習時にロスを大きく、メジャーデータに対してはロスを小さくすることで、マイナークラスの学習効率の改善を図ります。重みはデータ数の逆数に比例させます。
今回のデータ分布では最も少ないクラスと最も多いクラスのデータ数に約5000倍もの差があります。
最もデータ数が多いクラスは最も少ないクラスより5000倍多く学習するが学習係数が1/5000なので同じくらい学習が進む…ということを狙います。
ただし、ロスの差がここまで大きいと学習がうまく行くのか不安です。
うまい具合で学習ができる範囲のデータセットであれば単純で良い手法だと思います。
BCEの各クラス毎に重みを変更するのは、PyTorchではBCEWithLogitLoss
の引数 weight
もしくは pos_weight
を利用すると簡単に実装できます。
今回の問題では、正解ラベルのときだけ重み付けをするpos_weight
を利用するのが良い(と思う)ので、pos_weight
にクラスのデータ数の逆数比を与えます。(そのまま逆数で割ってしまうと元の学習と比較しにくいので全ての重み係数に (全データ数)/(全クラス数) を掛けます。)
Class Balanced Loss
次に、Class Balanced Lossです。
https://openaccess.thecvf.com/content_CVPR_2019/html/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.html
先ほどと同様にクラスごとに学習時の重みつけを変える手法です。
本質的なデータの表現は限りがあり、実データ数の差と正比例しないのでロスの重みを工夫する、という考え方だと解釈しています。
クラスの対する学習係数の重みが以下の式で表現されます。
はクラスのデータ数、はハイパーパラメータであり、を取ります。
ロピタルの定理よりであり、をに近づけるとデータ数の逆数に近くなります。
機能的には先ほどのクラスのデータ数の逆数比を掛ける方法の、学習係数の下限を抑え込む、といった所でしょうか。
気にしていたデータ数の差が大きいためロスが安定しないかも、という不安に対しての答えになりそうです。
Focal Loss
Focal Lossは、上二つの手法と異なり係数だけでなく損失関数自体を定義します。
https://arxiv.org/pdf/1708.02002
物体検出において、画像全体のうち背景の割合が多いため、クラス間の不均衡が生じ精度が落ちているという考えから、背景の学習の積み重ねが物体クラスの学習を阻害しないようにする損失関数を提唱しています。
正解に近いロスの値を抑えて、大きく外れていた箇所のロスを大きくして重視する手法です。
論文中の式をそのまま記載すると
となります。
は正解ラベルのクラスの推測結果です。
Cross Entropy誤差にを掛けてるだけですね。 はハイパーパラメータで、のときはCross Entropy誤差と同じになります。
例えば正解時ラベルの推測のときにでよく推論出来ていなかった場合はとなり、で良く推論出来ているときはとなり、推論がよく出来てないときに比べてロスが小さくなります。
PyTorchのLoss Function https://pytorch.org/docs/stable/_modules/torch/nn/modules/loss.htmlを参考に実装しました。(引数device
は独自に追加していますが)
# FocalLoss class FocalLossWithLogits(torch.nn.Module): def __init__(self, gamma=1.0, reduction :str = 'mean', weight: Optional[torch.Tensor]=None, pos_weight: Optional[torch.Tensor]=None, device :str = 'cpu') -> None: super(FocalLossWithLogits, self).__init__() self.gamma = gamma self.device = device self.weight = weight if weight is not None else torch.tensor(1.0, device=self.device) self.pos_weight = pos_weight if pos_weight is not None else torch.tensor(1.0, device=self.device) if reduction not in {'mean', 'sum', 'none'}: raise ValueError("{} is not a valid value for reduction".format(reduction)) elif reduction == 'mean': self.reduction = torch.mean elif reduction == 'sum': self.reduction = torch.sum else: self.reduction = torch.nn.Identity() def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # BinaryCrossEntropy bce = torch.nn.functional.binary_cross_entropy_with_logits(input, target, reduction='none') pt = torch.exp(-bce) # Focal Loss loss = torch.pow(1.0-pt, self.gamma) * bce # Weight weight = self.weight * torch.where(target > 0, self.pos_weight, torch.tensor(1.0, device=self.device)) loss *= weight # Reduction loss = self.reduction(loss) return loss
Focal Lossは物体検出や画像セグメンテーションなど多くのラベルが0になる場合に有効なようですが、今回の問題でも既に学習が進んでいるメジャークラスの小さなロスの積み重ねがマイナークラスの学習のノイズになっていると考えたためFocal Lossが機能するのではないか思います。
結果 #2
改善手法の実験
さて、これら用意した方法と最初のモデルを比較検証します。
検証データにメジャークラスのみを持つデータを5000枚、マイナークラスのみを持つ画像データを5000枚、メジャークラスとマイナークラスを持つ画像データ(mixedと呼びます)を5000用意し、それぞれに対する識別性能を評価します。
学習に使う教師データは、元のデータセットからランダムに抽出した50000枚(originと呼びます)と、比較用に、全てのクラスでデータ数が等しくなるようにサンプリングしたデータセットからランダム抽出した50000枚(sampledと呼びます)を用意します。
20エポック学習させ、1エポックごとに検証データに対する損失関数と識別性能(ROC-AUCとPR-AUCによる評価)を調べます。
教師データのサンプリング
ランダムに抽出したデータセット(origin)とクラス間データ数の不均衡を解消したデータセット(sampled)の結果の比較です。
まずは、検証データに対する損失関数。
曲線が緩やかですが、元々のデータセットでのメジャークラスに対しての学習が出来ていたことから、サンプリング後のメジャークラス、マイナークラスに対しての学習は出来ていそうですね。
対して学習が進んでいないのは、元データセットでのマイナークラスとミックスのクラス(1枚にメジャークラスとマイナークラスを両方もつ画像)、サンプリング後のミックスクラスはあまり学習が進んでいなさそうです。
識別性能を見てみます。
評価指標はROC-AUCとPR-AUCを用います。こちらの結果でもおおむね同じ事が言えそうで、サンプリングにより、メジャークラスに対しての性能を大きく落とさず、マイナークラスに対する性能は向上したと言えるでしょう。
ミックスクラスがメジャー/マイナークラスよりスコアが低いのは何故でしょうかね?データの性質の違いを考えると、各画像内のラベル数が違うはずで、教師データ全体では1データの平均ラベル数が約1.25であり、メジャークラスやマイナーの検証データでは画像内の平均ラベル数が約1.3であるのに対して、ミックスクラスの検証データでは約3.0であったので、ここの違いが効いているのかもしれません。
これもマルチラベル分類におけるラベル数の不均衡データと言えるのですが、これが要因かは今回は検証しません。今後の課題とします。
損失関数
損失関数の違いによる比較です。
対象はBinary Cross Entropy(bce)、BCEの学習係数にクラスごとのデータ数の逆数比を掛けた損失関数(icf)、Class Balanced Lossの (cbf0.999), (cbf0.9)、Focal Lossの (focal1), (focal3)を比較してみます。
まずは検証データに対する損失関数の値を見ます。Focal Lossは損失関数の値自体異なるので、BCEとデータ数逆数比を学習率に掛けた損失関数(icf)とClass Balanced Loss (cbf)を比較します。
えーっと、一目でBCEより学習出来ていないのがありそうですね…。
データ数の逆数比を学習率に掛けた損失関数(icf)では、クラス間の学習率に5000倍もの差があったため、メジャークラスかマイナークラスどちらかがうまく学習出来ないとやる前から思ってました。
Class Balanced Lossの方ですが、 の方はクラス間の学習率の差は最大1000倍なので、もしかしたら上手く行くかと思いましたがそんなことはなさそうです。
ではクラス間の学習率の差は最大10倍なのでただのBCEよりちょっとだけ改善されることを期待しましたが、右のグラフを見るとマイナークラスとミックスクラスの結果がむしろ改悪してそうです。何故そうなるのかは見当つきません。
AUCによる比較をします。
こちらはBCEとClass Balanced Lossでほとんど差が無いように見えます。マイナークラスがちょっとだけ改善されているかも、という程度。
Class Balanced Lossによる改善はあまり見られませんでした。ただし、ハイパーパラメータを調整することで効果は得られるかもしれません。そこまで検証する計算リソースが無いので今回は保留します。
続いて、Focal LossによるAUCの比較を見てみます。Focal LossはKaggleでもよく使われているので改善に期待したいです。
AUCによる比較ではBCEとFocal Loss(と)による差は無さそうですね。
尤も、Focal Lossによる効果がより顕著なのはもっと学習が進んだときだと思うので、ただ学習不足なだけかもしれません。
差が出るのがどのくらい学習が必要かわからないので、これ以上深く計算しないことにしますが、今後Kaggle等で使う機会があったらもうちょっと深い検証を試してみたいです。
AUC以外に評価指標を集計していて、独自で作ったスコアによる比較結果を以下に貼ります。
スコアの計算式は
です。
損失関数はBCEベースなので、モデルの出力層の活性化でロジスティックシグモイド関数を適用して各クラスに[0,1]の推測値が出力されるので、その中の不正解ラベルの最高値と正解ラベルの最低値を利用します。
例えばスコアが0.9だと、正解ラベルは全部約0.9以上、不正解ラベルは全部約0.1以下という指標になります。
これによるとメジャー/マイナー/ミックス全てにおいてBCEよりFocal Lossのほうが少し良いという結果になっています。
これだけで良しとは言えませんが、ハイパーパラメータ調整をしっかりしなくても最低限酷く悪化するということは無さそうです。
まとめと課題点や感想など
今回初めて自分で問題設定して取り組んで(画像識別器の作成をして)みたことで、いろんなことを学べました。
機械学習的な観点では特にマルチラベル分類と不均衡データにどう対処していくかを少しだけ学べました。
最終的な検証結果として、まず、サンプリングは強い。
これは不均衡データ対処するときはした方が良いと思います。今回検証に使った例では全部のクラスのデータ数を均一にするということをしましたが、そこまでしないで多少不均衡を和らげるだけでも効果はあるんじゃないかと思います。
損失関数についてですが、ちょっと検証に掛けた時間不足が目立ちますね。
もうちょっと色々試行錯誤したいですが、これにばかり時間掛けて他のことが勉強できないというのもあれなので。
ただ、Focal Lossに関しては、積極的に使って行きたいと思うようになりました。
最終的に多項分布サンプリングやFocal Lossを使ってmobilenetv3より大きいモデル(構造似てるEfficientnetなど)を学習させてみようと思っています。
マルチラベル分類の方は記事タイトルに入れたけど記事内容的には不均衡データの方に偏ってしまった。っていう不均衡ギャグ(マルチラベル分類についても習得したことなのでもうちょっと書きたかった。)
記事執筆も初めてだったので内容の薄さに対して冗長だったかもしれません。これも改善していきたい。
細かい感想を言えばmatplotlibの結果をブログに貼るときsvgで貼りたい。
あと、識別器としての結果が出るのは楽しい。