重みの初期値

重みの初期値

ハイパラメータとして他にも重要なのが、重みの初期値です。ここでは、XavierとHeの初期値を使ってXOR回路のニューラルネットワークを実行したいと思います。

Xavier

定義

ランダムで選んだ初期値に、前層のノードの個数nの平方で割ります。
 {\boldsymbol{W}*\frac{1}{\sqrt{n}}}

コード

def __init__(self, input_size, hidden_size, output_size):
    self.params = {}
    self.params['W1'] = np.random.randn(input_size, hidden_size) / np.sqrt(input_size)
    self.params['b1'] = np.zeros(hidden_size)
    self.params['W2'] = np.random.randn(hidden_size, output_size) / np.sqrt(hidden_size)
    self.params['b2'] = np.zeros(output_size)

He

定義

Xavierの初期値に {\sqrt{2}}を掛けます。
 {\boldsymbol{W}*\sqrt{\frac{2}{n}}}

コード

def __init__(self, input_size, hidden_size, output_size):
    self.params = {}
    self.params['W1'] = np.random.randn(input_size, hidden_size) / np.sqrt(input_size) * np.sqrt{2}
    self.params['b1'] = np.zeros(hidden_size)
    self.params['W2'] = np.random.randn(hidden_size, output_size) / np.sqrt(hidden_size) * np.sqrt{2}
    self.params['b2'] = np.zeros(output_size)

他にも

標準偏差0.01

def __init__(self, input_size, hidden_size, output_size):
    self.params = {}
    self.params['W1'] = 0.01 * np.random.randn(input_size, hidden_size)
    self.params['b1'] = np.zeros(hidden_size)
    self.params['W2'] = 0.01 * np.random.randn(hidden_size, output_size)
    self.params['b2'] = np.zeros(output_size)

間違ったら良くなった初期値

def __init__(self, input_size, hidden_size, output_size):
    self.params = {}
    self.params['W1'] = 1.1 ** np.random.randn(input_size, hidden_size)
    self.params['b1'] = np.zeros(hidden_size)
    self.params['W2'] = 1.1 * np.random.randn(hidden_size, output_size)
    self.params['b2'] = np.zeros(output_size)

出力

Xavier

('ans:', 15, '/ 100')
('ans:', 10, '/ 100')
('ans:', 22, '/ 100')

He

('ans:', 18, '/ 100')
('ans:', 17, '/ 100')
('ans:', 13, '/ 100')

標準偏差0.01

('ans:', 15, '/ 100')
('ans:', 16, '/ 100')
('ans:', 14, '/ 100')

間違ったら良くなった初期値

('ans:', 44, '/ 100')
('ans:', 40, '/ 100')
('ans:', 40, '/ 100')

結論

XavierもHeもあまりよくならなかった。ハイパラメータを自動で決めることで、良い結果が得られるかも。 ちなみに一番良かった初期値(間違ったやつ)にMomentumを使うと、

('ans:', 76, '/ 100')
('ans:', 71, '/ 100')
('ans:', 64, '/ 100')

なかなかすばらしい。