This iPython Notebook is a demonstration of the algorithm presented in:
building on previous work:
In this demo, we present the online GCC-NMF variant for blind speech enhancement. We enhance the noisy speech signal found at data/dev_Sq1_Co_A_mix.wav, taken from the SiSEC 2016 Two-channel Mixtures of Speech and Real-world Background Noise "dev" dataset, and save the result to the data directory. Dictionary pre-learning is performed in an unsupervised fashion using a small subset of isolated speech and noise frames from the CHiME 2016 dataset that we have preprocessed and stored locally at data/chimeTrainSet.npy.
Performing speech enhancement with GCC-NMF in real-time requires three modifications to the offline approach presented previously:
In [1]:
from gccNMF.gccNMFFunctions import *
from gccNMF.gccNMFPlotting import *
from gccNMF.wavfile import wavwrite
from numpy import *
from numpy.fft import rfft, irfft
from matplotlib.pylab import cm
from IPython import display
%matplotlib inline
In [2]:
# Preprocessing params
windowSize = 1024
fftSize = windowSize
hopSize = 128
window = hanning(windowSize)
stftGainFactor = hopSize / float(windowSize) * 2
# TDOA params
numTDOAs = 128
targetTDOAEpsilonPercent = 0.05 # controls the TDOA width for GCC-NMF mask generation
targetTDOAEpsilon = targetTDOAEpsilonPercent * numTDOAs
# NMF params
trainingDataFileName = '../data/chimeTrainSet.npy'
dictionarySize = 128
numPreLearningIterations = 100
numInferenceIterations = 0
sparsityAlpha = 0
epsilon = 1e-16
seedValue = 0
# Input params
mixtureFileNamePrefix = '../data/dev_Sq1_Co_A'
microphoneSeparationInMetres = 0.086
In [3]:
mixtureFileName = getMixtureFileName(mixtureFileNamePrefix)
stereoSamples, sampleRate = loadMixtureSignal(mixtureFileName)
numChannels, numSamples = stereoSamples.shape
durationInSeconds = numSamples / float(sampleRate)
display.display( display.Audio(mixtureFileName) )
In [4]:
describeMixtureSignal(stereoSamples, sampleRate)
figure(figsize=(14, 6))
plotMixtureSignal(stereoSamples, sampleRate)
Pre-learn the NMF dictionary from a small subset of the CHiME dataset. A total of 1024 FFT frames are loaded divided equally between isolated speech and noise signals.
In [5]:
trainV = load(trainingDataFileName)
numFrequencies, numTrainFrames = trainV.shape
frequenciesInHz = getFrequenciesInHz(sampleRate, numFrequencies)
frequenciesInkHz = frequenciesInHz / 1000.0
In [6]:
figure(figsize=(12, 6))
imshow((trainV / max(trainV)) ** (1/3.0),
extent=[0, trainV.shape[1]-1, frequenciesInkHz[0], frequenciesInkHz[-1]],
cmap=cm.binary)
title('CHiME Training Data')
ylabel('Frequency (kHz)')
xlabel('Frame Index')
show()
In [7]:
W, H = performKLNMF(trainV, dictionarySize, numPreLearningIterations,
sparsityAlpha, epsilon, seedValue)
In [8]:
figure(figsize=(12, 6))
imshow((W / max(W)) ** (1/3.0),
extent=[0, W.shape[1]-1, frequenciesInkHz[0], frequenciesInkHz[-1]],
cmap=cm.binary)
title('Pre-trinaed NMF Dictionary')
ylabel('Frequency (kHz)')
xlabel('Atom Index')
show()
In [9]:
numFrames = (numSamples-windowSize) // hopSize
maxSample = numFrames * hopSize + windowSize
tdoasInSeconds = getTDOAsInSeconds(microphoneSeparationInMetres, numTDOAs)
expJOmegaTau = exp( outer(frequenciesInHz, -(2j * pi) * tdoasInSeconds) )
targetEstimateSamplesOLA = zeros_like(stereoSamples)
gccPHATAccumulatedMax = zeros(numTDOAs)
gccPHATAccumulatedMax[:] = -inf
atomMask = zeros(dictionarySize)
targetTDOAs = zeros(numFrames)
targetTDOAs[:] = nan
angularSpectrogram = zeros( (numTDOAs, numFrames) )
atomMasks = zeros( (dictionarySize, numFrames) )
wienerFilters = zeros( (2, numFrequencies, numFrames) )
inputSpectrogram = zeros( (2, numFrequencies, numFrames), 'complex64')
outputSpectrogram = zeros( (2, numFrequencies, numFrames), 'complex64')
In [10]:
for frameIndex in range(numFrames):
# Compute FFT
frameStart = frameIndex * hopSize
frameEnd = frameStart + windowSize
stereoSTFTFrame = rfft( stereoSamples[:, frameStart:frameEnd] * window )
inputSpectrogram[..., frameIndex] = stereoSTFTFrame
# localize target with accumulated GCC-PHAT
coherenceV = stereoSTFTFrame[0] * stereoSTFTFrame[1].conj() / abs(stereoSTFTFrame[0]) / abs(stereoSTFTFrame[1])
gccPHAT = dot(coherenceV, expJOmegaTau).real
gccPHATAccumulatedMax = max( array( [gccPHAT, gccPHATAccumulatedMax] ), axis=0 )
targetTDOAEstimate = argmax(gccPHATAccumulatedMax)
targetTDOAs[frameIndex] = targetTDOAEstimate
angularSpectrogram[:, frameIndex] = gccPHAT
# compute GCC-NMF atom mask
gccNMF = dot( (coherenceV[:, newaxis] * expJOmegaTau).real.T, W )
gccNMFTDOAEstimates = argmax(gccNMF, axis=0)
atomMask[:] = 0
atomMask[ abs(gccNMFTDOAEstimates - targetTDOAEstimate) < targetTDOAEpsilon ] = 1
atomMasks[:, frameIndex] = atomMask
# construct wiener filter
if numInferenceIterations == 0:
wienerFilter = sum(atomMask * W, axis=1) / sum(W, axis=1)
wienerFilters[:, :, frameIndex] = wienerFilter
else:
stereoH = inferCoefficientsKLNMF( abs(stereoSTFTFrame).T, W, numInferenceIterations,
sparsityAlpha, epsilon, seedValue)
recV = dot(W, stereoH)
sourceEstimate = dot(W, stereoH * atomMask[:, newaxis])
wienerFilter = (sourceEstimate / recV).T
wienerFilters[:, :, frameIndex] = wienerFilter
filterdSTFTFrame = wienerFilter * stereoSTFTFrame
outputSpectrogram[..., frameIndex] = filterdSTFTFrame
# reconstruct time domain samples
recStereoSTFTFrame = irfft(filterdSTFTFrame)
# overlap-add to output samples
targetEstimateSamplesOLA[:, frameStart:frameEnd] += recStereoSTFTFrame * stftGainFactor
In [11]:
figure(figsize=(16, 4))
imshow(angularSpectrogram,
extent=[0, durationInSeconds, tdoasInSeconds[0]*1000.0, tdoasInSeconds[-1]*1000.0], cmap=cm.binary)
targetTDOAsInMilliSeconds = take(tdoasInSeconds, targetTDOAs.astype('int32')) * 1000.0
plot( linspace(0, durationInSeconds, len(targetTDOAs)), targetTDOAsInMilliSeconds, 'r')
ylabel('TDOA (ms)')
xlabel('Time (s)')
show()
In [12]:
figure(figsize=(16, 6))
imshow(atomMasks, cmap=cm.binary,
extent=[0, durationInSeconds, 0, dictionarySize])
ylabel('Atom Index')
xlabel('Time (s)')
show()
In [13]:
figure(figsize=(16, 6))
ax = subplot(121)
imshow(wienerFilters[0], cmap=cm.jet,
extent=[0, durationInSeconds, frequenciesInkHz[0], frequenciesInkHz[-1]])
ylabel('Frequency (kHz)')
xlabel('Time (s)')
title('Wiener filter (left)')
ax = subplot(122)
imshow(wienerFilters[1], cmap=cm.jet,
extent=[0, durationInSeconds, frequenciesInkHz[0], frequenciesInkHz[-1]])
xlabel('Time (s)')
title('Wiener filter (right)')
show()
In [14]:
figure(figsize=(16, 8))
ax = subplot(221)
imshow( abs(inputSpectrogram[0]) ** (1/3.0), cmap=cm.binary,
extent=[0, durationInSeconds, frequenciesInkHz[0], frequenciesInkHz[-1]] )
ax.set_xticklabels([])
ylabel('Frequency (kHz)')
title('Noisy Input (left)')
ax = subplot(222)
imshow( abs(inputSpectrogram[1]) ** (1/3.0), cmap=cm.binary,
extent=[0, durationInSeconds, frequenciesInkHz[0], frequenciesInkHz[-1]] )
ax.set_xticklabels([])
ax.set_yticklabels([])
title('Noisy Input (right)')
ax = subplot(223)
imshow( abs(outputSpectrogram[0]) ** (1/3.0), cmap=cm.binary,
extent=[0, durationInSeconds, frequenciesInkHz[0], frequenciesInkHz[-1]] )
ylabel('Frequency (kHz)')
xlabel('Time (s)')
title('Target estimate (left)')
ax = subplot(224)
imshow( abs(outputSpectrogram[1]) ** (1/3.0), cmap=cm.binary,
extent=[0, durationInSeconds, frequenciesInkHz[0], frequenciesInkHz[-1]] )
ax.set_yticklabels([])
xlabel('Time (s)')
title('Target estimate (right)')
show()
In [15]:
figure(figsize=(16, 8))
sampleTimesInSeconds = arange(numSamples) / float(sampleRate)
ax = subplot(221)
plot(sampleTimesInSeconds, stereoSamples[0])
ax.set_xticklabels([])
title('Noisy Input (left)')
ax = subplot(222)
plot(sampleTimesInSeconds, stereoSamples[1])
ax.set_xticklabels([])
ax.set_yticklabels([])
title('Noisy Input (right)')
ax = subplot(223)
plot(sampleTimesInSeconds, targetEstimateSamplesOLA[0])
title('Target estimate (left)')
ax = subplot(224)
plot(sampleTimesInSeconds, targetEstimateSamplesOLA[1])
ax.set_yticklabels([])
title('Target estimate (right)')
show()
In [16]:
targetEstimateFileName = mixtureFileNamePrefix + '_sim_realtime.wav'
wavwrite( targetEstimateSamplesOLA, targetEstimateFileName, sampleRate )
print('Noisy Mixture:')
display.display(display.Audio(mixtureFileName))
print('Target Estimate:')
display.display(display.Audio(targetEstimateFileName))
In [ ]: