BPSK Demodulation in Nonlinear Channels with Deep Neural Networks

This code is provided as supplementary material of the lecture Machine Learning and Optimization in Communications (MLOC).

This code illustrates:

  • demodulation of BPSK symbols in highly nonlinear channels using an artificial neural network

In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interactive
import ipywidgets as widgets
%matplotlib inline

Specify the parameters of the transmission as the fiber length $L$ (in km), the fiber nonlinearity coefficienty $\gamma$ (given in 1/W/km) and the total noise power $P_n$ (given in dBM. The noise is due to amplified spontaneous emission in amplifiers along the link). We assume a model of a dispersion-less fiber affected by nonlinearity. The model, which is described for instance in [1] is given by an iterative application of the equation $$ x_{k+1} = x_k\exp\left(\jmath\frac{L}{K}\gamma|x_k|^2\right) + n_{k+1},\qquad 0 \leq k < K $$ where $x_0$ is the channel input (the modulated, complex symbols) and $x_K$ is the channel output. $K$ denotes the number of steps taken to simulate the channel Usually $K=50$ gives a good approximation.

[1] S. Li, C. Häger, N. Garcia, and H. Wymeersch, "Achievable Information Rates for Nonlinear Fiber Communication via End-to-end Autoencoder Learning," Proc. ECOC, Rome, Sep. 2018


In [2]:
# Length of transmission (in km)
L = 5000

# fiber nonlinearity coefficient
gamma = 1.27

Pn = -21.3 # noise power (in dBm)

Kstep = 50 # number of steps used in the channel model

def simulate_channel(x, Pin):  
    # modulate bpsk
    input_power_linear = 10**((Pin-30)/10)
    norm_factor = np.sqrt(input_power_linear);
    bpsk = (1 - 2*x) * norm_factor
    

    # noise variance per step    
    sigma = np.sqrt((10**((Pn-30)/10)) / Kstep / 2)    

    temp = np.array(bpsk, copy=True)
    for i in range(Kstep):
        power = np.absolute(temp)**2
        rotcoff = (L / Kstep) * gamma * power
        
        temp = temp * np.exp(1j*rotcoff) + sigma*(np.random.randn(len(x)) + 1j*np.random.randn(len(x)))
    return temp

We consider BPSK transmission over this channel.

Show constellation as a function of the fiber input power. When the input power is small, the effect of the nonlinearity is small (as $\jmath\frac{L}{K}\gamma|x_k|^2 \approx 0$) and the transmission is dominated by the additive noise. If the input power becomes larger, the effect of the noise (the noise power is constant) becomes less pronounced, but the constellation rotates due to the larger input power and hence effect of the nonlinearity.


In [14]:
length = 5000

def plot_constellation(Pin):
    t = np.random.randint(2,size=length)
    r = simulate_channel(t, Pin)

    plt.figure(figsize=(6,6))
    font = {'size'   : 14}
    plt.rc('font', **font)
    plt.rc('text', usetex=True)
    plt.scatter(np.real(r), np.imag(r), c=t, cmap='coolwarm')
    plt.xlabel(r'$\Re\{r\}$',fontsize=14)
    plt.ylabel(r'$\Im\{r\}$',fontsize=14)
    plt.axis('equal')
    plt.title('Received constellation (L = %d km, $P_{in} = %1.2f$\,dBm)' % (L, Pin))    
    plt.savefig('bpsk_received_zd_%1.2f.pdf' % Pin,bbox_inches='tight')
    
interactive_update = interactive(plot_constellation, Pin = widgets.FloatSlider(min=-10.0,max=10.0,step=0.1,value=1, continuous_update=False, description='Input Power Pin (dBm)', style={'description_width': 'initial'}, layout=widgets.Layout(width='50%')))


output = interactive_update.children[-1]
output.layout.height = '500px'
interactive_update


Helper function to plot the constellation together with the decision region. Note that a bit is decided as "1" if $\sigma(\boldsymbol{\theta}^\mathrm{T}\boldsymbol{r}) > \frac12$, i.e., if $\boldsymbol{\theta}^\mathrm{T}\boldsymbol{r}$ > 0. The decision line is therefore given by $\theta_1\Re\{r\} + \theta_2\Im\{r\} = 0$, i.e., $\Im\{r\} = -\frac{\theta_1}{\theta_2}\Re\{r\}$

Generate training, validation and testing data sets


In [15]:
# helper function to compute the bit error rate
def BER(predictions, labels):
    decision = predictions >= 0.5
    return np.mean(decision != labels)

In [16]:
# set input power
Pin = 3

# validation set. Training examples are generated on the fly
N_valid = 100000

# mini-batch size
batch_size = 1000

hidden_neurons_1 = 8
hidden_neurons_2 = 14


y_valid = np.random.randint(2,size=N_valid)
r = simulate_channel(y_valid, Pin)

# find extension of data (for normalization and plotting)
ext_x = max(abs(np.real(r)))
ext_y = max(abs(np.imag(r)))
ext_max = max(ext_x,ext_y)*1.2

# scale data to be between 0 and 1
X_valid = np.column_stack((np.real(r), np.imag(r))) / ext_max


# meshgrid for plotting
mgx,mgy = np.meshgrid(np.linspace(-ext_max,ext_max,200), np.linspace(-ext_max,ext_max,200))
meshgrid = np.column_stack((np.reshape(mgx,(-1,1)),np.reshape(mgy,(-1,1)))) / ext_max

In [17]:
# generate graph
graph = tf.Graph()

with graph.as_default():    
    # placeholder for training data (passed from externally)
    tf_train_dataset = tf.placeholder(tf.float32, shape=(batch_size,2))
    tf_train_labels = tf.placeholder(tf.float32, shape=(batch_size))
    
    # the validation dataset
    tf_valid_dataset = tf.constant(X_valid, dtype=tf.float32)
    tf_valid_labels = tf.constant(y_valid, dtype=tf.float32)

    # the mesgrid for plotting the decision region
    tf_meshgrid = tf.constant(meshgrid, dtype=tf.float32)
    
    # define neural network by hand
    # assume 2 hidden layers with ReLU functions and a logistic output function
    W1 = tf.Variable(tf.truncated_normal([2,hidden_neurons_1], stddev=0.8)) 
    b1 = tf.Variable(tf.truncated_normal([hidden_neurons_1], stddev=0.8))
    
    W2 = tf.Variable(tf.truncated_normal([hidden_neurons_1,hidden_neurons_2], stddev=0.8))
    b2 = tf.Variable(tf.truncated_normal([hidden_neurons_2], stddev=0.8))
    
    W3 = tf.Variable(tf.truncated_normal([hidden_neurons_2,1], stddev=0.8))
    b3 = tf.Variable(tf.truncated_normal([1], stddev=0.8))
    
    def neural_network(inp):
        temp1 = tf.nn.elu(tf.matmul(inp, W1)+b1)
        temp2 = tf.nn.elu(tf.matmul(temp1, W2)+b2)
        # note that the sigmoid is not part of the neural network, but of the loss function later. This saves some complexity        
        # squeeze removes the extra dimension and makes the output scalar
        return tf.squeeze(tf.matmul(temp2, W3)+b3)
    
    #output of the neural network
    logits = neural_network(tf_train_dataset)
    loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf_train_labels, logits=logits))    

    # equivalent formulation without using tensorflow's intern function
    #loss = tf.reduce_mean(-tf_train_labels * tf.log(tf.sigmoid(logits)) - (1-tf_train_labels)*tf.log(1-tf.sigmoid(logits)))
    
    # use Adap optimizer
    optimizer = tf.train.AdamOptimizer().minimize(loss)
    
    # predictions for training, validation and test data    
    valid_prediction = tf.nn.sigmoid(neural_network(tf_valid_dataset))
        
    # mesh prediction for plotting
    mesh_prediction = tf.nn.sigmoid(neural_network(tf_meshgrid))

In [21]:
num_epochs = 300
batches_per_epoch = 150

validation_BERs = np.zeros(num_epochs)
decision_region_evolution = []

with tf.Session(graph=graph) as session:
    # initialize variables
    tf.global_variables_initializer().run()

    print('Initialized')
    for epoch in range(num_epochs):
        for step in range(batches_per_epoch):
            # sample new mini-batch 
            batch_labels =  np.random.randint(2,size=batch_size)
            r = simulate_channel(batch_labels, Pin)
            batch_data = np.column_stack((np.real(r), np.imag(r))) / ext_max
                        
            feed_dict = {tf_train_dataset : batch_data, tf_train_labels : batch_labels }
            
            # run an optimization step
            _ = session.run(optimizer, feed_dict=feed_dict)
        
        # compute validation BER
        valid_out = valid_prediction.eval()
        validation_BERs[epoch] = BER(valid_out, y_valid)
        print('Validation BER after epoch %d: %f' % (epoch, validation_BERs[epoch]))                
        
        # store decision region for generating the animation
        decision_region_evolution.append(0.195*mesh_prediction.eval() + 0.4)


Initialized
Validation BER after epoch 0: 0.023680
Validation BER after epoch 1: 0.014440
Validation BER after epoch 2: 0.013160
Validation BER after epoch 3: 0.012750
Validation BER after epoch 4: 0.012430
Validation BER after epoch 5: 0.012450
Validation BER after epoch 6: 0.012410
Validation BER after epoch 7: 0.012320
Validation BER after epoch 8: 0.012310
Validation BER after epoch 9: 0.012270
Validation BER after epoch 10: 0.012200
Validation BER after epoch 11: 0.011870
Validation BER after epoch 12: 0.012120
Validation BER after epoch 13: 0.012260
Validation BER after epoch 14: 0.012120
Validation BER after epoch 15: 0.012030
Validation BER after epoch 16: 0.012220
Validation BER after epoch 17: 0.012150
Validation BER after epoch 18: 0.011890
Validation BER after epoch 19: 0.011760
Validation BER after epoch 20: 0.011860
Validation BER after epoch 21: 0.011820
Validation BER after epoch 22: 0.011800
Validation BER after epoch 23: 0.011900
Validation BER after epoch 24: 0.011880
Validation BER after epoch 25: 0.011780
Validation BER after epoch 26: 0.012050
Validation BER after epoch 27: 0.011810
Validation BER after epoch 28: 0.011620
Validation BER after epoch 29: 0.011960
Validation BER after epoch 30: 0.011780
Validation BER after epoch 31: 0.011760
Validation BER after epoch 32: 0.011650
Validation BER after epoch 33: 0.011310
Validation BER after epoch 34: 0.011370
Validation BER after epoch 35: 0.011490
Validation BER after epoch 36: 0.011360
Validation BER after epoch 37: 0.011320
Validation BER after epoch 38: 0.011270
Validation BER after epoch 39: 0.010960
Validation BER after epoch 40: 0.010870
Validation BER after epoch 41: 0.010480
Validation BER after epoch 42: 0.010390
Validation BER after epoch 43: 0.010070
Validation BER after epoch 44: 0.009570
Validation BER after epoch 45: 0.009200
Validation BER after epoch 46: 0.008920
Validation BER after epoch 47: 0.008640
Validation BER after epoch 48: 0.008030
Validation BER after epoch 49: 0.008130
Validation BER after epoch 50: 0.007360
Validation BER after epoch 51: 0.007210
Validation BER after epoch 52: 0.007120
Validation BER after epoch 53: 0.007000
Validation BER after epoch 54: 0.006620
Validation BER after epoch 55: 0.006310
Validation BER after epoch 56: 0.006210
Validation BER after epoch 57: 0.005750
Validation BER after epoch 58: 0.005620
Validation BER after epoch 59: 0.005210
Validation BER after epoch 60: 0.004870
Validation BER after epoch 61: 0.004420
Validation BER after epoch 62: 0.003720
Validation BER after epoch 63: 0.003320
Validation BER after epoch 64: 0.002990
Validation BER after epoch 65: 0.002490
Validation BER after epoch 66: 0.002140
Validation BER after epoch 67: 0.001840
Validation BER after epoch 68: 0.001610
Validation BER after epoch 69: 0.001480
Validation BER after epoch 70: 0.001290
Validation BER after epoch 71: 0.001130
Validation BER after epoch 72: 0.001010
Validation BER after epoch 73: 0.000990
Validation BER after epoch 74: 0.000840
Validation BER after epoch 75: 0.000760
Validation BER after epoch 76: 0.000720
Validation BER after epoch 77: 0.000640
Validation BER after epoch 78: 0.000560
Validation BER after epoch 79: 0.000620
Validation BER after epoch 80: 0.000520
Validation BER after epoch 81: 0.000480
Validation BER after epoch 82: 0.000410
Validation BER after epoch 83: 0.000370
Validation BER after epoch 84: 0.000380
Validation BER after epoch 85: 0.000350
Validation BER after epoch 86: 0.000290
Validation BER after epoch 87: 0.000320
Validation BER after epoch 88: 0.000260
Validation BER after epoch 89: 0.000230
Validation BER after epoch 90: 0.000260
Validation BER after epoch 91: 0.000190
Validation BER after epoch 92: 0.000220
Validation BER after epoch 93: 0.000200
Validation BER after epoch 94: 0.000160
Validation BER after epoch 95: 0.000160
Validation BER after epoch 96: 0.000140
Validation BER after epoch 97: 0.000180
Validation BER after epoch 98: 0.000160
Validation BER after epoch 99: 0.000130
Validation BER after epoch 100: 0.000150
Validation BER after epoch 101: 0.000120
Validation BER after epoch 102: 0.000130
Validation BER after epoch 103: 0.000130
Validation BER after epoch 104: 0.000160
Validation BER after epoch 105: 0.000110
Validation BER after epoch 106: 0.000120
Validation BER after epoch 107: 0.000120
Validation BER after epoch 108: 0.000120
Validation BER after epoch 109: 0.000120
Validation BER after epoch 110: 0.000090
Validation BER after epoch 111: 0.000110
Validation BER after epoch 112: 0.000100
Validation BER after epoch 113: 0.000080
Validation BER after epoch 114: 0.000110
Validation BER after epoch 115: 0.000080
Validation BER after epoch 116: 0.000100
Validation BER after epoch 117: 0.000100
Validation BER after epoch 118: 0.000100
Validation BER after epoch 119: 0.000100
Validation BER after epoch 120: 0.000080
Validation BER after epoch 121: 0.000100
Validation BER after epoch 122: 0.000090
Validation BER after epoch 123: 0.000060
Validation BER after epoch 124: 0.000090
Validation BER after epoch 125: 0.000060
Validation BER after epoch 126: 0.000060
Validation BER after epoch 127: 0.000070
Validation BER after epoch 128: 0.000090
Validation BER after epoch 129: 0.000060
Validation BER after epoch 130: 0.000050
Validation BER after epoch 131: 0.000080
Validation BER after epoch 132: 0.000080
Validation BER after epoch 133: 0.000050
Validation BER after epoch 134: 0.000070
Validation BER after epoch 135: 0.000080
Validation BER after epoch 136: 0.000070
Validation BER after epoch 137: 0.000060
Validation BER after epoch 138: 0.000060
Validation BER after epoch 139: 0.000070
Validation BER after epoch 140: 0.000060
Validation BER after epoch 141: 0.000060
Validation BER after epoch 142: 0.000050
Validation BER after epoch 143: 0.000050
Validation BER after epoch 144: 0.000040
Validation BER after epoch 145: 0.000040
Validation BER after epoch 146: 0.000040
Validation BER after epoch 147: 0.000060
Validation BER after epoch 148: 0.000050
Validation BER after epoch 149: 0.000050
Validation BER after epoch 150: 0.000040
Validation BER after epoch 151: 0.000070
Validation BER after epoch 152: 0.000040
Validation BER after epoch 153: 0.000040
Validation BER after epoch 154: 0.000040
Validation BER after epoch 155: 0.000070
Validation BER after epoch 156: 0.000060
Validation BER after epoch 157: 0.000030
Validation BER after epoch 158: 0.000040
Validation BER after epoch 159: 0.000040
Validation BER after epoch 160: 0.000030
Validation BER after epoch 161: 0.000050
Validation BER after epoch 162: 0.000040
Validation BER after epoch 163: 0.000040
Validation BER after epoch 164: 0.000050
Validation BER after epoch 165: 0.000060
Validation BER after epoch 166: 0.000050
Validation BER after epoch 167: 0.000040
Validation BER after epoch 168: 0.000040
Validation BER after epoch 169: 0.000040
Validation BER after epoch 170: 0.000050
Validation BER after epoch 171: 0.000060
Validation BER after epoch 172: 0.000040
Validation BER after epoch 173: 0.000030
Validation BER after epoch 174: 0.000030
Validation BER after epoch 175: 0.000030
Validation BER after epoch 176: 0.000060
Validation BER after epoch 177: 0.000030
Validation BER after epoch 178: 0.000040
Validation BER after epoch 179: 0.000030
Validation BER after epoch 180: 0.000030
Validation BER after epoch 181: 0.000020
Validation BER after epoch 182: 0.000030
Validation BER after epoch 183: 0.000020
Validation BER after epoch 184: 0.000030
Validation BER after epoch 185: 0.000020
Validation BER after epoch 186: 0.000030
Validation BER after epoch 187: 0.000020
Validation BER after epoch 188: 0.000030
Validation BER after epoch 189: 0.000030
Validation BER after epoch 190: 0.000020
Validation BER after epoch 191: 0.000030
Validation BER after epoch 192: 0.000020
Validation BER after epoch 193: 0.000020
Validation BER after epoch 194: 0.000020
Validation BER after epoch 195: 0.000030
Validation BER after epoch 196: 0.000030
Validation BER after epoch 197: 0.000040
Validation BER after epoch 198: 0.000020
Validation BER after epoch 199: 0.000020
Validation BER after epoch 200: 0.000020
Validation BER after epoch 201: 0.000020
Validation BER after epoch 202: 0.000020
Validation BER after epoch 203: 0.000020
Validation BER after epoch 204: 0.000020
Validation BER after epoch 205: 0.000020
Validation BER after epoch 206: 0.000020
Validation BER after epoch 207: 0.000020
Validation BER after epoch 208: 0.000020
Validation BER after epoch 209: 0.000020
Validation BER after epoch 210: 0.000020
Validation BER after epoch 211: 0.000020
Validation BER after epoch 212: 0.000020
Validation BER after epoch 213: 0.000020
Validation BER after epoch 214: 0.000030
Validation BER after epoch 215: 0.000020
Validation BER after epoch 216: 0.000020
Validation BER after epoch 217: 0.000020
Validation BER after epoch 218: 0.000020
Validation BER after epoch 219: 0.000020
Validation BER after epoch 220: 0.000020
Validation BER after epoch 221: 0.000020
Validation BER after epoch 222: 0.000030
Validation BER after epoch 223: 0.000020
Validation BER after epoch 224: 0.000020
Validation BER after epoch 225: 0.000020
Validation BER after epoch 226: 0.000020
Validation BER after epoch 227: 0.000020
Validation BER after epoch 228: 0.000020
Validation BER after epoch 229: 0.000020
Validation BER after epoch 230: 0.000020
Validation BER after epoch 231: 0.000020
Validation BER after epoch 232: 0.000020
Validation BER after epoch 233: 0.000020
Validation BER after epoch 234: 0.000020
Validation BER after epoch 235: 0.000020
Validation BER after epoch 236: 0.000020
Validation BER after epoch 237: 0.000020
Validation BER after epoch 238: 0.000020
Validation BER after epoch 239: 0.000020
Validation BER after epoch 240: 0.000030
Validation BER after epoch 241: 0.000020
Validation BER after epoch 242: 0.000020
Validation BER after epoch 243: 0.000020
Validation BER after epoch 244: 0.000020
Validation BER after epoch 245: 0.000020
Validation BER after epoch 246: 0.000020
Validation BER after epoch 247: 0.000020
Validation BER after epoch 248: 0.000030
Validation BER after epoch 249: 0.000020
Validation BER after epoch 250: 0.000020
Validation BER after epoch 251: 0.000020
Validation BER after epoch 252: 0.000020
Validation BER after epoch 253: 0.000020
Validation BER after epoch 254: 0.000020
Validation BER after epoch 255: 0.000020
Validation BER after epoch 256: 0.000020
Validation BER after epoch 257: 0.000020
Validation BER after epoch 258: 0.000030
Validation BER after epoch 259: 0.000020
Validation BER after epoch 260: 0.000020
Validation BER after epoch 261: 0.000020
Validation BER after epoch 262: 0.000020
Validation BER after epoch 263: 0.000020
Validation BER after epoch 264: 0.000020
Validation BER after epoch 265: 0.000020
Validation BER after epoch 266: 0.000020
Validation BER after epoch 267: 0.000030
Validation BER after epoch 268: 0.000020
Validation BER after epoch 269: 0.000020
Validation BER after epoch 270: 0.000020
Validation BER after epoch 271: 0.000020
Validation BER after epoch 272: 0.000020
Validation BER after epoch 273: 0.000020
Validation BER after epoch 274: 0.000020
Validation BER after epoch 275: 0.000010
Validation BER after epoch 276: 0.000020
Validation BER after epoch 277: 0.000020
Validation BER after epoch 278: 0.000040
Validation BER after epoch 279: 0.000030
Validation BER after epoch 280: 0.000040
Validation BER after epoch 281: 0.000030
Validation BER after epoch 282: 0.000020
Validation BER after epoch 283: 0.000020
Validation BER after epoch 284: 0.000020
Validation BER after epoch 285: 0.000020
Validation BER after epoch 286: 0.000020
Validation BER after epoch 287: 0.000020
Validation BER after epoch 288: 0.000020
Validation BER after epoch 289: 0.000020
Validation BER after epoch 290: 0.000020
Validation BER after epoch 291: 0.000020
Validation BER after epoch 292: 0.000020
Validation BER after epoch 293: 0.000020
Validation BER after epoch 294: 0.000020
Validation BER after epoch 295: 0.000020
Validation BER after epoch 296: 0.000020
Validation BER after epoch 297: 0.000030
Validation BER after epoch 298: 0.000020
Validation BER after epoch 299: 0.000020

In [22]:
plt.figure(figsize=(8,8))
plt.contourf(mgx,mgy,decision_region_evolution[-1].reshape(mgy.shape).T,cmap='coolwarm',vmin=0.3,vmax=0.695)
plt.scatter(X_valid[:,0]*ext_max, X_valid[:,1] * ext_max, c=y_valid, cmap='coolwarm')
print(Pin)
plt.axis('scaled')
plt.xlabel(r'$\Re\{r\}$',fontsize=16)
plt.ylabel(r'$\Im\{r\}$',fontsize=16)
#plt.title(title,fontsize=16)
#plt.savefig('after_optimization.pdf',bbox_inches='tight')


3
Out[22]:
Text(0, 0.5, '$\\Im\\{r\\}$')