Python Keras 使用 class_weight 調整不平衡

說明

這篇是使用 Keras 有一個參數 classweight ,在訓練的時候,有時會遇到不同分類的資料量不平衡,例如醫學資料,異常的資料可能佔少數,正常的資料還是占大多數,那在這樣的情況下,可以使用 classweight 去做資料的平衡,也可以使用實際資料去做資料量擴增,也還有其他多種方法,這裡說明平衡資料的參數使用。

操作流程

引入套件

from sklearn.utils import class_weight

平衡資料

class_weights = class_weight.compute_class_weight('balanced',
                                                 np.unique(y_train),
                                                 y_train)

還有另外一個方法是直接設定資料平衡的倍數

class_weight = {0: 1.,
                1: 50.,
                2: 2.}

這裡的意思是第0類=>1倍 第1類=>乘以50倍 第2類=>乘以2倍

在訓練的時候

model.fit(X_train, Y_train, nb_epoch=5, batch_size=32, class_weight=class_weight)

fit 最後要下 class_weight 的參數,這樣就可以做到資料調整

參考

如何在Keras中为不平衡的班级设置班级权重?

留言