Simple Network


In [1]:
import torch

In [2]:
N, t_in, t_out = 64, 128, 256

In [3]:
X = torch.randn(N, t_in).cuda()
y = torch.randn(N, t_out).cuda()

In [4]:
model = torch.nn.Linear(in_features=t_in, out_features=t_out).cuda()

In [5]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [6]:
%time
optimizer.zero_grad()
predicted = model(X)
loss = torch.nn.functional.mse_loss(predicted, y)
loss.backward()
optimizer.step()


CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 6.91 µs

Pure half-Precision


In [7]:
X = torch.randn(N, t_in).cuda().half()
y = torch.randn(N, t_out).cuda().half()
model = torch.nn.Linear(in_features=t_in, out_features=t_out).cuda().half()

In [8]:
%time
optimizer.zero_grad()
predicted = model(X)
loss = torch.nn.functional.mse_loss(predicted, y)
loss.backward()
optimizer.step()


CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 10.7 µs

Mixed precision

master weight


In [9]:
def pre_param_list(model:torch.nn.Module):
    model_params = [p for p in model.parameters() if p.requires_grad]
    master_params = [p.detach().clone().float() for p  in model.parameters()]
    for p in model_params:
        p.requires_grad = True
    
    return model_params, master_params

weight sync


In [10]:
def master_param_to_model_param(model_param, master_param):
    for model, master in zip(model_param, master_param):
        model.data.copy_(master.data)

In [11]:
X = torch.randn(N, t_in).cuda().half()
y = torch.randn(N, t_out).cuda().half()
model = torch.nn.Linear(in_features=t_in, out_features=t_out).cuda().half()

In [12]:
model_params, master_params = pre_param_list(model)

In [13]:
def model_grad_to_master_grad(model_param, master_param):
    for model, master in zip(model_param, master_param):
        if master.grad is None:
            master.grad = torch.autograd.Variable(master.data.new(*master.data.size()))
    master.grad.data.copy_(model.grad.data)

In [ ]:
model_params, master_params = pre_param_list(model)
optimizer = torch.optim.Adam(master_params, lr=0.001)
scale_factor  =128

In [15]:
%time
optimizer.zero_grad()
predicted = model(X)
loss = torch.nn.functional.mse_loss(predicted, y)
scaled_loss = loss.float() * scale_factor
scaled_loss.backward()
model_grad_to_master_grad(model_params, master_params)
for p in master_params:
    p.grad.data.mul_(1/scaled_loss)
optimizer.step()
master_param_to_model_param(model_params, master_params)


CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 5.72 µs

In [ ]: