This code is provided as supplementary material of the lecture Machine Learning and Optimization in Communications (MLOC).
This code illustrates
In [ ]:
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
Import and load MNIST dataset
In [ ]:
mnist = tf.keras.datasets.mnist
# only load the images, we are not interested in the training data
(x_train, _),(x_test, _) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
image_size = x_train.shape[1]
x_test_flat = np.array([np.reshape(x_test[k,:,:], image_size*image_size) for k in range(x_test.shape[0])])
In [ ]:
#print 8 random images
plt.figure(figsize=(16,2))
for k in range(8):
plt.subplot(1,8,k+1)
plt.imshow(x_train[np.random.randint(x_train.shape[0])], interpolation='nearest', cmap='binary')
plt.xticks(())
plt.yticks(())
In [ ]:
tf.reset_default_graph()
# target compression rate
bit_per_image = tf.placeholder(tf.int32, shape=())
Pe = tf.placeholder(tf.float32, shape=())
# Network parameters
hidden_encoder_1 = 500
hidden_encoder_2 = 250
hidden_encoder_3 = 120
hidden_decoder_1 = 120
hidden_decoder_2 = 250
hidden_decoder_3 = 500
training_data = tf.placeholder(tf.float32, [None, image_size*image_size])
valid_data = tf.constant(x_test_flat, dtype=tf.float32)
weights = { 'We1' : tf.Variable(tf.truncated_normal([image_size*image_size, hidden_encoder_1], stddev=0.1)),
'We2' : tf.Variable(tf.truncated_normal([hidden_encoder_1, hidden_encoder_2], stddev=0.1)),
'We3' : tf.Variable(tf.truncated_normal([hidden_encoder_2, hidden_encoder_3], stddev=0.1)),
'We4' : tf.Variable(tf.truncated_normal([hidden_encoder_3, bit_per_image], stddev=0.1), validate_shape=False),
'Wd1' : tf.Variable(tf.truncated_normal([bit_per_image, hidden_decoder_1], stddev=0.1), validate_shape=False),
'Wd2' : tf.Variable(tf.truncated_normal([hidden_decoder_1, hidden_decoder_2], stddev=0.1)),
'Wd3' : tf.Variable(tf.truncated_normal([hidden_decoder_2, hidden_decoder_3], stddev=0.1)),
'Wd4' : tf.Variable(tf.truncated_normal([hidden_decoder_3, image_size*image_size], stddev=0.1)),
}
biases = { 'be1' : tf.Variable(tf.truncated_normal([hidden_encoder_1], stddev=0.1)),
'be2' : tf.Variable(tf.truncated_normal([hidden_encoder_2], stddev=0.1)),
'be3' : tf.Variable(tf.truncated_normal([hidden_encoder_3], stddev=0.1)),
'be4' : tf.Variable(tf.truncated_normal([bit_per_image], stddev=0.1), validate_shape=False),
'bd1' : tf.Variable(tf.truncated_normal([hidden_decoder_1], stddev=0.1)),
'bd2' : tf.Variable(tf.truncated_normal([hidden_decoder_2], stddev=0.1)),
'bd3' : tf.Variable(tf.truncated_normal([hidden_decoder_3], stddev=0.1)),
'bd4' : tf.Variable(tf.truncated_normal([image_size*image_size], stddev=0.1)),
}
def binarizer(input):
prob = tf.truediv(tf.add(input, 1.0), 2.0)
bernoulli = tf.distributions.Bernoulli(probs=prob, dtype=tf.float32)
return 2*bernoulli.sample() - 1
def binarizer_deterministic(input):
return tf.sign(input)
def encoder(batch):
temp = tf.nn.elu(tf.matmul(batch, weights['We1']) + biases['be1'])
temp = tf.nn.elu(tf.matmul(temp, weights['We2']) + biases['be2'])
temp = tf.nn.elu(tf.matmul(temp, weights['We3']) + biases['be3'])
output = tf.nn.softsign(tf.matmul(temp, weights['We4']) + biases['be4'])
return output
def decoder(batch):
temp = tf.nn.elu(tf.matmul(batch, weights['Wd1']) + biases['bd1'])
temp = tf.nn.elu(tf.matmul(temp, weights['Wd2']) + biases['bd2'])
temp = tf.nn.elu(tf.matmul(temp, weights['Wd3']) + biases['bd3'])
output = tf.nn.sigmoid(tf.matmul(temp, weights['Wd4']) + biases['bd4'])
return output
encoded = encoder(training_data)
# random binarization in training
ti = tf.identity(encoded)
compressed = ti + tf.stop_gradient(binarizer(encoded) - ti)
# add error pattern
error_tensor = tf.distributions.Bernoulli(probs = Pe * tf.ones_like(compressed), dtype=tf.float32).sample()
received = tf.math.multiply( compressed, 1 - 2*error_tensor)
reconstructed = decoder(received)
encoded_test = encoder(valid_data)
compressed_test = binarizer_deterministic(encoded_test)
error_tensor_test = tf.distributions.Bernoulli(probs = Pe * tf.ones_like(compressed_test), dtype=tf.float32).sample()
received_test = tf.math.multiply( compressed_test, 1 - 2*error_tensor_test )
reconstructed_test = decoder(received_test)
loss_test = tf.reduce_mean(tf.square(valid_data - reconstructed_test))
signal_test = tf.reduce_sum(tf.square(valid_data))
noise_test = tf.reduce_sum(tf.square(valid_data - reconstructed_test))
SNR = 10.0*(tf.log(signal_test) - tf.log(noise_test))/tf.log(tf.constant(10.0))
loss = tf.losses.mean_squared_error(training_data, reconstructed)
#loss = tf.reduce_mean(tf.square(training_data - reconstructed))
train_step = tf.train.AdamOptimizer().minimize(loss)
init = tf.global_variables_initializer()
In [ ]:
def get_batch(x, batch_size):
idxs = np.random.randint(0, x.shape[0], (batch_size))
return np.array([np.reshape(x[k,:,:], image_size*image_size) for k in idxs])
Sweep among different bit per image values and different error probabilities. The results are saved in a text file that can be used to plot figures.
In [ ]:
batch_size = 250
Pe_range = np.array([0, 0.01, 0.1, 0.2])
bit_range = np.array([5, 10, 20, 30, 40, 50, 60, 70, 80, 100])
SNR_result = np.zeros( (len(Pe_range), len(bit_range)) )
for i in range(len(Pe_range)):
for j in range(len(bit_range)):
best_SNR = -9999;
print('Initializing ....')
# Create session and initialize all variables
session = tf.InteractiveSession()
session.run(init, feed_dict = { bit_per_image : bit_range[j]})
print('done')
# Training loop
for it in range(100000):
mini_batch = get_batch(x_train, batch_size)
session.run(train_step, feed_dict = { training_data : mini_batch, bit_per_image : bit_range[j], Pe: Pe_range[i] })
if it % 500 == 0:
cur_SNR = SNR.eval(feed_dict = { bit_per_image : bit_range[j], Pe: Pe_range[i] })
if cur_SNR > best_SNR:
best_SNR = cur_SNR
if it % 10000 == 0:
print('Pe = %1.2f, bits = %d, It %d: (best SNR: %1.4f dB)' % (Pe_range[i], bit_range[j], it, best_SNR))
SNR_result[i,j] = best_SNR
print('Finished learning for e = %1.2f, bits = %d. Best SNR: %1.4f' % (Pe_range[i], bit_range[j], best_SNR))
session.close()
np.savetxt('SNR_result.txt', SNR_result, delimiter=',')
In [ ]:
valid_images = reconstructed_test.eval(feed_dict = { bit_per_image : 20, Pe: 0.0 })
valid_binary = 0.5*(1 - compressed_test.eval()) # from bipolar (BPSK) to binary
# show 8 images and their reconstructed versions
plt.figure(figsize=(16,4))
idxs = np.random.randint(x_test.shape[0],size=8)
for k in range(8):
plt.subplot(2,8,k+1)
plt.imshow(np.reshape(x_test_flat[idxs[k]], (image_size,image_size)), interpolation='nearest', cmap='binary')
plt.xticks(())
plt.yticks(())
plt.subplot(2,8,k+1+8)
plt.imshow(np.reshape(valid_images[idxs[k]], (image_size,image_size)), interpolation='nearest', cmap='binary')
plt.xticks(())
plt.yticks(())
# print binary data of the images
for k in range(8):
print('Image %d: ' % (k+1), valid_binary[idxs[k],:])
In [ ]:
session.close()