機械学習基礎の基礎: 過学習を防ぐ正則化

前回のブログでは2次元平面上の8つの点を10次関数でfittingしました。
すると、fittingする関数の自由度が高いために、8つの点すべてを通るようにうまくパラメータを調節できてしまい、学習データに対しては100%近い正解率がでるのに、新たなデータに対する予言の精度が下がってしまうという現象をみました。
これが過学習というもので、別の言い方をすると汎化性が低いとも言えます。
この過学習を防ぐためには学習データを増やすか、fittingする関数の自由度を下げればいいということがわかります。
ここでは自由度を下げるために、L2正則化と呼ばれる手法を実装してみます。
実装はとても簡単で、最適化する目的関数に、(1/2) theta^2 を追加するだけです。
目的関数を最小化するようなthetaの組を求めるときに、この項の効果でthetaをあまり大きな値にできないという拘束が生まれます。
これにより、自由度が下がって過学習を防ぐことができるという仕組みです。
実際のコードは以下のようになります。

###############
# overfittingと正則化の実装
###############
import numpy as np
import matplotlib.pyplot as plt

# 3次関数を定義
def f(x):
	return 0.2*(x**3 + x **2 + x)

# ノイズを付与する
trainDataX = np.linspace(-1, 1, 8)
trainDataY = f(trainDataX) + 0.05 * np.random.randn(trainDataX.size)

# dataを平均ゼロ、標準偏差1に標準化する
mu = trainDataX.mean()
sigma = trainDataX.std()

def standardization(x):
	return (x - mu)/sigma

trainDataXSTD = standardization(trainDataX)

# 10次の多項式でfittingする
# data が8個しかないのに、自由度が10個ある関数でfitするのでoverfitする

def ToMatrix(x):
	return np.vstack([
		np.ones(x.size),
		x, x**2, x**3, x**4, x**5, x**6, x**7, x**8, x**9, x**10,
		]).T

X = ToMatrix(trainDataXSTD)

# パラメータの初期値設定
theta = np.random.randn(X.shape[1])

# fitting関数
def fitfunc(x):
	return np.dot(x, theta)

# loss 関数(最小2乗誤差)
def E(x, y):
	return 0.5 * np.sum((y - fitfunc(x))**2)

eta = 1e-4 #学習率
l = 5.0 #正則化係数

# loss 関数の変化量
diff = 1
### 学習 ###
error = E(X, trainDataY)
while diff > 1e-8:
	# 正則化項
	r = l * np.hstack([0, theta[1:]]) #バイアスには正則化を足さない
	theta = theta - eta * (np.dot(fitfunc(X) - trainDataY, X ) + r);
	current_error = E(X, trainDataY)
	diff = error - current_error
	error = current_error
# 結果を表示
x = np.linspace(-1, 1, 100)
xSTD = standardization(x)
plt.plot(trainDataXSTD, trainDataY, 'o') #3次関数に誤差を乗せた学習データ
plt.plot(xSTD, fitfunc(ToMatrix(xSTD))) # fittingした10次関数の曲線
#plt.show()
# データをplotする#plt.plot(trainDataX, trainDataY, 'o')
plt.plot(xSTD, f(x))
plt.show()

このコードを実行すると、次のようなグラフが描画されます。
正則化係数lを変更すると正則化項の強さが変わって、過学習の度合いが変化するのがわかります。
f:id:wshinya:20180401231849p:plain
グラフは、緑の線が学習データの基になった3次関数で、赤が正則化後のfittingカーブです。
前回の正則化なしの結果に比べて、データの基になった関数をより近似していることがわかります。