In [1]:
import mxnet as mx
from mxnet import gluon
In [2]:
ctx = mx.cpu()
In [3]:
# Generate some example NDArrays
X = mx.nd.ones((100, 100))
Y = mx.nd.zeros((100, 100))
In [4]:
# Create a folder mimicking checkpoints
import os
os.makedirs('checkpoints', exist_ok=True)
filename = "checkpoints/test1.params"
In [5]:
# Save the data to a file
mx.nd.save(fname=filename, data=[X, Y])
In [6]:
# Load the data from a file
A, B = mx.nd.load(filename)
print(A)
print(B)
In [7]:
# Dictionary
# Saving a dictionary
mydict = {"X": X, "Y": Y}
filename = "checkpoints/test2.params"
mx.nd.save(filename, mydict)
In [8]:
# Loading the dictionary
C = mx.nd.load(filename)
print(C)
In [9]:
C['X']
Out[9]:
In [10]:
C['Y']
Out[10]:
In [11]:
# Defining a dummy network
num_hidden = 256
num_outputs = 1
In [12]:
net = gluon.nn.Sequential()
with net.name_scope():
net.add(gluon.nn.Dense(units=num_hidden,
activation="relu"))
net.add(gluon.nn.Dense(units=num_hidden,
activation="relu"))
net.add(gluon.nn.Dense(units=num_outputs))
In [13]:
# Collecting parameters
net.collect_params().initialize(mx.init.Normal(sigma=1.), ctx=ctx)
At this point the parameters are not yet fully defined since we have not provided the output, nor the hard-coded shape
In [14]:
# First data run through the network
dummy_data = mx.nd.ones((1, 100), ctx=ctx)
net(dummy_data)
Out[14]:
In [15]:
# Saving network parameters to a file
filename = "checkpoints/testnet.params"
# Use net.save_params(filename), save_params is deprecated
net.save_parameters(filename)
In [16]:
# Loading the parameters into another network
net2 = gluon.nn.Sequential()
with net2.name_scope():
net2.add(gluon.nn.Dense(units=num_hidden,
activation="relu"))
net2.add(gluon.nn.Dense(units=num_hidden,
activation="relu"))
net2.add(gluon.nn.Dense(units=num_outputs))
In [17]:
# Use net.load_parameters(), load_params is deprecated
net2.load_parameters(filename=filename,
ctx=ctx)
In [18]:
net2(dummy_data)
Out[18]: