しんさんのブログ

科学や技術のこと読書のことなど

機械学習基礎の基礎: 過学習

今回は過学習についてみて見ます。
まずは訓練データを作成します。
訓練データは3次関数にノイズを加えた8つの点(x, y)構成することとします。
プログラムの初めに8点生成して、それらの点を10次関数でfittingすることにします。
データの個数がfitting関数の自由度よりも小さいのでoverfittingすることが予想できます。
最適化の目的関数は最小2乗誤差です。
損失関数の変化が小さくなるまで学習を繰り返し、fittingする10次関数の係数thetaを決めます。
コードにまとめると,

###############
# overfittingと正則化の実装
###############
import numpy as np
import matplotlib.pyplot as plt

# 3次関数を定義
def f(x):
	return 0.2*(x**3 + x **2 + x)

# ノイズを付与する
trainDataX = np.linspace(-1, 1, 8)
trainDataY = f(trainDataX) + 0.05 * np.random.randn(trainDataX.size)

# dataを平均ゼロ、標準偏差1に標準化する
mu = trainDataX.mean()
sigma = trainDataX.std()

def standardization(x):
	return (x - mu)/sigma

trainDataXSTD = standardization(trainDataX)

# 10次の多項式でfittingする
# data が8個しかないのに、自由度が10個ある関数でfitするのでoverfitする

def ToMatrix(x):
	return np.vstack([
		np.ones(x.size),
		x, x**2, x**3, x**4, x**5, x**6, x**7, x**8, x**9, x**10,
		]).T

X = ToMatrix(trainDataXSTD)

# パラメータの初期値設定
theta = np.random.randn(X.shape[1])

# fitting関数
def fitfunc(x):
	return np.dot(x, theta)

# loss 関数(最小2乗誤差)
def E(x, y):
	return 0.5 * np.sum((y - fitfunc(x))**2)

eta = 1e-4 #学習率

# loss 関数の変化量
diff = 1

### 学習 ###
error = E(X, trainDataY)
while diff > 1e-8:
	theta = theta - eta * np.dot(fitfunc(X) - trainDataY, X)
	current_error = E(X, trainDataY)
	diff = error - current_error
	error = current_error

# 結果を表示
x = np.linspace(-1, 1, 100)
xSTD = standardization(x)
plt.plot(trainDataXSTD, trainDataY, 'o') #3次関数に誤差を乗せた学習データ
plt.plot(xSTD, fitfunc(ToMatrix(xSTD))) # fittingした10次関数の曲線
#plt.show()
# データをplotする#plt.plot(trainDataX, trainDataY, 'o')
plt.plot(xSTD, f(x))
plt.show()

f:id:wshinya:20180327000706p:plain
緑の線は学習データを作成するために使った3次関数のグラフ。
点は3次関数にノイズを加えて作成した学習用データ。
さらにオレンジの線は10次元関数でfittingした結果のグラフ。
つまり学習結果です。
オレンジのグラフは元の3次関数のグラフとはかなりずれた結果になっていますので、この結果を使用して新しい入力xに対して出力yを求めると期待しているものとは異なった結果となりそうです。
与えられた8つの学習データに対してはそれなりにいい近似になっているように見えますが、未知の入力に対しての予想精度が悪くなっている、これが過学習です。