In [1]:
import torch
import torch.nn as nn
In [2]:
torch.manual_seed(1)
class SimpleNet(nn.Module):
def __init__(self, input_size, hidden, output_size):
super().__init__()
self.layer1 = nn.Linear(input_size, hidden)
self.relu = nn.ReLU()
self.layer2 = nn.Linear(hidden, output_size)
def forward(self, input):
out = self.layer1(input)
out = self.relu(out)
out = self.layer2(out)
return out
model = SimpleNet(3, 5, 2)
crit = nn.CrossEntropyLoss()
x = torch.Tensor([[1, 2, 3], [4, 5, 6]])
y = torch.LongTensor([0, 1])
x.requires_grad = True
ypred = model(x)
loss = crit(ypred, y)
loss.backward(retain_graph=True)
loss += torch.norm(x.grad)
loss.backward()
print(loss)