This iPython Notebook is a demonstration of the algorithm presented in:
that builds on previous work:
This demo builds on the real-time GCC-NMF speech enhancement iPython notebook, where we enhance the noisy speech signal found at data/dev_Sq1_Co_A_mix.wav, taken from the SiSEC 2016 Two-channel Mixtures of Speech and Real-world Background Noise "dev" dataset. We extend real-time GCC-NMF here with the asymmetric STFT windowing strategy, maintaining the long analysis windows required by GCC-NMF, but using short synthesis windows, drastically reducing the algorithmic latency. We demonstrate that the asymmetric STFT windowing strategy results in similar speech enhancement results to the traditional symmetric STFT windowing case. The asymmetric windowing strategy we use was proposed in,
The first three sections, listed in italics, are repeated from the real-time GCC-NMF notebook.
In [1]:
from gccNMF.gccNMFFunctions import *
from gccNMF.gccNMFPlotting import *
from gccNMF.wavfile import wavwrite
from numpy import *
from numpy.fft import rfft, irfft
from matplotlib.pylab import cm
from matplotlib import gridspec
from IPython import display
%matplotlib inline
In [2]:
# Preprocessing params
fftSize = 1024
# Asymmetric windowing params
analysisWindowSize = fftSize
synthesisWindowSize = 128
asymmetricHopSize = (synthesisWindowSize * 3) // 4
m = synthesisWindowSize // 2
k = analysisWindowSize
d = 0
# Symmetric windowing params
symmetricWindowSize = fftSize
#symmetricHopSize = (fftSize * 3) // 4
symmetricHopSize = asymmetricHopSize # to better compare results
# TDOA params
numTDOAs = 128
targetTDOAEpsilonPercent = 0.05 # controls the TDOA width for GCC-NMF mask generation
targetTDOAEpsilon = targetTDOAEpsilonPercent * numTDOAs
# NMF params
trainingDataFileName = '../data/chimeTrainSet.npy'
dictionarySize = 128
numPreLearningIterations = 100
numInferenceIterations = 0
sparsityAlpha = 0
epsilon = 1e-16
seedValue = 0
# Input params
mixtureFileNamePrefix = '../data/dev_Sq1_Co_A'
microphoneSeparationInMetres = 0.086
In [3]:
mixtureFileName = getMixtureFileName(mixtureFileNamePrefix)
stereoSamples, sampleRate = loadMixtureSignal(mixtureFileName)
numChannels, numSamples = stereoSamples.shape
durationInSeconds = numSamples / float(sampleRate)
display.display( display.Audio(mixtureFileName) )
In [4]:
describeMixtureSignal(stereoSamples, sampleRate)
figure(figsize=(14, 6))
plotMixtureSignal(stereoSamples, sampleRate)
Pre-learn the NMF dictionary from a small subset of the CHiME dataset. A total of 1024 FFT frames are loaded divided equally between isolated speech and noise signals.
In [5]:
trainV = load(trainingDataFileName)
numFrequencies, numTrainFrames = trainV.shape
frequenciesInHz = getFrequenciesInHz(sampleRate, numFrequencies)
frequenciesInkHz = frequenciesInHz / 1000.0
In [6]:
figure(figsize=(12, 6))
imshow((trainV / max(trainV)) ** (1/3.0),
extent=[0, trainV.shape[1]-1, frequenciesInkHz[0], frequenciesInkHz[-1]],
cmap=cm.binary)
title('CHiME Training Data')
ylabel('Frequency (kHz)')
xlabel('Frame Index')
show()
In [7]:
W, H = performKLNMF(trainV, dictionarySize, numPreLearningIterations,
sparsityAlpha, epsilon, seedValue)
In [8]:
figure(figsize=(12, 6))
imshow((W / max(W)) ** (1/3.0),
extent=[0, W.shape[1]-1, frequenciesInkHz[0], frequenciesInkHz[-1]],
cmap=cm.binary)
title('Pre-trinaed NMF Dictionary')
ylabel('Frequency (kHz)')
xlabel('Atom Index')
show()
In [9]:
def getAsymmetricAnalysisWindow(k, m, d):
risingSqrtHann = sqrt( hanning(2*(k-m-d)+1)[:2*(k-m-d)] )
fallingSqrtHann = sqrt( hanning(2*m+1)[:2*m] )
window = zeros(k)
window[:d] = 0
window[d:k-m] = risingSqrtHann[:k-m-d]
window[k-m:] = fallingSqrtHann[-m:]
return window
def getAsymmetricSynthesisWindow(k, m, d):
risingSqrtHannAnalysis = sqrt( hanning(2*(k-m-d)+1)[:2*(k-m-d)] )
risingNoramlizedHann = hanning(2*m+1)[:m] / risingSqrtHannAnalysis[k-2*m-d:k-m-d]
fallingSqrtHann = sqrt( hanning(2*m+1)[:2*m] )
window = zeros(k)
window[:-2*m] = 0
window[-2*m:-m] = risingNoramlizedHann
window[-m:] = fallingSqrtHann[-m:]
return window
In [10]:
analysisWindow = getAsymmetricAnalysisWindow(k, m, d)
synthesisWindow = getAsymmetricSynthesisWindow(k, m, d)
symmetricWindow = sqrt(hanning(symmetricWindowSize))
In [11]:
figure(figsize=(14, 10))
ax = subplot(321)
plot(symmetricWindow)
ax.set_xticklabels([])
ylabel('Analysis Window')
title('Symmetric Windowing')
ax = subplot(323)
plot(symmetricWindow)
ax.set_xticklabels([])
ylabel('Synthesis Window')
ax = subplot(325)
plot(symmetricWindow * symmetricWindow)
xlabel('Sample Index')
ylabel('Analysis Window *\nSynthesis Window')
ax = subplot(322)
plot(analysisWindow)
ax.set_xticklabels([])
title('Asymmetric Windowing')
ax = subplot(324)
ax.set_xticklabels([])
plot(synthesisWindow)
ax = subplot(326)
plot(analysisWindow * synthesisWindow)
xlabel('Sample Index')
show()
In [12]:
tdoasInSeconds = getTDOAsInSeconds(microphoneSeparationInMetres, numTDOAs)
expJOmegaTau = exp( outer(frequenciesInHz, -(2j * pi) * tdoasInSeconds) )
In [13]:
def performOnlineSpeechEnhancement(analysisWindow, synthesisWindow, hopSize):
# Setup variables to save speech enhancement results
numFrames = (numSamples-len(synthesisWindow)) // hopSize
gainFactor = hopSize / float(len(synthesisWindow)) * 2
targetEstimateSamplesOLA = zeros_like(stereoSamples)
gccPHATAccumulatedMax = zeros(numTDOAs)
gccPHATAccumulatedMax[:] = -inf
atomMask = zeros(dictionarySize)
targetTDOAs = zeros(numFrames)
targetTDOAs[:] = nan
angularSpectrogram = zeros( (numTDOAs, numFrames) )
atomMasks = zeros( (dictionarySize, numFrames) )
wienerFilters = zeros( (2, numFrequencies, numFrames) )
inputSpectrogram = zeros( (2, numFrequencies, numFrames), 'complex64')
outputSpectrogram = zeros( (2, numFrequencies, numFrames), 'complex64')
# For each STFT frame:
# 1. Compute FFT
# 2. Localize target with accumulated GCC-PHAT localization
# 3. Compute GCC-NMF atom mask
# 4. Construct Wiener-like filter, and filter input
# 5. Reconstruct time domain samples
# 6. Overlap-add to output samples
for frameIndex in range(numFrames):
# Compute FFT
frameStart = frameIndex * hopSize
frameEnd = frameStart + analysisWindowSize
stereoSTFTFrame = rfft( stereoSamples[:, frameStart:frameEnd] * analysisWindow )
inputSpectrogram[..., frameIndex] = stereoSTFTFrame
# localize target with accumulated GCC-PHAT
coherenceV = stereoSTFTFrame[0] * stereoSTFTFrame[1].conj() / abs(stereoSTFTFrame[0]) / abs(stereoSTFTFrame[1])
gccPHAT = dot(coherenceV, expJOmegaTau).real
gccPHATAccumulatedMax = max( array( [gccPHAT, gccPHATAccumulatedMax] ), axis=0 )
targetTDOAEstimate = argmax(gccPHATAccumulatedMax)
targetTDOAs[frameIndex] = targetTDOAEstimate
angularSpectrogram[:, frameIndex] = gccPHAT
# compute GCC-NMF atom mask
gccNMF = dot( (coherenceV[:, newaxis] * expJOmegaTau).real.T, W )
gccNMFTDOAEstimates = argmax(gccNMF, axis=0)
atomMask[:] = 0
atomMask[ abs(gccNMFTDOAEstimates - targetTDOAEstimate) < targetTDOAEpsilon ] = 1
atomMasks[:, frameIndex] = atomMask
# construct wiener filter
if numInferenceIterations == 0:
wienerFilter = sum(atomMask * W, axis=1) / sum(W, axis=1)
wienerFilters[:, :, frameIndex] = wienerFilter
else:
stereoH = inferCoefficientsKLNMF( abs(stereoSTFTFrame).T, W, numInferenceIterations,
sparsityAlpha, epsilon, seedValue)
recV = dot(W, stereoH)
sourceEstimate = dot(W, stereoH * atomMask[:, newaxis])
wienerFilter = (sourceEstimate / recV).T
wienerFilters[:, :, frameIndex] = wienerFilter
filterdSTFTFrame = wienerFilter * stereoSTFTFrame
outputSpectrogram[..., frameIndex] = filterdSTFTFrame
# reconstruct time domain samples
recStereoSTFTFrame = irfft(filterdSTFTFrame)
# overlap-add to output samples
targetEstimateSamplesOLA[:, frameStart:frameEnd] += recStereoSTFTFrame
targetEstimateSamplesOLA *= gainFactor
return inputSpectrogram, outputSpectrogram, targetEstimateSamplesOLA, \
gccPHATAccumulatedMax, targetTDOAs, angularSpectrogram, atomMasks, wienerFilters
In [14]:
def plotResults(titleString, mainGridSpec, inputSpectrogram, outputSpectrogram, targetEstimateSamplesOLA, \
gccPHATAccumulatedMax, targetTDOAs, angularSpectrogram, atomMasks, wienerFilters):
gridSpec = gridspec.GridSpecFromSubplotSpec(5, 1, mainGridSpec)
ax = subplot(gridSpec[0])
imshow(angularSpectrogram, cmap=cm.jet,
extent=[0, angularSpectrogram.shape[1]-1, tdoasInSeconds[0]*1000.0, tdoasInSeconds[-1]*1000.0])
targetTDOAsInMilliSeconds = take(tdoasInSeconds, targetTDOAs.astype('int32')) * 1000.0
plot( targetTDOAsInMilliSeconds, 'r')
ylabel('TDOA (ms)')
title(titleString)
ax = subplot(gridSpec[1])
imshow( atomMasks, cmap=cm.binary )
ylabel('Atom Index')
ax = subplot(gridSpec[2])
imshow( wienerFilters[0], cmap=cm.jet,
extent=[0, wienerFilters.shape[-1]-1, frequenciesInkHz[0], frequenciesInkHz[-1]] )
ylabel('Frequency (kHz)')
ax = subplot(gridSpec[3])
imshow( abs(outputSpectrogram[0]) ** (1/3.0), cmap=cm.jet,
extent=[0, outputSpectrogram.shape[-1]-1, frequenciesInkHz[0], frequenciesInkHz[-1]] )
ylabel('Frequency (kHz)')
ax = subplot(gridSpec[4])
plot(targetEstimateSamplesOLA[0])
xlim( (0, targetEstimateSamplesOLA.shape[-1]-1) )
xlabel('Time (samples)')
In [15]:
symmetricResults = performOnlineSpeechEnhancement(symmetricWindow, symmetricWindow, symmetricHopSize)
asymmetricResults = performOnlineSpeechEnhancement(analysisWindow, synthesisWindow, asymmetricHopSize)
In [16]:
figure(figsize=(16, 16))
mainGridSpec = gridspec.GridSpec(1, 2)
plotResults( 'Symmetric Windowing', mainGridSpec[0], *symmetricResults)
plotResults( 'Asymmetric Windowing', mainGridSpec[1], *asymmetricResults )
show()
In [17]:
figure(figsize=(16, 8))
sampleTimesInSeconds = arange(numSamples) / float(sampleRate)
ax = subplot(321)
plot(sampleTimesInSeconds, stereoSamples[0])
ax.set_xticklabels([])
title('Noisy Input (left)')
ax = subplot(322)
plot(sampleTimesInSeconds, stereoSamples[1])
ax.set_xticklabels([])
ax.set_yticklabels([])
title('Noisy Input (right)')
ax = subplot(323)
plot(sampleTimesInSeconds, symmetricResults[2][0])
title('Symmetric windowing target estimate (left)')
ax = subplot(324)
plot(sampleTimesInSeconds, symmetricResults[2][1])
ax.set_yticklabels([])
title('Symmetric windowing target estimate (right)')
ax = subplot(325)
plot(sampleTimesInSeconds, asymmetricResults[2][0])
title('Asymmetric windowing target estimate (left)')
ax = subplot(326)
plot(sampleTimesInSeconds, asymmetricResults[2][1])
ax.set_yticklabels([])
title('Asymmetric windowing target estimate (right)')
show()
In [18]:
symmetricTargetEstimateFileName = mixtureFileNamePrefix + '_sim_symmetric.wav'
wavwrite(symmetricResults[2], symmetricTargetEstimateFileName, sampleRate)
asymmetricTargetEstimateFileName = mixtureFileNamePrefix + '_sim_asymmetric.wav'
wavwrite(asymmetricResults[2], asymmetricTargetEstimateFileName, sampleRate)
print('Noisy Mixture:')
display.display(display.Audio(mixtureFileName))
print('Symmetric Target Estimate:')
display.display(display.Audio(symmetricTargetEstimateFileName))
print('Asymmetric Target Estimate:')
display.display(display.Audio(asymmetricTargetEstimateFileName))
In [ ]: