[DL] What is weight decay?

In a deep neural network, the more layers there are, the more expressive the model will be. However, the higher the number of layers, the higher the risk of ** overfitting **. The risk of ** overfitting ** is reduced by limiting the degree of freedom of the parameter while maintaining the expressive power of the Model. One of the methods is ** weight decay **.

The weight update formula is written as follows.

w \leftarrow w -\eta \frac{\partial C(w)}{\partial w} - \eta \lambda w

The above formula is a little difficult to understand what you want to do, but it actually comes from the cost function as shown below.

\tilde C(w) = C(w) + \frac{\lambda}{2}||w||^2

This is the ** cost function ** with the ** L2 regularization ** clause. This term reduces the weight value. So, when you actually implement it, you will need to add the ** L2 regularization ** section to the cost.

Normally, ** L2 regularization ** is not applied to bias. This comes from the difference in the roles of neuron weight and bias. Since weight is the role of selecting the input, the value can be smaller as long as the priority does not change, but This is because bias is a threshold and may have to be large.

If you actually train with ** weight decay ** and ** without weight decay ** and look at the weight histogram, it will look like the figure below. The left is without weight decay and the right is with weight decay. You can see that the weight is getting smaller.

weightdecay1.png

The accuracy is as follows. Blue is the result of no weight decay, red is the result of weight decay, the dotted line is the training data, and the solid line is the validation data. weightdecay2.png

image

import tensorflow as tf
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("./data/mnist/", one_hot=True)

image_size = 28
n_labels = 10
n_batch  = 128
n_train  = 10000
beta = 0.001

def accuracy(y, y_):
    return 100.0 * np.sum(np.argmax(y, 1) == np.argmax(y_, 1)) / y.shape[0]

with tf.variable_scope("slp"):
    x  = tf.placeholder(tf.float32, shape=(n_batch, image_size*image_size))
    y_ = tf.placeholder(tf.float32, shape=(n_batch, n_labels))
    w0 = tf.get_variable("w0", [image_size * image_size, n_labels], initializer=tf.truncated_normal_initializer(seed=0))
    b0 = tf.get_variable("b0", [n_labels], initializer=tf.constant_initializer(0.0))

    w1 = tf.get_variable("w1", [image_size * image_size, n_labels], initializer=tf.truncated_normal_initializer(seed=0))
    b1 = tf.get_variable("b1", [n_labels], initializer=tf.constant_initializer(0.0))
    
    y0 = tf.matmul( x, w0 ) + b0
    y1 = tf.matmul( x, w1 ) + b1
    
valid_data = mnist.validation.images
valid_labels = mnist.validation.labels
test_data = mnist.test.images
test_labels = mnist.test.labels
vx = tf.constant(valid_data)
vy_ = tf.constant(valid_labels)
tx = tf.constant(test_data)
ty_ = tf.constant(test_labels)

loss0 = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits( y0, y_))
optimizer0 = tf.train.GradientDescentOptimizer(0.5).minimize(loss0)

loss1 = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits( y1, y_) + beta * tf.nn.l2_loss(w1))
optimizer1 = tf.train.GradientDescentOptimizer(0.5).minimize(loss1)

train_prediction0 = tf.nn.softmax(y0)
valid_prediction0 = tf.nn.softmax(tf.matmul(vx, w0) + b0)
test_prediction0  = tf.nn.softmax(tf.matmul(tx, w0) + b0)

train_prediction1 = tf.nn.softmax(y1)
valid_prediction1 = tf.nn.softmax(tf.matmul(vx, w1) + b1)
test_prediction1  = tf.nn.softmax(tf.matmul(tx, w1) + b1)

sess = tf.InteractiveSession()
# sess = tf.Session()

init = tf.initialize_all_variables()
sess.run(init)
result_accuracy = []

for step in xrange(n_train):
    bx, by = mnist.train.next_batch(n_batch)    
    _, L0, tp0 = sess.run([optimizer0, loss0, train_prediction0], feed_dict={x: bx, y_: by})
    _, L1, tp1 = sess.run([optimizer1, loss1, train_prediction1], feed_dict={x: bx, y_: by})
    if step % 500 == 0:
        ac_wo_decay_train = accuracy(tp0, by)
        ac_wo_decay_valid = accuracy(valid_prediction0.eval(), valid_labels)
        ac_wt_decay_train = accuracy(tp1, by)
        ac_wt_decay_valid = accuracy(valid_prediction1.eval(), valid_labels)
        ac = {'step' : step, 'wo_decay' : {'training' : ac_wo_decay_train, 'validation' : ac_wo_decay_valid}, 'wt_decay' : {'training' : ac_wt_decay_train, 'validation' : ac_wt_decay_valid}}
        result_accuracy.append(ac)
        print "step = %d, train accuracy0: %.2f, validation accuracy0: %.2f, train accuracy1: %.2f, validation accuracy1: %.2f" % (step, ac_wo_decay_train, ac_wo_decay_valid, ac_wt_decay_train, ac_wt_decay_valid)
        
print "test accuracy0: %.2f" % accuracy(test_prediction0.eval(), test_labels)
print "test accuracy1: %.2f" % accuracy(test_prediction1.eval(), test_labels)

fig,axes = plt.subplots(ncols=2, figsize=(8,4))
axes[0].hist(w0.eval().flatten(), bins=sp.linspace(-3,3,50))
axes[0].set_title('without weight decay')
axes[0].set_xlabel('weight')
axes[1].hist(w1.eval().flatten(), bins=sp.linspace(-3,3,50))
axes[1].set_title('with weight decay')
axes[1].set_xlabel('weight')
fig.show()

tr_step = [ac['step'] for ac in result_accuracy]
ac_training_wo_decay = [ac['wo_decay']['training'] for ac in result_accuracy]
ac_training_wt_decay = [ac['wt_decay']['training'] for ac in result_accuracy]
ac_validation_wo_decay = [ac['wo_decay']['validation'] for ac in result_accuracy]
ac_validation_wt_decay = [ac['wt_decay']['validation'] for ac in result_accuracy]

fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(1,1,1)

ax.plot(tr_step, ac_training_wo_decay, color='blue', linestyle='dashed')
ax.plot(tr_step, ac_training_wt_decay, color='red', linestyle='dashed')
ax.plot(tr_step, ac_validation_wo_decay, color='blue', linestyle='solid')
ax.plot(tr_step, ac_validation_wt_decay, color='red', linestyle='solid')
ax.set_title('accuracy')
ax.set_xlabel('step')
ax.set_ylabel('accuracy')
ax.grid(True)
ax.set_xlim((0, 10000))
ax.set_ylim((0, 100))
fig.show()

Recommended Posts

[DL] What is weight decay?
What is namespace
What is Django? .. ..
What is dotenv?
What is POSIX?
What is Linux
What is klass?
What is SALOME?
What is Linux?
What is python
What is hyperopt?
What is Linux
What is pyvenv
What is __call__
What is Linux
What is Python
[Python] What is Pipeline ...
What is Calmar Ratio?
What is a terminal?
[PyTorch Tutorial ①] What is PyTorch?
What is hyperparameter tuning?
What is a hacker?
What is JSON? .. [Note]
What is Linux for?
What is a pointer?
What is ensemble learning?
What is TCP / IP?
What is Python's __init__.py?
What is an iterator?
What is UNIT-V Linux?
[Python] What is virtualenv
What is machine learning?
What is Minisum or Minimax?
What is Linux? [Command list]
What is Logistic Regression Analysis?
What is the activation function?
What is the Linux kernel?
What is an instance variable?
What is a Context Switch?
What is Google Cloud Dataflow?
[Python] Python and security-① What is Python?
What is a super user?
Competitive programming is what (bonus)
[Python] * args ** What is kwrgs?
What is a system call
[Definition] What is a framework?
What is the interface for ...
What is Project Euler 3 Acceleration?
What is a callback function?
What is the Callback function?
What is a python map?
What is your "Tanimoto coefficient"?
Python Basic Course (1 What is Python)
[Python] What is a zip function?
[Python] What is a with statement?
What is labeling in financial forecasting?
What is Reduced Rank Ridge Regression?
What is Azure Automation Update Management?
[Python] What is @? (About the decorator)
What is a lexical scope / dynamic scope?
What is a Convolutional Neural Network?