TensorFlow の活性化関数

活性化関数(伝達関数)は、入力信号の総和を出力信号に変換する関数のことです。

パーセプトロンの時代ではステップ関数が用いられ、バックプロパゲーションの時代ではシグモイド関数が用いられましたが、最近ではReLU関数が多く用いられます。

ここでは、よく使われる活性化関数の概要を説明し、TensorFlowで使う場合のサンプルコードを紹介したいと思います。

事前準備

まずサンプルコードで使用するパッケージをインポートします。


In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

sess = tf.Session()

活性化関数

ステップ関数

$$ f(x) = \begin{cases} 1 & x \le \theta\\ 0 & x > \theta \end{cases} $$

In [2]:
def step_function(x):
    return np.array(x > 0, dtype=np.int)

X = np.arange(-8., 8., 0.02)
Y = step_function(X)

plt.plot(X, Y)
plt.xlim(-8, 8)
plt.ylim(-0.5, 1.5)
plt.grid()
plt.show()


ロジスティックシグモイド関数

ロジスティックシグモイド関数はロジスティク関数とも呼ばれていて、範囲は 0 から 1 までの値をとります。

$$ f(x) = \frac{1}{1 + e^{-x}} $$

In [3]:
X = np.arange(-8., 8., 0.1)
Y = sess.run(tf.nn.sigmoid(X))

plt.plot(X, Y)
plt.xlim(-8, 8)
plt.ylim(-0.5, 1.5)
plt.grid()
plt.show()


ReLU(Rectified Linear Unit)関数

シグモイド関数やハイパボリックタンジェント関数は絶対値が大きいほど勾配がなくなって学習が停滞する問題があります。

$$ f(x) = max(0, x) $$

In [4]:
X = np.arange(-8., 8., 0.1)
Y = sess.run(tf.nn.relu(X))

plt.plot(X, Y)
plt.xlim(-8, 8)
plt.ylim(-1., 8.)
plt.grid()
plt.show()


ReLU6関数


In [5]:
X = np.arange(-8., 8., 0.1)
Y = sess.run(tf.nn.relu6(X))

plt.plot(X, Y)
plt.xlim(-8, 8)
plt.ylim(-1., 8.)
plt.grid()
plt.show()


ELU(Exponential Linear Units)関数


In [6]:
X = np.arange(-8., 8., 0.1)
Y = sess.run(tf.nn.elu(X))

plt.plot(X, Y)
plt.xlim(-8, 8)
plt.ylim(-2., 8.)
plt.grid()
plt.show()


ソフトプラス関数

$$ f(x) = \log{(1 + e^x)} $$

In [7]:
X = np.arange(-8., 8., 0.1)
Y = sess.run(tf.nn.softplus(X))

plt.plot(X, Y)
plt.xlim(-8, 8)
plt.ylim(-1., 8.)
plt.grid()
plt.show()


ソフトサイン関数

$$ f(x) = \frac{x}{1 + |x|} $$

In [8]:
X = np.arange(-8., 8., 0.1)
Y = sess.run(tf.nn.softsign(X))

plt.plot(X, Y)
plt.xlim(-8, 8)
plt.ylim(-1.5, 1.5)
plt.grid()
plt.show()


ハイパボリックタンジェント関数

ハイパボリックタンジェント関数は双曲線正接関数とも呼ばれていて、範囲は -1 から 1 までの値をとります。

$$ tanh(x) = \frac{e^x - e^{-u}}{e^x + e^{-u}} $$

とも表せるように、ロジスティック関数と似た性質を持ちます。


In [9]:
X = np.arange(-8., 8., 0.1)
Y = sess.run(tf.nn.tanh(X))

plt.plot(X, Y)
plt.xlim(-8, 8)
plt.ylim(-1.5, 1.5)
plt.grid()
plt.show()