BPSK Demodulation in Nonlinear Channels with Deep Neural Networks

This code is provided as supplementary material of the lecture Machine Learning and Optimization in Communications (MLOC).

This code illustrates:

  • demodulation of BPSK symbols in highly nonlinear channels using an artificial neural network, implemented via PyTorch

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interactive
import ipywidgets as widgets
%matplotlib inline 

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("We are using the following device for learning:",device)


We are using the following device for learning: cuda

Specify the parameters of the transmission as the fiber length $L$ (in km), the fiber nonlinearity coefficienty $\gamma$ (given in 1/W/km) and the total noise power $P_n$ (given in dBM. The noise is due to amplified spontaneous emission in amplifiers along the link). We assume a model of a dispersion-less fiber affected by nonlinearity. The model, which is described for instance in [1] is given by an iterative application of the equation $$ x_{k+1} = x_k\exp\left(\jmath\frac{L}{K}\gamma|x_k|^2\right) + n_{k+1},\qquad 0 \leq k < K $$ where $x_0$ is the channel input (the modulated, complex symbols) and $x_K$ is the channel output. $K$ denotes the number of steps taken to simulate the channel Usually $K=50$ gives a good approximation.

[1] S. Li, C. Häger, N. Garcia, and H. Wymeersch, "Achievable Information Rates for Nonlinear Fiber Communication via End-to-end Autoencoder Learning," Proc. ECOC, Rome, Sep. 2018


In [25]:
# Length of transmission (in km)
L = 5000

# fiber nonlinearity coefficient
gamma = 1.27

Pn = -21.3 # noise power (in dBm)

Kstep = 50 # number of steps used in the channel model

def simulate_channel(x, Pin):  
    # modulate bpsk
    input_power_linear = 10**((Pin-30)/10)
    norm_factor = np.sqrt(input_power_linear);
    bpsk = (1 - 2*x) * norm_factor

    # noise variance per step    
    sigma = np.sqrt((10**((Pn-30)/10)) / Kstep / 2)    

    temp = np.array(bpsk, copy=True)
    for i in range(Kstep):
        power = np.absolute(temp)**2
        rotcoff = (L / Kstep) * gamma * power
        temp = temp * np.exp(1j*rotcoff) + sigma*(np.random.randn(len(x)) + 1j*np.random.randn(len(x)))
    return temp

We consider BPSK transmission over this channel.

Show constellation as a function of the fiber input power. When the input power is small, the effect of the nonlinearity is small (as $\jmath\frac{L}{K}\gamma|x_k|^2 \approx 0$) and the transmission is dominated by the additive noise. If the input power becomes larger, the effect of the noise (the noise power is constant) becomes less pronounced, but the constellation rotates due to the larger input power and hence effect of the nonlinearity.


In [26]:
length = 5000

def plot_constellation(Pin):
    t = np.random.randint(2,size=length)
    r = simulate_channel(t, Pin)

    plt.figure(figsize=(6,6))
    font = {'size'   : 14}
    plt.rc('font', **font)
    plt.rc('text', usetex=True)
    plt.scatter(np.real(r), np.imag(r), c=t, cmap='coolwarm')
    plt.xlabel(r'$\Re\{r\}$',fontsize=14)
    plt.ylabel(r'$\Im\{r\}$',fontsize=14)
    plt.axis('equal')
    plt.title('Received constellation (L = %d km, $P_{in} = %1.2f$\,dBm)' % (L, Pin))    
    #plt.savefig('bpsk_received_zd_%1.2f.pdf' % Pin,bbox_inches='tight')
    
interactive_update = interactive(plot_constellation, Pin = widgets.FloatSlider(min=-10.0,max=10.0,step=0.1,value=1, continuous_update=False, description='Input Power Pin (dBm)', style={'description_width': 'initial'}, layout=widgets.Layout(width='50%')))


output = interactive_update.children[-1]
output.layout.height = '500px'
interactive_update


Helper function to plot the constellation together with the decision region. Note that a bit is decided as "1" if $\sigma(\boldsymbol{\theta}^\mathrm{T}\boldsymbol{r}) > \frac12$, i.e., if $\boldsymbol{\theta}^\mathrm{T}\boldsymbol{r}$ > 0. The decision line is therefore given by $\theta_1\Re\{r\} + \theta_2\Im\{r\} = 0$, i.e., $\Im\{r\} = -\frac{\theta_1}{\theta_2}\Re\{r\}$

Generate training, validation and testing data sets


In [4]:
# helper function to compute the bit error rate
def BER(predictions, labels):
    decision = predictions >= 0.5
    temp = decision != (labels != 0)
    return np.mean(temp)

In [5]:
# set input power
Pin = 3

# validation set. Training examples are generated on the fly
N_valid = 100000


hidden_neurons_1 = 8
hidden_neurons_2 = 14


y_valid = np.random.randint(2,size=N_valid)
r = simulate_channel(y_valid, Pin)

# find extension of data (for normalization and plotting)
ext_x = max(abs(np.real(r)))
ext_y = max(abs(np.imag(r)))
ext_max = max(ext_x,ext_y)*1.2

# scale data to be between 0 and 1
X_valid = torch.from_numpy(np.column_stack((np.real(r), np.imag(r))) / ext_max).float().to(device)


# meshgrid for plotting
mgx,mgy = np.meshgrid(np.linspace(-ext_max,ext_max,200), np.linspace(-ext_max,ext_max,200))
meshgrid = torch.from_numpy(np.column_stack((np.reshape(mgx,(-1,1)),np.reshape(mgy,(-1,1)))) / ext_max).float().to(device)

In [7]:
class Receiver_Network(nn.Module):
    def __init__(self, hidden1_neurons, hidden2_neurons):
        super(Receiver_Network, self).__init__()
        # Linear function, 2 input neurons (real and imaginary part)        
        self.fc1 = nn.Linear(2, hidden1_neurons) 

        # Non-linearity
        self.activation_function = nn.ELU()
       
        # Linear function (hidden layer)
        self.fc2 = nn.Linear(hidden1_neurons, hidden2_neurons)  
        
        # Output function 
        self.fc3 = nn.Linear(hidden2_neurons, 1)
        

    def forward(self, x):
        # Linear function, first layer
        out = self.fc1(x)

        # Non-linearity, first layer
        out = self.activation_function(out)
        
        # Linear function, second layer
        out = self.fc2(out)
        
        # Non-linearity, second layer
        out = self.activation_function(out)
        
        # Linear function, third layer
        out = self.fc3(out)
              
        return out

In [29]:
model = Receiver_Network(hidden_neurons_1, hidden_neurons_2)
model.to(device)

sigmoid = nn.Sigmoid()


# channel parameters
norm_factor = np.sqrt(10**((Pin-30)/10));
sigma = np.sqrt((10**((Pn-30)/10)) / Kstep / 2)

# Binary Cross Entropy loss
loss_fn = nn.BCEWithLogitsLoss()

# Adam Optimizer
optimizer = optim.Adam(model.parameters())  


# Training parameters
num_epochs = 160
batches_per_epoch = 300

# Vary batch size during training
batch_size_per_epoch = np.linspace(100,10000,num=num_epochs)


validation_BERs = np.zeros(num_epochs)
decision_region_evolution = []

for epoch in range(num_epochs):
    batch_labels = torch.empty(int(batch_size_per_epoch[epoch]), device=device)
    noise = torch.empty((int(batch_size_per_epoch[epoch]),2), device=device, requires_grad=False)        

    for step in range(batches_per_epoch):
        # sample new mini-batch directory on the GPU (if available)        
        batch_labels.random_(2)
        # channel simulation directly on the GPU
        bpsk = ((1 - 2*batch_labels) * norm_factor).unsqueeze(-1) * torch.tensor([1.0,0.0],device=device)

        for i in range(Kstep):
            power = torch.norm(bpsk, dim=1) ** 2
            rotcoff = (L / Kstep) * gamma * power
            noise.normal_(mean=0, std=sigma) # sample noise
            
            # phase rotation due to nonlinearity
            temp1 = bpsk[:,0] * torch.cos(rotcoff) - bpsk[:,1] * torch.sin(rotcoff)            
            temp2 = bpsk[:,0] * torch.sin(rotcoff) + bpsk[:,1] * torch.cos(rotcoff)            
            bpsk = torch.stack([temp1, temp2], dim=1) + noise

        bpsk = bpsk / ext_max
        outputs = model(bpsk)

        # compute loss
        loss = loss_fn(outputs.squeeze(), batch_labels)
        
        # compute gradients
        loss.backward()
        
        optimizer.step()
        # reset gradients
        optimizer.zero_grad()
        
    # compute validation BER
    out_valid = sigmoid(model(X_valid))
    validation_BERs[epoch] = BER(out_valid.detach().cpu().numpy().squeeze(), y_valid)
    
    print('Validation BER after epoch %d: %f (loss %1.8f)' % (epoch, validation_BERs[epoch], loss.detach().cpu().numpy()))                
        
    # store decision region for generating the animation
    mesh_prediction = sigmoid(model(meshgrid))    
    decision_region_evolution.append(0.195*mesh_prediction.detach().cpu().numpy() + 0.4)


Validation BER after epoch 0: 0.011900 (loss 0.04081550)
Validation BER after epoch 1: 0.011880 (loss 0.06237675)
Validation BER after epoch 2: 0.011880 (loss 0.05860055)
Validation BER after epoch 3: 0.011930 (loss 0.03137892)
Validation BER after epoch 4: 0.011840 (loss 0.02328396)
Validation BER after epoch 5: 0.011940 (loss 0.01498417)
Validation BER after epoch 6: 0.011830 (loss 0.03882520)
Validation BER after epoch 7: 0.011960 (loss 0.03213685)
Validation BER after epoch 8: 0.011820 (loss 0.03174237)
Validation BER after epoch 9: 0.011920 (loss 0.02231092)
Validation BER after epoch 10: 0.011950 (loss 0.02508574)
Validation BER after epoch 11: 0.011820 (loss 0.03134822)
Validation BER after epoch 12: 0.011750 (loss 0.03832325)
Validation BER after epoch 13: 0.011900 (loss 0.02572788)
Validation BER after epoch 14: 0.011720 (loss 0.04637164)
Validation BER after epoch 15: 0.011770 (loss 0.02873223)
Validation BER after epoch 16: 0.011630 (loss 0.03887529)
Validation BER after epoch 17: 0.011840 (loss 0.02333896)
Validation BER after epoch 18: 0.011610 (loss 0.03756538)
Validation BER after epoch 19: 0.011410 (loss 0.03302569)
Validation BER after epoch 20: 0.011310 (loss 0.03760448)
Validation BER after epoch 21: 0.010950 (loss 0.02140117)
Validation BER after epoch 22: 0.010600 (loss 0.03338458)
Validation BER after epoch 23: 0.009930 (loss 0.03605200)
Validation BER after epoch 24: 0.009130 (loss 0.02889551)
Validation BER after epoch 25: 0.008240 (loss 0.02857891)
Validation BER after epoch 26: 0.007150 (loss 0.01911181)
Validation BER after epoch 27: 0.006310 (loss 0.02833870)
Validation BER after epoch 28: 0.005170 (loss 0.02069176)
Validation BER after epoch 29: 0.004250 (loss 0.01209422)
Validation BER after epoch 30: 0.003070 (loss 0.01115703)
Validation BER after epoch 31: 0.002530 (loss 0.00821617)
Validation BER after epoch 32: 0.002070 (loss 0.00696379)
Validation BER after epoch 33: 0.001470 (loss 0.00896005)
Validation BER after epoch 34: 0.001170 (loss 0.01304166)
Validation BER after epoch 35: 0.000950 (loss 0.01505199)
Validation BER after epoch 36: 0.000890 (loss 0.01122363)
Validation BER after epoch 37: 0.000830 (loss 0.00267421)
Validation BER after epoch 38: 0.000700 (loss 0.00351502)
Validation BER after epoch 39: 0.000740 (loss 0.00236395)
Validation BER after epoch 40: 0.000670 (loss 0.00355382)
Validation BER after epoch 41: 0.000650 (loss 0.00271674)
Validation BER after epoch 42: 0.000620 (loss 0.00226128)
Validation BER after epoch 43: 0.000570 (loss 0.00334926)
Validation BER after epoch 44: 0.000470 (loss 0.01259388)
Validation BER after epoch 45: 0.000450 (loss 0.00130147)
Validation BER after epoch 46: 0.000390 (loss 0.00182487)
Validation BER after epoch 47: 0.000350 (loss 0.00414351)
Validation BER after epoch 48: 0.000280 (loss 0.00183806)
Validation BER after epoch 49: 0.000270 (loss 0.00908718)
Validation BER after epoch 50: 0.000250 (loss 0.00579314)
Validation BER after epoch 51: 0.000230 (loss 0.00231978)
Validation BER after epoch 52: 0.000210 (loss 0.00087539)
Validation BER after epoch 53: 0.000180 (loss 0.00404366)
Validation BER after epoch 54: 0.000150 (loss 0.00058311)
Validation BER after epoch 55: 0.000150 (loss 0.00061291)
Validation BER after epoch 56: 0.000130 (loss 0.00044722)
Validation BER after epoch 57: 0.000100 (loss 0.00585709)
Validation BER after epoch 58: 0.000090 (loss 0.00067960)
Validation BER after epoch 59: 0.000100 (loss 0.00044337)
Validation BER after epoch 60: 0.000090 (loss 0.00035334)
Validation BER after epoch 61: 0.000070 (loss 0.00148391)
Validation BER after epoch 62: 0.000070 (loss 0.00037881)
Validation BER after epoch 63: 0.000070 (loss 0.00035939)
Validation BER after epoch 64: 0.000070 (loss 0.00143160)
Validation BER after epoch 65: 0.000070 (loss 0.00417593)
Validation BER after epoch 66: 0.000070 (loss 0.00020876)
Validation BER after epoch 67: 0.000070 (loss 0.00017361)
Validation BER after epoch 68: 0.000070 (loss 0.00029100)
Validation BER after epoch 69: 0.000070 (loss 0.00015851)
Validation BER after epoch 70: 0.000070 (loss 0.00011071)
Validation BER after epoch 71: 0.000070 (loss 0.00153271)
Validation BER after epoch 72: 0.000070 (loss 0.00010566)
Validation BER after epoch 73: 0.000060 (loss 0.00013147)
Validation BER after epoch 74: 0.000070 (loss 0.00009230)
Validation BER after epoch 75: 0.000060 (loss 0.00009838)
Validation BER after epoch 76: 0.000060 (loss 0.00010027)
Validation BER after epoch 77: 0.000060 (loss 0.00009298)
Validation BER after epoch 78: 0.000060 (loss 0.00006186)
Validation BER after epoch 79: 0.000060 (loss 0.00017598)
Validation BER after epoch 80: 0.000060 (loss 0.00005090)
Validation BER after epoch 81: 0.000040 (loss 0.00011124)
Validation BER after epoch 82: 0.000060 (loss 0.00006365)
Validation BER after epoch 83: 0.000060 (loss 0.00005719)
Validation BER after epoch 84: 0.000040 (loss 0.00007591)
Validation BER after epoch 85: 0.000060 (loss 0.00005141)
Validation BER after epoch 86: 0.000040 (loss 0.00005167)
Validation BER after epoch 87: 0.000030 (loss 0.00007741)
Validation BER after epoch 88: 0.000040 (loss 0.00010145)
Validation BER after epoch 89: 0.000030 (loss 0.00039115)
Validation BER after epoch 90: 0.000040 (loss 0.00004980)
Validation BER after epoch 91: 0.000030 (loss 0.00006374)
Validation BER after epoch 92: 0.000030 (loss 0.00019472)
Validation BER after epoch 93: 0.000030 (loss 0.00005862)
Validation BER after epoch 94: 0.000030 (loss 0.00005816)
Validation BER after epoch 95: 0.000030 (loss 0.00006141)
Validation BER after epoch 96: 0.000030 (loss 0.00006563)
Validation BER after epoch 97: 0.000030 (loss 0.00003710)
Validation BER after epoch 98: 0.000030 (loss 0.00003480)
Validation BER after epoch 99: 0.000030 (loss 0.00003560)
Validation BER after epoch 100: 0.000030 (loss 0.00002542)
Validation BER after epoch 101: 0.000030 (loss 0.00002261)
Validation BER after epoch 102: 0.000030 (loss 0.00220603)
Validation BER after epoch 103: 0.000020 (loss 0.00003350)
Validation BER after epoch 104: 0.000020 (loss 0.00002150)
Validation BER after epoch 105: 0.000020 (loss 0.00004899)
Validation BER after epoch 106: 0.000030 (loss 0.00003853)
Validation BER after epoch 107: 0.000030 (loss 0.00022100)
Validation BER after epoch 108: 0.000030 (loss 0.00001859)
Validation BER after epoch 109: 0.000030 (loss 0.00001805)
Validation BER after epoch 110: 0.000020 (loss 0.00001448)
Validation BER after epoch 111: 0.000020 (loss 0.00002760)
Validation BER after epoch 112: 0.000030 (loss 0.00109210)
Validation BER after epoch 113: 0.000020 (loss 0.00001632)
Validation BER after epoch 114: 0.000020 (loss 0.00001698)
Validation BER after epoch 115: 0.000030 (loss 0.00001269)
Validation BER after epoch 116: 0.000020 (loss 0.00002444)
Validation BER after epoch 117: 0.000020 (loss 0.00004510)
Validation BER after epoch 118: 0.000020 (loss 0.00002313)
Validation BER after epoch 119: 0.000020 (loss 0.00002384)
Validation BER after epoch 120: 0.000020 (loss 0.00007444)
Validation BER after epoch 121: 0.000020 (loss 0.00001235)
Validation BER after epoch 122: 0.000020 (loss 0.00001242)
Validation BER after epoch 123: 0.000020 (loss 0.00003285)
Validation BER after epoch 124: 0.000020 (loss 0.00001661)
Validation BER after epoch 125: 0.000020 (loss 0.00002253)
Validation BER after epoch 126: 0.000020 (loss 0.00002228)
Validation BER after epoch 127: 0.000020 (loss 0.00001125)
Validation BER after epoch 128: 0.000020 (loss 0.00003357)
Validation BER after epoch 129: 0.000020 (loss 0.00001068)
Validation BER after epoch 130: 0.000020 (loss 0.00003057)
Validation BER after epoch 131: 0.000020 (loss 0.00005582)
Validation BER after epoch 132: 0.000020 (loss 0.00001659)
Validation BER after epoch 133: 0.000020 (loss 0.00001554)
Validation BER after epoch 134: 0.000020 (loss 0.00002069)
Validation BER after epoch 135: 0.000020 (loss 0.00002881)
Validation BER after epoch 136: 0.000020 (loss 0.00000820)
Validation BER after epoch 137: 0.000020 (loss 0.00001532)
Validation BER after epoch 138: 0.000020 (loss 0.00000651)
Validation BER after epoch 139: 0.000020 (loss 0.00003448)
Validation BER after epoch 140: 0.000020 (loss 0.00139898)
Validation BER after epoch 141: 0.000020 (loss 0.00001830)
Validation BER after epoch 142: 0.000020 (loss 0.00000928)
Validation BER after epoch 143: 0.000020 (loss 0.00000588)
Validation BER after epoch 144: 0.000020 (loss 0.00000676)
Validation BER after epoch 145: 0.000020 (loss 0.00000701)
Validation BER after epoch 146: 0.000020 (loss 0.00000655)
Validation BER after epoch 147: 0.000020 (loss 0.00000904)
Validation BER after epoch 148: 0.000020 (loss 0.00001027)
Validation BER after epoch 149: 0.000020 (loss 0.00001276)
Validation BER after epoch 150: 0.000020 (loss 0.00001847)
Validation BER after epoch 151: 0.000020 (loss 0.00000718)
Validation BER after epoch 152: 0.000020 (loss 0.00001634)
Validation BER after epoch 153: 0.000020 (loss 0.00002710)
Validation BER after epoch 154: 0.000020 (loss 0.00000506)
Validation BER after epoch 155: 0.000020 (loss 0.00000749)
Validation BER after epoch 156: 0.000010 (loss 0.00001412)
Validation BER after epoch 157: 0.000020 (loss 0.00002366)
Validation BER after epoch 158: 0.000020 (loss 0.00000768)
Validation BER after epoch 159: 0.000020 (loss 0.00018771)

In [30]:
plt.figure(figsize=(8,8))
plt.contourf(mgx,mgy,decision_region_evolution[-1].reshape(mgy.shape).T,cmap='coolwarm',vmin=0.3,vmax=0.695)
plt.scatter(X_valid[:,0].cpu()*ext_max, X_valid[:,1].cpu() * ext_max, c=y_valid, cmap='coolwarm')
print(Pin)
plt.axis('scaled')
plt.xlabel(r'$\Re\{r\}$',fontsize=16)
plt.ylabel(r'$\Im\{r\}$',fontsize=16)
#plt.title(title,fontsize=16)
#plt.savefig('after_optimization.pdf',bbox_inches='tight')


3
Out[30]:
Text(0, 0.5, '$\\Im\\{r\\}$')