In [250]:
    
# -*- coding: utf-8 -*-
# !/usr/bin/python
# Cross Wavelet Analysis (CWA) based on Maraun and Kurths(2004).
# http://www.nonlin-processes-geophys.net/11/505/2004/npg-11-505-2004.pdf
# author: Mabel Calim Costa
# INPE
# 23/01/2013
# reviewed --- 02/03/2018
"""
Created on Mon Jun 17 2013
@author: Mabel Calim Costa
"""
import numpy as np
import pylab
from pylab import *
import matplotlib.pyplot as plt
import cmath
import pandas as pd
    
In [302]:
    
def cross_wavelet(wave1, wave2):
    """ Computes the cross wavelet analysis.
    wave1 = result['wave'] time serie 1
            wave2 = result['wave'] time serie 2
    A normalized time and scale resolved measure for the relationship
    between two time series x1(t) and x2(t) is the wavelet coherency (WCO),
    which is defined as the amplitude of the WCS(wavelet cross spectrum)
    normalized to the two single WPS(wavelet power spectrum) (Maraun and
    Kurths,2004).
    WCOi(s)**2  = ((WPS12)*(WPS12)) / ((WPS1 * WPS2))**2
    _____________________________________________________________________
    Inputs:
    wave1 - wavelet transform of time series x1
            wave2 - wavelet transform of time series x2
    Outputs:
    cohere - wavelet coherency (WCO)
    Call function:
    cohere = cross_wavelet(wave1,wave2)
    """
    cross_power = wave1 * wave2.conjugate()
    WPS12 = (wave1 * wave2.conjugate())
    WPS1  = (wave1 * wave1.conjugate())
    WPS2  = (wave2 * wave2.conjugate())
    #coherence = np.sqrt(cross_power*cross_power)/ \
    #np.sqrt(np.absolute(wave1.real*wave1.imag)* np.absolute(wave2.real*wave2.imag))
    coherence = ((WPS12)*(WPS12)) / ((WPS1 * WPS2))
    coherence = np.real(coherence)
    pot_cohere = np.around(coherence**2,3) # round numbers for digits to be in interval 0 a 1 
    phase_angle = np.angle(cross_power)  # ,deg=True)
    return WPS12, pot_cohere, phase_angle
    
In [85]:
    
def plot_cross(var, cross_power, phase_angle, time, result, result1):
    """ PLOT CROSS POWER
    cross_power = from cross_wavelet function
    coherence   = from cross_wavelet function
    """
    import matplotlib.gridspec as gridspec
    fig = plt.figure(figsize=(15, 10), dpi=300)
    # set plot grid
    gs1 = gridspec.GridSpec(4, 3)
    gs1.update(left=0.05)
    ax1 = plt.subplot(gs1[0, :])
    ax1 = pylab.gca()
    ax1.xaxis.set_visible(False)
    plt.setp(ax1.get_xticklabels(), visible=False)
    ax2 = plt.subplot(gs1[1:4, :])# axisbg='#C0C0C0')
    # plot timeseries
    ax1.plot(time, result['data'])
    ax1.set_ylabel('Amplitude', fontsize=13)
    ax1.axis('tight')
    ax3 = ax1.twinx()
    ax3.plot(time, result1['data'], color='c')
    ax3.set_ylabel(result1['name'])
    ax3.axis('tight')
    ax1.set_title('%s' % var, fontsize=15)
    ax1.yaxis.set_major_locator(MaxNLocator(prune='lower'))
    ax1.grid(True)
    ax1.xaxis.set_visible(False)
    phs_dt = round(len(time) / 20)
    tidx = np.arange(np.max(np.floor(phs_dt / 2)), len(time), phs_dt)
    tidx = [int(i) for i in tidx]
    tidx = np.array(tidx)
    phs_dp = round(len(result['period']) / 20)
    pidx = np.arange(
        np.max(np.floor(phs_dp / 2)), len(result['period']), phs_dp)
    pidx = [int(i) for i in pidx]
    pidx = np.array(pidx)
    X, Y = meshgrid(
        time.astype(np.int64)[tidx], np.log2(result['period'][pidx]))
    
    #Arrows indicate in phase when pointing to the right and out of phase when pointing left.
    phase_angle1 = phase_angle[:, tidx]
    phase_angle1 = phase_angle1[pidx, :]
    cA = np.exp(1j * phase_angle1)
    U = np.real(cA)
    V = np.imag(cA)
    ax4 = ax2.twiny()
    ax4.xaxis.set_visible(False)
    # ax4.set_xlim(0.9,4.4)
    CS = ax2.contourf(time, np.log2(result['period']), cross_power)
    # cone-of-influence , anything "below"is dubious
    ax2.plot(time, np.log2(result['coi']), 'k')
    ax2.fill_between(time, np.log2(result['coi']), int(
        np.log2(result['period'][-1]) + 1), alpha=0.5, hatch='/')
    position = fig.add_axes([0.15, 0.05, 0.6, 0.01])
    # ,norm=normal)#,shrink=0.5,pad=0.08)
    cbar = plt.colorbar(CS, cax=position, orientation='horizontal')
    cbar.set_label('Power')
    Q = ax4.quiver(X.astype(np.int64), Y, U, V, linewidth=0.1)
    ax4.axis('tight')
    yt = range(int(np.log2(result['period'][0])), int(
        np.log2(result['period'][-1]) + 1))  # create the vector of periods
    Yticks = [float(math.pow(2, p)) for p in yt]  # make 2^periods
    Yticks = [int(i) for i in Yticks]
    ax2.set_yticks(yt)
    ax2.set_yticklabels(Yticks)
    ax2.set_ylim(ymin=(np.log2(result['period'][0])), ymax=(
        np.log2(result['period'][-1])))
    ax2.invert_yaxis()
    ax2.set_xlabel('Time', fontsize=13)
    ax2.set_ylabel('Period', fontsize=13)
    ax2.axhline(y=10.5, xmin=0, xmax=1, linewidth=2, color='k')
    ax2.axhline(y=13.3, xmin=0, xmax=1, linewidth=2, color='k')
    ax2.set_title('Cross Power')
    plt.savefig('Cross Power {0} vs {1}'.format(
        result['name'], result1['name']), dpi=300)
    return
    
In [222]:
    
def plot_cohere(var, coherence, time, result, result1):
    """
    PLOT COHERENCE
    coherence   =  from cross_wavelet function
    time       =  time vector from load function
    result      =  dict from cwt function
    """
    import matplotlib.gridspec as gridspec
    from copy import copy
    import matplotlib.colors as colors
    fig = plt.figure(figsize=(15, 14), dpi=300)
    # set plot grid
    gs1 = gridspec.GridSpec(4, 3)
    gs1.update(left=0.05)
    ax1 = plt.subplot(gs1[0, :])
    ax1 = pylab.gca()
    ax1.xaxis.set_visible(False)
    plt.setp(ax1.get_xticklabels(), visible=False)
    ax2 = plt.subplot(gs1[1:4, :])#, axisbg='#C0C0C0')
    # plot timeseries
    ax1.plot(time, result['data'])
    ax1.set_ylabel('Amplitude', fontsize=13)
    ax3 = ax1.twinx()
    ax3.plot(time, result1['data'], color='red')
    ax3.set_ylabel(result1['name'])
    ax1.set_title('%s' % var, fontsize=15)
    ax1.yaxis.set_major_locator(MaxNLocator(prune='lower'))
    ax1.grid(True)
    ax1.xaxis.set_visible(False)
    ax3.axis('tight')
    ax1.axis('tight')
    # fig = plt.figure(figsize=(15,10), dpi=100)
    lev = list(np.linspace(0, 1.0, 21))
    #palette = copy(plt.cm.inferno)
    #palette.set_over('r', 1.0)
    #palette.set_under('k', 1.0)
    #palette.set_bad('b', 1.0)
    CS = ax2.contourf(time, np.log2(result['period']), coherence,lev)
    ax2.plot(time, np.log2(result['coi']), 'k')
    ax2.fill_between(time, np.log2(result['coi']), int(
        np.log2(result['period'][-1]) * 2), alpha=0.5, hatch='/')
    position = fig.add_axes([0.15, 0.05, 0.6, 0.01])
    cbar = plt.colorbar(CS, cax=position, orientation='horizontal')
    cbar.set_label('coherence')
    yt = range(int(np.log2(result['period'][0])), int(
        np.log2(result['period'][-1]) + 1))  # create the vector of periods
    Yticks = [float(math.pow(2, p)) for p in yt]  # make 2^periods
    Yticks = [int(i) for i in Yticks]
    ax2.set_yticks(yt)
    ax2.set_yticklabels(Yticks)
    ax2.set_ylim(ymin=(np.log2(result['period'][0])), ymax=(
        np.log2(result['period'][-1])))
    ax2.invert_yaxis()
    ax2.set_xlabel('Time', fontsize=15)
    ax2.set_ylabel('Period', fontsize=15)
    # ax2.axhline(y=10.5, xmin=0, xmax=1, linewidth=2, color='k')
    # ax2.axhline(y=13.3, xmin=0, xmax=1, linewidth=2, color='k')
    ax2.set_title('Coherence')
    #ax2.axis('tight')
    plt.savefig('Coherence {0} vs {1}'.format(
        result['name'], result1['name']), dpi=300)
    return
    
In [300]:
    
# using the example : example_python3.6.ipynb
import numpy as np
from pylab import *
import waipy
z = np.linspace(0,2048,2048)
x = np.sin(50*np.pi*z)
data_norm_x = waipy.normalize(x)
result_x = waipy.cwt(data_norm_x, 1, 1, 0.125, 2, 4/0.125, 0.72, 6,mother='Morlet',name='x')
waipy.wavelet_plot('x', z, x, 0.03125, result_x)
    
    
    
    
In [301]:
    
#let's create another artificial signal to compare with that!
y = np.cos(50*np.pi*z)  # signal out of phase from x
#noise = np.sin(7*np.pi*z)+np.cos(2*np.pi*z)
#y = np.sin(30*np.pi*z)+noise #out of phase signal compared to x + noise
#data_norm_y = waipy.normalize(y)
result_y = waipy.cwt(data_norm_y, 1, 1, 0.125, 2, 4/0.125, 0.72, 6,mother='Morlet',name='x')
waipy.wavelet_plot('y', z, y, 0.03125, result_y)
    
    
    
    
In [240]:
    
cross_power, coherence, phase_angle = cross_wavelet(result_x['wave'],result_y['wave'])
plot_cross('signals', cross_power, phase_angle, z, result_x, result_y)
#Arrows indicate in phase when pointing to the right and out of phase when pointing left.
    
    
In [303]:
    
cross_power, coherence, phase_angle = cross_wavelet(result_x['wave'],result_y['wave'])
plot_cohere('signals',coherence,z,result_x, result_y)
    
    
In [305]:
    
cross_power, coherence, phase_angle = cross_wavelet(result_x['wave'],result_x['wave'])
plot_cohere('signals',coherence,z,result_x, result_x)
    
    
In [ ]: