In [1]:
import torch
In [2]:
N, t_in, t_out = 64, 128, 256
In [3]:
X = torch.randn(N, t_in).cuda()
y = torch.randn(N, t_out).cuda()
In [4]:
model = torch.nn.Linear(in_features=t_in, out_features=t_out).cuda()
In [5]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
In [6]:
%time
optimizer.zero_grad()
predicted = model(X)
loss = torch.nn.functional.mse_loss(predicted, y)
loss.backward()
optimizer.step()
In [7]:
X = torch.randn(N, t_in).cuda().half()
y = torch.randn(N, t_out).cuda().half()
model = torch.nn.Linear(in_features=t_in, out_features=t_out).cuda().half()
In [8]:
%time
optimizer.zero_grad()
predicted = model(X)
loss = torch.nn.functional.mse_loss(predicted, y)
loss.backward()
optimizer.step()
In [9]:
def pre_param_list(model:torch.nn.Module):
model_params = [p for p in model.parameters() if p.requires_grad]
master_params = [p.detach().clone().float() for p in model.parameters()]
for p in model_params:
p.requires_grad = True
return model_params, master_params
In [10]:
def master_param_to_model_param(model_param, master_param):
for model, master in zip(model_param, master_param):
model.data.copy_(master.data)
In [11]:
X = torch.randn(N, t_in).cuda().half()
y = torch.randn(N, t_out).cuda().half()
model = torch.nn.Linear(in_features=t_in, out_features=t_out).cuda().half()
In [12]:
model_params, master_params = pre_param_list(model)
In [13]:
def model_grad_to_master_grad(model_param, master_param):
for model, master in zip(model_param, master_param):
if master.grad is None:
master.grad = torch.autograd.Variable(master.data.new(*master.data.size()))
master.grad.data.copy_(model.grad.data)
In [ ]:
model_params, master_params = pre_param_list(model)
optimizer = torch.optim.Adam(master_params, lr=0.001)
scale_factor =128
In [15]:
%time
optimizer.zero_grad()
predicted = model(X)
loss = torch.nn.functional.mse_loss(predicted, y)
scaled_loss = loss.float() * scale_factor
scaled_loss.backward()
model_grad_to_master_grad(model_params, master_params)
for p in master_params:
p.grad.data.mul_(1/scaled_loss)
optimizer.step()
master_param_to_model_param(model_params, master_params)
In [ ]: