In [95]:
import matplotlib.pyplot as plt

In [101]:
import scipy
import numpy as np

def subtractContinuum(s, n=3, plot=False):
    '''Take a 1D array, use spline to subtract off continuum.
            subtractContinuum(s, n=3)
            required:
            s = the array
            optional:
            n = 3, the number of spline points to use
    '''

    x = np.arange(len(s))
    points = (np.arange(n)+1)*len(s)/(n+1)
    spline = scipy.interpolate.LSQUnivariateSpline(x,s,points)
    if plot:
        plt.ion()
        plt.figure()
        plt.plot(x, s)
        plt.plot(x, spline(x), linewidth=5, alpha=0.5)
    return s - spline(x)


def ccf(f, g, scale=1.0):
    '''Calculate the normalized cross-correlation function of two identically-size arrays.
        [required]:
        f = an N-element array (for example, spectrum of target star)
        g = an N-element array (for example, spectrum of template star)
        scale = a scalar indicating what the indices of f and g (one unit of "lag") correspond to
    '''

    # how long are our arrays
    N = len(f)

    # define the x-axis, if not supplied
    assert(len(f) == len(g))
    x = np.arange(-N+1, N, 1.0)*scale

    # calculation the normalized cross-correlation function
    sigma_f = np.sqrt(np.sum(f**2)/N)
    sigma_g = np.sqrt(np.sum(g**2)/N)
    C_fg = np.correlate(f, g, 'full')/N/sigma_f/sigma_g

    # WILL THIS WORK?
    return scipy.interpolate.interp1d(x,C_fg, fill_value=0.0, bounds_error=False)


def todcor(f, g1, g2, scale=1.0, luminosity_ratio=None):
    '''Calculate the 2D correlation of a 1D array with two template arrays.
    
    '''

    assert(len(f) == len(g1))
    assert(len(f) == len(g2))

    C_1 = ccf(f, g1, scale=scale)
    C_2 = ccf(f, g2, scale=scale)
    C_12 = ccf(g1, g2, scale=scale)

    N = len(f)
    sigma_g1 = np.sqrt(np.sum(g1**2)/N)
    sigma_g2 = np.sqrt(np.sum(g2**2)/N)

    def bestalphaprime(s1, s2):
        return sigma_g1/sigma_g2*(C_1(s1)*C_12(s2 - s1) - C_2(s2))/(C_2(s2)*C_12(s2-s1) - C_1(s1))

    def R(s1, s2):
        #a =
        if luminosity_ratio is None:
            a = np.maximum(np.minimum(bestalphaprime(s1,s2), sigma_g2/sigma_g1), 0.0)
            flexiblecorrelation = (C_1(s1) + a*C_2(s2))/np.sqrt(1.0 + 2*a*C_12(s2 - s1) + a**2)
            ok = np.isfinite(a)
            peak = np.argmax(flexiblecorrelation[ok].flatten())
            a = a[ok].flatten()[peak]
        else:
            a = luminosity_ratio*sigma_g2/sigma_g1
        # print("alpha spans", np.min(a), np.max(a))
        return (C_1(s1) + a*C_2(s2))/np.sqrt(1.0 + 2*a*C_12(s2 - s1) + a**2), a*sigma_g1/sigma_g2

    return R

In [108]:
from astropy.io import fits 
spec1 = fits.getdata("../tests/testdata/lte05200-4.50-0.0.PHOENIX-ACES-AGSS-COND-2011-HiRes.fits")
spec2 = fits.getdata("../tests/testdata/lte03400-4.50-0.0.PHOENIX-ACES-AGSS-COND-2011-HiRes.fits")
wav = fits.getdata("../tests/testdata/WAVE_PHOENIX-ACES-AGSS-COND-2011.fits")
wav = wav/10

from PyAstronomy import pyasl

spec1_shifted, wlprime = pyasl.dopplerShift(wav, spec1, -5)
spec2_shifted, wlprime = pyasl.dopplerShift(wav, spec2, 15)
spec_3 = spec1_shifted + spec2_shifted


spec_1 = spec1[(wav > 2100) * (wav< 2200)]
spec_2 = spec2[(wav > 2100) * (wav< 2200)]
spec_3 = spec_3[(wav > 2100) * (wav< 2200)]
wav_3 = wav[(wav > 2100) * (wav< 2200)]

# need to log resample so that evenly spaced in rv 
def log_resample(x, drv=0.01):
    logx = np.log(x)
    logx = np.arange(logx[0], logx[-1], drv)
    return np.exp(logx)

new_full_wav = log_resample(wav, 0.001)
new_wav = log_resample(wav_3)
new_s1 = np.interp(new_full_wav, wav, spec1)
new_s2 = np.interp(new_full_wav, wav, spec2)
new_s3 = np.interp(new_wav, wav_3, spec_3)

In [107]:
plt.plot(wav, spec1, label="1")
plt.plot(new_full_wav, new_s1, label="1")
plt.show()
plt.plot(wav_3, spec2, label="2")
plt.show()
plt.plot(wav_3, spec_3, label="3")
plt.legend()
plt.show()


---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-107-c4bafa38ccab> in <module>()
      2 plt.plot(new_full_wav, new_s1, label="1")
      3 plt.show()
----> 4 plt.plot(wav_3, spec2, label="2")
      5 plt.show()
      6 plt.plot(wav_3, spec_3, label="3")

~\Miniconda3\envs\work\lib\site-packages\matplotlib\pyplot.py in plot(*args, **kwargs)
   3238                       mplDeprecation)
   3239     try:
-> 3240         ret = ax.plot(*args, **kwargs)
   3241     finally:
   3242         ax._hold = washold

~\Miniconda3\envs\work\lib\site-packages\matplotlib\__init__.py in inner(ax, *args, **kwargs)
   1708                     warnings.warn(msg % (label_namer, func.__name__),
   1709                                   RuntimeWarning, stacklevel=2)
-> 1710             return func(ax, *args, **kwargs)
   1711         pre_doc = inner.__doc__
   1712         if pre_doc is None:

~\Miniconda3\envs\work\lib\site-packages\matplotlib\axes\_axes.py in plot(self, *args, **kwargs)
   1435         kwargs = cbook.normalize_kwargs(kwargs, _alias_map)
   1436 
-> 1437         for line in self._get_lines(*args, **kwargs):
   1438             self.add_line(line)
   1439             lines.append(line)

~\Miniconda3\envs\work\lib\site-packages\matplotlib\axes\_base.py in _grab_next_args(self, *args, **kwargs)
    402                 this += args[0],
    403                 args = args[1:]
--> 404             for seg in self._plot_args(this, kwargs):
    405                 yield seg
    406 

~\Miniconda3\envs\work\lib\site-packages\matplotlib\axes\_base.py in _plot_args(self, tup, kwargs)
    382             x, y = index_of(tup[-1])
    383 
--> 384         x, y = self._xy_from_xy(x, y)
    385 
    386         if self.command == 'plot':

~\Miniconda3\envs\work\lib\site-packages\matplotlib\axes\_base.py in _xy_from_xy(self, x, y)
    241         if x.shape[0] != y.shape[0]:
    242             raise ValueError("x and y must have same first dimension, but "
--> 243                              "have shapes {} and {}".format(x.shape, y.shape))
    244         if x.ndim > 2 or y.ndim > 2:
    245             raise ValueError("x and y can be no greater than 2-D, but have "

ValueError: x and y must have same first dimension, but have shapes (24999,) and (1569128,)

In [ ]:
s1 = subtractContinuum(spec1, n=5)
s2 = subtractContinuum(spec2,n=5)
s3 = subtractContinuum(spec_3,n=5)



plt.plot(s1)
plt.plot(s2)
plt.plot(s3)
plt.show()

In [93]:
R = todcor(s3, s1, s2)

In [97]:
rv = np.arange(-50, 50, 0.5)
gamma = np.arange(-50, 50, 1)
RV, GAM = np.meshgrid(rv, gamma, indexing="ij")
c_n1_n2 = np.empty((len(rv), len(gamma)))
a_n1_n2 = np.empty((len(rv), len(gamma)))

for jj, y in enumerate(gamma):
    for ii, x in enumerate(rv):

        res = R(x, y) 
        c_n1_n2[ii, jj] = res[0]
        a_n1_n2[ii, jj] = res[1]

In [ ]:


In [109]:
plt.contourf(RV, GAM, c_n1_n2)

plt.title("correlation value")
plt.show()

a_n1_n2
plt.contourf(RV, GAM, a_n1_n2)
plt.title("flux ratio")
plt.show()



In [ ]:


In [ ]:


In [ ]: