Learn Modulation and Demodulation in Nonlinear Channels with Deep Neural Networks by Autoencoders and End-to-end Training

This code is provided as supplementary material of the lecture Machine Learning and Optimization in Communications (MLOC).

This code illustrates

  • End-to-end-learning of modulation scheme and demodulator in a simple nonlinear channel model with time-varying batch size

In [1]:
import tensorflow as tf
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from ipywidgets import interactive
import ipywidgets as widgets

Specify the parameters of the transmission as the fiber length $L$ (in km), the fiber nonlinearity coefficienty $\gamma$ (given in 1/W/km) and the total noise power $P_n$ (given in dBM. The noise is due to amplified spontaneous emission in amplifiers along the link). We assume a model of a dispersion-less fiber affected by nonlinearity. The model, which is described for instance in [1] is given by an iterative application of the equation $$ x_{k+1} = x_k\exp\left(\jmath\frac{L}{K}\gamma|x_k|^2\right) + n_{k+1},\qquad 0 \leq k < K $$ where $x_0$ is the channel input (the modulated, complex symbols) and $x_K$ is the channel output. $K$ denotes the number of steps taken to simulate the channel Usually $K=50$ gives a good approximation.

[1] S. Li, C. Häger, N. Garcia, and H. Wymeersch, "Achievable Information Rates for Nonlinear Fiber Communication via End-to-end Autoencoder Learning," Proc. ECOC, Rome, Sep. 2018


In [2]:
# Length of transmission (in km)
L = 5000

# fiber nonlinearity coefficient
gamma = 1.27

Pn = -21.3 # noise power (in dBm)

Kstep = 50 # number of steps used in the channel model

# noise variance per step    
sigma_n = np.sqrt((10**((Pn-30)/10)) / Kstep / 2)    


def simulate_channel(x, Pin, constellation):  
    # modulate bpsk
    input_power_linear = 10**((Pin-30)/10)
    norm_factor = 1 / np.sqrt(np.mean(np.abs(constellation)**2)/input_power_linear)
    modulated = constellation[x] * norm_factor
    

    temp = np.array(modulated, copy=True)
    for i in range(Kstep):
        power = np.absolute(temp)**2
        rotcoff = (L / Kstep) * gamma * power
        
        temp = temp * np.exp(1j*rotcoff) + sigma_n*(np.random.randn(len(x)) + 1j*np.random.randn(len(x)))
    return temp

Helper function to compute Bit Error Rate (BER)


In [3]:
# helper function to compute the symbol error rate
def SER(predictions, labels):
    return (np.sum(np.argmax(predictions, 1) != labels) / predictions.shape[0])

Here, we define the parameters of the neural network and training, generate the validation set and a helping set to show the decision regions


In [4]:
# set input power
Pin = 4
input_power_linear = 10**((Pin-30)/10)

# number of points in constellation
M = 16


# validation set. Training examples are generated on the fly
N_valid = 100000

# number of neurons in hidden layers at transmitter
hidden_neurons_TX_1 = 50
hidden_neurons_TX_2 = 50
hidden_neurons_TX_3 = 50
hidden_neurons_TX_4 = 50

# number of neurons in hidden layers at receiver
hidden_neurons_RX_1 = 50
hidden_neurons_RX_2 = 50
hidden_neurons_RX_3 = 50
hidden_neurons_RX_4 = 50



y_valid = np.random.randint(M,size=N_valid)
y_valid_onehot = np.eye(M)[y_valid]

# meshgrid for plotting
# assume that the worst case constellation is the one where all points lie on a straight line starting at the center and then are spreaded equidistantly. In this case, this is the scaling factor of the constellation points and we assume that there is an (M+1)th point which defines ext_max 
ext_max = np.sqrt(M*input_power_linear)
mgx,mgy = np.meshgrid(np.linspace(-ext_max,ext_max,400), np.linspace(-ext_max,ext_max,400))
meshgrid = np.column_stack((np.reshape(mgx,(-1,1)),np.reshape(mgy,(-1,1))))

This is the main function of TensorFlow that generates the computation graph. We have a single interface to the outside (a tf.placeholder which is the batch size. Here the idea is to vary the batch size during training. In the first iterations, we start with a small batch size to rapidly get to a working solution. The closer we come towards the end of the training we increase the batch size. If keeping the abtch size small, it may happen that there are no misclassifications in a small batch and there is no incentive of the training to improve. A larger batch size will most likely contain errors in the batch and hence there will be incentive to keep on training and improving.

Here, the data is generated on the fly inside the graph, by using TensorFlows random number generation. As TensorFlow does not natively support complex numbers (at least in early versions), we decided to replace the complex number operations in the channel by a simple rotation matrix and treating real and imaginary parts separately.

We use the ELU activation function inside the neural network and employ the Adam optimization algorithm.


In [5]:
# generate graph
graph = tf.Graph()

with graph.as_default():    
    # placeholder for training data (passed from external)
    tf_batch_size = tf.placeholder(tf.int32, shape=())

    # the validation dataset, we only have labels
    tf_valid_labels = tf.constant(y_valid_onehot, dtype=tf.float32)

    # temporary identity matrix for one-hot vector conversion
    tf_onehot_conversion = tf.constant(np.eye(M), dtype=tf.float32)
    
    # the mesgrid for plotting the decision region
    tf_meshgrid = tf.constant(meshgrid, dtype=tf.float32)
    
    # define weights
    Ws = {
        'T1' : tf.Variable(tf.truncated_normal([M,hidden_neurons_TX_1], stddev=0.8)),
        'T2' : tf.Variable(tf.truncated_normal([hidden_neurons_TX_1, hidden_neurons_TX_2], stddev=0.8)),
        'T3' : tf.Variable(tf.truncated_normal([hidden_neurons_TX_2, hidden_neurons_TX_3], stddev=0.8)),
        'T4' : tf.Variable(tf.truncated_normal([hidden_neurons_TX_3, hidden_neurons_TX_4], stddev=0.8)),
        'T5' : tf.Variable(tf.truncated_normal([hidden_neurons_TX_4, 2], stddev=0.8)),
        'R1' : tf.Variable(tf.truncated_normal([2,hidden_neurons_RX_1], stddev=0.8)),
        'R2' : tf.Variable(tf.truncated_normal([hidden_neurons_RX_1, hidden_neurons_RX_2], stddev=0.8)),
        'R3' : tf.Variable(tf.truncated_normal([hidden_neurons_RX_2, hidden_neurons_RX_3], stddev=0.8)),
        'R4' : tf.Variable(tf.truncated_normal([hidden_neurons_RX_3, hidden_neurons_RX_4], stddev=0.8)),
        'R5' : tf.Variable(tf.truncated_normal([hidden_neurons_RX_4, M], stddev=0.8))
        }

    bs = {
        'T1' : tf.Variable(tf.truncated_normal([hidden_neurons_TX_1], stddev=0.8)),
        'T2' : tf.Variable(tf.truncated_normal([hidden_neurons_TX_2], stddev=0.8)),
        'T3' : tf.Variable(tf.truncated_normal([hidden_neurons_TX_3], stddev=0.8)),
        'T4' : tf.Variable(tf.truncated_normal([hidden_neurons_TX_4], stddev=0.8)),
        'T5' : tf.Variable(tf.truncated_normal([2], stddev=0.8)),
        'R1' : tf.Variable(tf.truncated_normal([hidden_neurons_RX_1], stddev=0.8)),
        'R2' : tf.Variable(tf.truncated_normal([hidden_neurons_RX_2], stddev=0.8)),
        'R3' : tf.Variable(tf.truncated_normal([hidden_neurons_RX_3], stddev=0.8)),
        'R4' : tf.Variable(tf.truncated_normal([hidden_neurons_RX_4], stddev=0.8)),
        'R5' : tf.Variable(tf.truncated_normal([M], stddev=0.8))
    }
        
    

    def network_transmitter(batch_labels):
        nn = tf.nn.elu(tf.matmul(batch_labels, Ws['T1'])+bs['T1'])
        nn = tf.nn.elu(tf.matmul(nn, Ws['T2'])+bs['T2'])
        nn = tf.nn.elu(tf.matmul(nn, Ws['T3'])+bs['T3'])
        nn = tf.nn.elu(tf.matmul(nn, Ws['T4'])+bs['T4'])
        nn = tf.matmul(nn, Ws['T5'])+bs['T5']
        return nn

    def network_receiver(inp):
        nn = tf.nn.elu(tf.matmul(inp, Ws['R1'])+bs['R1'])
        nn = tf.nn.elu(tf.matmul(nn, Ws['R2'])+bs['R2'])
        nn = tf.nn.elu(tf.matmul(nn, Ws['R3'])+bs['R3'])
        nn = tf.nn.elu(tf.matmul(nn, Ws['R4'])+bs['R4'])
        logits = tf.matmul(nn, Ws['R5'])+bs['R5']
        return logits
    
    def channel_model(modulated):
        # simulate the channel
        for i in range(Kstep):
            power = tf.norm(modulated, axis=1) ** 2
            rotcoff = (L / Kstep) * gamma * power
            
            # rotation matrix corresponding to exp(1j*rotcoff)        
            temp = tf.stack([modulated[:,0] * tf.cos(rotcoff) - modulated[:,1]*tf.sin(rotcoff), modulated[:,0]*tf.sin(rotcoff)+modulated[:,1]*tf.cos(rotcoff)], axis=1)        
            modulated = temp + tf.random_normal(shape=(tf_batch_size,2), stddev=sigma_n)
        return modulated        
      
    def autoencoder(batch_labels):        
        # compute output
        encoded = network_transmitter(batch_labels)
        
        # compute normalization factor and normalize channel output
        norm_factor = tf.sqrt(tf.reduce_mean(tf.square(encoded)) / input_power_linear * 2 )                            
        modulated = encoded / norm_factor    
                
        received = channel_model(modulated)
        
        logits = network_receiver(received)
        return logits
    
    

    # generate random data
    batch_temp = tf.random_uniform(shape=[tf_batch_size], dtype=tf.int32, minval=0, maxval=M)
    # convert to one-hot representation
    batch_labels = tf.nn.embedding_lookup(tf_onehot_conversion, batch_temp)

    logits = autoencoder(batch_labels)

    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=batch_labels, logits=logits))    

    # use Adam optimizer
    optimizer = tf.train.AdamOptimizer().minimize(loss)
 
    # get constellation     
    constellation_unnormalized = network_transmitter(tf_onehot_conversion)
    norm_factor = tf.sqrt(tf.reduce_mean(tf.square(constellation_unnormalized)) / input_power_linear * 2 )
    constellation = constellation_unnormalized / norm_factor

    # compute channel output of validation and decision of validation
    valid_modulated = network_transmitter( tf_valid_labels) / norm_factor
    valid_received = channel_model(valid_modulated)        
    valid_prediction = tf.nn.softmax(network_receiver(valid_received))
    
        
    # mesh prediction for plotting     
    mesh_prediction = tf.nn.softmax(network_receiver(tf_meshgrid))


WARNING:tensorflow:From C:\Users\schmalen\Anaconda3\envs\Lecture_MLOC\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From C:\Users\schmalen\Anaconda3\envs\Lecture_MLOC\lib\site-packages\tensorflow\python\ops\math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.

Now, carry out the training as such. First initialize the variables and then loop through the training. Here, the epochs are not defined in the classical way, as we do not have a training set per se. We generate new data on the fly and never reuse data.


In [6]:
num_epochs = 150
batches_per_epoch = 350

# increase batch size while learning from 100 up to 10000
batch_size_per_epoch = np.linspace(100,10000,num=num_epochs)

validation_SERs = np.zeros(num_epochs)
validation_received = []
decision_region_evolution = []
constellations = []

with tf.Session(graph=graph) as session:
    # initialize variables
    tf.global_variables_initializer().run()

    print('Initialized')
    for epoch in range(num_epochs):
        for step in range(batches_per_epoch):
            feed_dict = {tf_batch_size : batch_size_per_epoch[epoch] }
            
            # run an optimization step
            _,l = session.run([optimizer, loss], feed_dict=feed_dict)
        
        # compute validation BER        
        valid_out = valid_prediction.eval(feed_dict={ tf_batch_size : N_valid })
        validation_SERs[epoch] = SER(valid_out, y_valid)
        print('Validation SER after epoch %d: %f (loss %f)' % (epoch, validation_SERs[epoch], l))                
        
        # store received validation data
        validation_received.append( valid_received.eval(feed_dict={ tf_batch_size : N_valid }) )
        
        # store constellation
        constellations.append(constellation.eval())
        # store decision region for generating the animation
        decision_region_evolution.append(mesh_prediction.eval())


Initialized
Validation SER after epoch 0: 0.716240 (loss 2.516677)
Validation SER after epoch 1: 0.631560 (loss 2.487401)
Validation SER after epoch 2: 0.613660 (loss 1.743685)
Validation SER after epoch 3: 0.621710 (loss 1.830276)
Validation SER after epoch 4: 0.502390 (loss 1.444139)
Validation SER after epoch 5: 0.408320 (loss 1.283352)
Validation SER after epoch 6: 0.378620 (loss 1.040426)
Validation SER after epoch 7: 0.377220 (loss 0.857601)
Validation SER after epoch 8: 0.181330 (loss 0.676158)
Validation SER after epoch 9: 0.135880 (loss 0.607676)
Validation SER after epoch 10: 0.146510 (loss 0.618175)
Validation SER after epoch 11: 0.044190 (loss 0.501843)
Validation SER after epoch 12: 0.044770 (loss 0.304340)
Validation SER after epoch 13: 0.027560 (loss 0.184089)
Validation SER after epoch 14: 0.025630 (loss 0.730079)
Validation SER after epoch 15: 0.019450 (loss 0.341600)
Validation SER after epoch 16: 0.014230 (loss 0.144689)
Validation SER after epoch 17: 0.025040 (loss 0.110057)
Validation SER after epoch 18: 0.012110 (loss 0.073357)
Validation SER after epoch 19: 0.011400 (loss 0.200293)
Validation SER after epoch 20: 0.013340 (loss 0.054190)
Validation SER after epoch 21: 0.006510 (loss 0.051744)
Validation SER after epoch 22: 0.015090 (loss 0.051730)
Validation SER after epoch 23: 0.005770 (loss 0.041049)
Validation SER after epoch 24: 0.006480 (loss 0.063838)
Validation SER after epoch 25: 0.004950 (loss 0.063035)
Validation SER after epoch 26: 0.013100 (loss 0.049497)
Validation SER after epoch 27: 0.007920 (loss 0.028723)
Validation SER after epoch 28: 0.005530 (loss 0.022271)
Validation SER after epoch 29: 0.007060 (loss 0.016713)
Validation SER after epoch 30: 0.004650 (loss 0.031856)
Validation SER after epoch 31: 0.004360 (loss 0.018834)
Validation SER after epoch 32: 0.004050 (loss 0.144191)
Validation SER after epoch 33: 0.003740 (loss 0.021273)
Validation SER after epoch 34: 0.003310 (loss 0.019285)
Validation SER after epoch 35: 0.008870 (loss 0.111871)
Validation SER after epoch 36: 0.003120 (loss 0.014101)
Validation SER after epoch 37: 0.003340 (loss 0.018830)
Validation SER after epoch 38: 0.003340 (loss 0.016047)
Validation SER after epoch 39: 0.002330 (loss 0.013850)
Validation SER after epoch 40: 0.003080 (loss 0.014429)
Validation SER after epoch 41: 0.003330 (loss 0.036381)
Validation SER after epoch 42: 0.002150 (loss 0.011272)
Validation SER after epoch 43: 0.002480 (loss 0.009109)
Validation SER after epoch 44: 0.003730 (loss 0.014735)
Validation SER after epoch 45: 0.002530 (loss 0.008038)
Validation SER after epoch 46: 0.117630 (loss 0.215405)
Validation SER after epoch 47: 0.001910 (loss 0.010444)
Validation SER after epoch 48: 0.001810 (loss 0.007620)
Validation SER after epoch 49: 0.002120 (loss 0.007741)
Validation SER after epoch 50: 0.002080 (loss 0.013705)
Validation SER after epoch 51: 0.004280 (loss 0.016625)
Validation SER after epoch 52: 0.001990 (loss 0.017481)
Validation SER after epoch 53: 0.010570 (loss 0.098580)
Validation SER after epoch 54: 0.001740 (loss 0.008400)
Validation SER after epoch 55: 0.001870 (loss 0.003082)
Validation SER after epoch 56: 0.001890 (loss 0.008551)
Validation SER after epoch 57: 0.001710 (loss 0.006484)
Validation SER after epoch 58: 0.001640 (loss 0.005270)
Validation SER after epoch 59: 0.002180 (loss 0.009786)
Validation SER after epoch 60: 0.001740 (loss 0.004265)
Validation SER after epoch 61: 0.001620 (loss 0.008002)
Validation SER after epoch 62: 0.001610 (loss 0.007249)
Validation SER after epoch 63: 0.001340 (loss 0.005359)
Validation SER after epoch 64: 0.005190 (loss 0.037234)
Validation SER after epoch 65: 0.001700 (loss 0.009232)
Validation SER after epoch 66: 0.001600 (loss 0.009064)
Validation SER after epoch 67: 0.001460 (loss 0.006548)
Validation SER after epoch 68: 0.001320 (loss 0.009909)
Validation SER after epoch 69: 0.001290 (loss 0.004125)
Validation SER after epoch 70: 0.001100 (loss 0.014171)
Validation SER after epoch 71: 0.001260 (loss 0.002350)
Validation SER after epoch 72: 0.001260 (loss 0.004207)
Validation SER after epoch 73: 0.001610 (loss 0.007955)
Validation SER after epoch 74: 0.001030 (loss 0.006916)
Validation SER after epoch 75: 0.001520 (loss 0.006578)
Validation SER after epoch 76: 0.001200 (loss 0.007569)
Validation SER after epoch 77: 0.004190 (loss 0.009838)
Validation SER after epoch 78: 0.002150 (loss 0.011814)
Validation SER after epoch 79: 0.001740 (loss 0.007131)
Validation SER after epoch 80: 0.001370 (loss 0.005167)
Validation SER after epoch 81: 0.001380 (loss 0.015208)
Validation SER after epoch 82: 0.001420 (loss 0.007476)
Validation SER after epoch 83: 0.001930 (loss 0.012504)
Validation SER after epoch 84: 0.001060 (loss 0.009343)
Validation SER after epoch 85: 0.001220 (loss 0.005221)
Validation SER after epoch 86: 0.000980 (loss 0.014299)
Validation SER after epoch 87: 0.001290 (loss 0.003563)
Validation SER after epoch 88: 0.001660 (loss 0.009115)
Validation SER after epoch 89: 0.000850 (loss 0.005638)
Validation SER after epoch 90: 0.000650 (loss 0.001566)
Validation SER after epoch 91: 0.001380 (loss 0.015645)
Validation SER after epoch 92: 0.000790 (loss 0.003700)
Validation SER after epoch 93: 0.001090 (loss 0.007435)
Validation SER after epoch 94: 0.001030 (loss 0.002376)
Validation SER after epoch 95: 0.000870 (loss 0.004296)
Validation SER after epoch 96: 0.000530 (loss 0.006039)
Validation SER after epoch 97: 0.001110 (loss 0.004910)
Validation SER after epoch 98: 0.000600 (loss 0.002956)
Validation SER after epoch 99: 0.000720 (loss 0.004183)
Validation SER after epoch 100: 0.001130 (loss 0.002196)
Validation SER after epoch 101: 0.001130 (loss 0.004830)
Validation SER after epoch 102: 0.000530 (loss 0.001334)
Validation SER after epoch 103: 0.000610 (loss 0.003733)
Validation SER after epoch 104: 0.000880 (loss 0.002156)
Validation SER after epoch 105: 0.001480 (loss 0.010643)
Validation SER after epoch 106: 0.000690 (loss 0.004530)
Validation SER after epoch 107: 0.001170 (loss 0.004154)
Validation SER after epoch 108: 0.000760 (loss 0.001755)
Validation SER after epoch 109: 0.000630 (loss 0.003229)
Validation SER after epoch 110: 0.000610 (loss 0.003633)
Validation SER after epoch 111: 0.000390 (loss 0.003466)
Validation SER after epoch 112: 0.000610 (loss 0.003205)
Validation SER after epoch 113: 0.000500 (loss 0.003290)
Validation SER after epoch 114: 0.000530 (loss 0.000764)
Validation SER after epoch 115: 0.000540 (loss 0.002453)
Validation SER after epoch 116: 0.000540 (loss 0.001602)
Validation SER after epoch 117: 0.000770 (loss 0.007057)
Validation SER after epoch 118: 0.000550 (loss 0.001477)
Validation SER after epoch 119: 0.005200 (loss 0.024063)
Validation SER after epoch 120: 0.000630 (loss 0.002571)
Validation SER after epoch 121: 0.000580 (loss 0.003148)
Validation SER after epoch 122: 0.000560 (loss 0.004809)
Validation SER after epoch 123: 0.000690 (loss 0.003645)
Validation SER after epoch 124: 0.000460 (loss 0.001970)
Validation SER after epoch 125: 0.000540 (loss 0.001669)
Validation SER after epoch 126: 0.000620 (loss 0.003454)
Validation SER after epoch 127: 0.000720 (loss 0.003353)
Validation SER after epoch 128: 0.000490 (loss 0.003121)
Validation SER after epoch 129: 0.000570 (loss 0.002673)
Validation SER after epoch 130: 0.000810 (loss 0.002050)
Validation SER after epoch 131: 0.000920 (loss 0.004932)
Validation SER after epoch 132: 0.000530 (loss 0.002509)
Validation SER after epoch 133: 0.000670 (loss 0.003383)
Validation SER after epoch 134: 0.003560 (loss 0.010201)
Validation SER after epoch 135: 0.000860 (loss 0.002696)
Validation SER after epoch 136: 0.000510 (loss 0.002132)
Validation SER after epoch 137: 0.000500 (loss 0.001366)
Validation SER after epoch 138: 0.000390 (loss 0.002112)
Validation SER after epoch 139: 0.000440 (loss 0.001115)
Validation SER after epoch 140: 0.000480 (loss 0.001403)
Validation SER after epoch 141: 0.000400 (loss 0.002755)
Validation SER after epoch 142: 0.000470 (loss 0.001165)
Validation SER after epoch 143: 0.000610 (loss 0.003438)
Validation SER after epoch 144: 0.005160 (loss 0.057302)
Validation SER after epoch 145: 0.000430 (loss 0.003233)
Validation SER after epoch 146: 0.000320 (loss 0.001603)
Validation SER after epoch 147: 0.000450 (loss 0.001839)
Validation SER after epoch 148: 0.000350 (loss 0.001959)
Validation SER after epoch 149: 0.000330 (loss 0.003944)

Plt decision region and scatter plot of the validation set. Note that the validation set is only used for computing BERs and plotting, there is no feedback towards the training!


In [7]:
cmap = matplotlib.cm.tab20
base = plt.cm.get_cmap(cmap)
color_list = base.colors
new_color_list = [[t/2 + 0.5 for t in color_list[k]] for k in range(len(color_list))]

# find minimum SER from validation set
min_SER_iter = np.argmin(validation_SERs)
print('Minimum SER obtained: %1.5f' % validation_SERs[min_SER_iter])

ext_max_plot = 1.05*max(max(abs(validation_received[min_SER_iter][:,0])), max(abs(validation_received[min_SER_iter][:,1])))


Minimum SER obtained: 0.00032

In [8]:
%matplotlib inline
plt.figure(figsize=(19,6))
font = {'size'   : 14}
plt.rc('font', **font)
plt.rc('text', usetex=True)
    
plt.subplot(131)
plt.scatter(constellations[min_SER_iter][:,0], constellations[min_SER_iter][:,1], c=range(M), cmap='tab20',s=50)
plt.axis('scaled')
plt.xlabel(r'$\Re\{r\}$',fontsize=14)
plt.ylabel(r'$\Im\{r\}$',fontsize=14)
plt.xlim((-ext_max_plot,ext_max_plot))
plt.ylim((-ext_max_plot,ext_max_plot))
plt.grid(which='both')
plt.title('Constellation',fontsize=16)

plt.subplot(132)
#plt.contourf(mgx,mgy,decision_region_evolution[-1].reshape(mgy.shape).T,cmap='coolwarm',vmin=0.3,vmax=0.7)
plt.scatter(validation_received[min_SER_iter][:,0], validation_received[min_SER_iter][:,1], c=y_valid, cmap='tab20',s=4)
plt.axis('scaled')
plt.xlabel(r'$\Re\{r\}$',fontsize=14)
plt.ylabel(r'$\Im\{r\}$',fontsize=14)
plt.xlim((-ext_max_plot,ext_max_plot))
plt.ylim((-ext_max_plot,ext_max_plot))
plt.title('Received',fontsize=16)

plt.subplot(133)
decision_scatter = np.argmax(decision_region_evolution[min_SER_iter], 1)
plt.scatter(meshgrid[:,0], meshgrid[:,1], c=decision_scatter, cmap=matplotlib.colors.ListedColormap(colors=new_color_list),s=4)
plt.scatter(validation_received[min_SER_iter][0:4000,0], validation_received[min_SER_iter][0:4000,1], c=y_valid[0:4000], cmap='tab20',s=4)
plt.axis('scaled')
plt.xlim((-ext_max_plot,ext_max_plot))
plt.ylim((-ext_max_plot,ext_max_plot))
plt.xlabel(r'$\Re\{r\}$',fontsize=14)
plt.ylabel(r'$\Im\{r\}$',fontsize=14)
plt.title('Decision regions',fontsize=16)

#plt.savefig('decision_region_AE_Pin%d.pdf' % Pin,bbox_inches='tight')


Generate animation and save as a gif.


In [9]:
%matplotlib notebook
%matplotlib notebook
# Generate animation
from matplotlib import animation, rc
from matplotlib.animation import PillowWriter # Disable if you don't want to save any GIFs.

font = {'size'   : 18}
plt.rc('font', **font)

fig = plt.figure(figsize=(14,6))
ax1 = fig.add_subplot(1,2,1)
ax2 = fig.add_subplot(1,2,2)

ax1.axis('scaled')
ax2.axis('scaled')

written = False
def animate(i):
    ax1.clear()
    ax1.scatter(constellations[i][:,0], constellations[i][:,1], c=range(M), cmap='tab20',s=50)

    ax2.clear()
    #ax2.scatter([0,0.02],[0.02,0], c=[1,2], cmap='tab20',s=100)
    #decision_scatter = np.argmax(decision_region_evolution[i], 1)
    decision_scatter = np.argmax(decision_region_evolution[i], 1)
    ax2.scatter(meshgrid[:,0], meshgrid[:,1], c=decision_scatter, cmap=matplotlib.colors.ListedColormap(colors=new_color_list),s=4)
    ax2.scatter(validation_received[i][0:4000,0], validation_received[i][0:4000,1], c=y_valid[0:4000], cmap='tab20',s=4)
    
    #plt.scatter(meshgrid[:,0] * ext_max,meshgrid[:,1] * ext_max, c=decision_scatter, cmap=matplotlib.colors.ListedColormap(colors=new_color_list),s=4, marker='s')
    #plt.scatter(X_valid[0:4000,0]*ext_max, X_valid[0:4000,1]*ext_max, c=y_valid[0:4000], cmap='tab20',s=4)
    ax1.set_xlim(( -ext_max_plot, ext_max_plot))
    ax1.set_ylim(( -ext_max_plot, ext_max_plot))
    ax2.set_xlim(( -ext_max_plot, ext_max_plot))
    ax2.set_ylim(( -ext_max_plot, ext_max_plot))
    ax1.set_title('Constellation', fontsize=14)
    ax2.set_title('Decision regions', fontsize=14)
    
    ax1.set_xlabel(r'$\Re\{r\}$',fontsize=14)
    ax1.set_ylabel(r'$\Im\{r\}$',fontsize=14)
    ax2.set_xlabel(r'$\Re\{r\}$',fontsize=14)
    ax2.set_ylabel(r'$\Im\{r\}$',fontsize=14)

    
anim = animation.FuncAnimation(fig, animate, frames=min_SER_iter+1, interval=200, blit=False)
fig.show()
#anim.save('learning_decision_AE_Pin%d_varbatch.gif' % Pin, writer=PillowWriter(fps=5))



In [ ]: