In [1]:
import torch
import torch.nn as nn

In [2]:
torch.manual_seed(1)

class SimpleNet(nn.Module):
    def __init__(self, input_size, hidden, output_size):
        super().__init__()
        self.layer1 = nn.Linear(input_size, hidden)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(hidden, output_size)

    def forward(self, input):
        out = self.layer1(input)
        out = self.relu(out)
        out = self.layer2(out)
        return out
    
model = SimpleNet(3, 5, 2)
crit = nn.CrossEntropyLoss()

x = torch.Tensor([[1, 2, 3], [4, 5, 6]])
y = torch.LongTensor([0, 1])
x.requires_grad = True
ypred = model(x)
loss = crit(ypred, y)
loss.backward(retain_graph=True)
loss += torch.norm(x.grad)
loss.backward()
print(loss)


tensor(0.7912, grad_fn=<AddBackward0>)