In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
In [2]:
# テンソルを作成
# Variableはgradプロパティを持ち自動微分の対象になる
# requires_grad=Falseだと微分の対象にならず勾配はNoneが返る
x = Variable(torch.Tensor([1]), requires_grad=True)
w = Variable(torch.Tensor([2]), requires_grad=True)
b = Variable(torch.Tensor([3]), requires_grad=True)
# 計算グラフを構築
# y = 2 * x + 3
y = w * x + b
# 勾配を計算
y.backward()
# 勾配を表示
print(x.grad) # dy/dx = w = 2
print(w.grad) # dy/dw = x = 1
print(b.grad) # dy/db = 1
In [3]:
x = Variable(torch.Tensor([2]), requires_grad=True)
y = x ** 2
y.backward()
print(x.grad)
In [4]:
x = Variable(torch.Tensor([2]), requires_grad=True)
y = torch.exp(x)
y.backward()
print(x.grad)
In [5]:
x = Variable(torch.Tensor([np.pi]), requires_grad=True)
y = torch.sin(x)
y.backward()
print(x.grad)
In [6]:
x = Variable(torch.Tensor([0]), requires_grad=True)
y = (x - 4) * (x ** 2 + 6)
y.backward()
print(x.grad)
In [7]:
x = Variable(torch.Tensor([2]), requires_grad=True)
y = (torch.sqrt(x) + 1) ** 3
y.backward()
print(x.grad)
In [8]:
x = Variable(torch.Tensor([1]), requires_grad=True)
y = Variable(torch.Tensor([2]), requires_grad=True)
z = (x + 2 * y) ** 2
z.backward()
print(x.grad) # dz/dx
print(y.grad) # dz/dy
In [9]:
# バッチサンプル数=5、入力特徴量の次元数=3
x = Variable(torch.randn(5, 3))
# バッチサンプル数=5、出力特徴量の次元数=2
y = Variable(torch.randn(5, 2))
# Linear層を作成
# 3ユニット => 2ユニット
linear = nn.Linear(3, 2)
# Linear層のパラメータ
print('w:', linear.weight)
print('b:', linear.bias)
# lossとoptimizer
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(linear.parameters(), lr=0.01)
# forward
pred = linear(x)
# loss = L
loss = criterion(pred, y)
print('loss:', loss)
# backpropagation
loss.backward()
# 勾配を表示
print('dL/dw:', linear.weight.grad)
print('dL/db:', linear.bias.grad)
# 勾配を用いてパラメータを更新
print('*** by hand')
print(linear.weight.sub(0.01 * linear.weight.grad))
print(linear.bias.sub(0.01 * linear.bias.grad))
# 勾配降下法
optimizer.step()
# 1ステップ更新後のパラメータを表示
# 上の式と結果が一致することがわかる
print('*** by optimizer.step()')
print(linear.weight)
print(linear.bias)