The dirichlet calculation is unstable near multiples of 2pi. The existing scipy code has a fix only at exact multiples of 2pi. This is demonstrated below, where for odd n, the output is sometimes substantially outside the expected range of [-1, 1]


In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import scipy
import scipy.special as specfun

print("scipy version = {}".format(scipy.__version__))

x = np.linspace(-8*np.pi,8*np.pi,201)
for xtype in [np.float64,]:  
    x = x.astype(xtype)
    plt.figure(figsize=(8,8)); 
    for idx,n in enumerate([2,3,4,9]): 
        plt.subplot(2,2,idx+1)
        plt.plot(x, specfun.diric(x, n))
        plt.title('diric, n={}'.format(n))


scipy version = 0.14.0

zoom around a region of numerical instability:


In [2]:
x = np.linspace(2*np.pi-1e-12, 2*np.pi+1e-12, 1001)
plt.figure(); plt.plot(x, specfun.diric(x,7)); plt.axis('tight');

#single precision case
x = np.linspace(2*np.pi-1e-4, 2*np.pi+1e-4, 1001).astype(np.float32)
plt.figure(); plt.plot(x, specfun.diric(x,7)); plt.axis('tight');


The existing dirichlet also allows array input for n, so some weird cases such as the following are possible


In [3]:
x = np.linspace(-8*np.pi,8*np.pi,201)
n = 7*np.ones_like(x)
n[:101] = 3
n[151:161] = np.nan
plt.figure(); plt.plot(x, specfun.diric(x, n))


/home/lee8rx/anaconda/lib/python2.7/site-packages/scipy/special/basic.py:60: RuntimeWarning: invalid value encountered in less_equal
  mask1 = (n <= 0) | (n != floor(n))
Out[3]:
[<matplotlib.lines.Line2D at 0x7f95cfba2210>]

Modified version of diric is defined below that fixes the instability and disallows non-positive integer n.


In [4]:
from numpy import (pi, asarray, floor, isscalar, sin, place, issubdtype, extract,
                   inexact, nan, zeros)

def new_diric(x,n):
    """Returns the periodic sinc function, also called the Dirichlet function

    The Dirichlet function is defined as::

        diric(x) = sin(x * n/2) / (n * sin(x / 2)),

    where n is a positive integer.

    Parameters
    ----------
    x : array_like
        Input data
    n : int
        Integer defining the periodicity.

    Returns
    -------
    diric : ndarray

    """
    if not isscalar(n) or (n < 0) or (n != floor(n)):
        raise ValueError("n must be a non-negative integer.")
    x = asarray(x)
    if issubdtype(x.dtype, inexact):
        ytype = x.dtype
    else:
        ytype = float
    y = zeros(x.shape, ytype)

    if np.finfo(ytype).eps < 1e-15:
        minval = 1e-9
    else:
        minval = 1e-3

    x = x/2.0
    denom = sin(x)
    mask1 = abs(denom) < minval
    xsub = extract(mask1, x)
    zsub = xsub / pi
    place(y, mask1, pow(-1, np.round(zsub)*(n-1)))    
    
    mask = (1-mask1)
    xsub = extract(mask, x)
    dsub = extract(mask, denom)
    place(y, mask, sin(n*xsub)/(n*dsub))

    return y

In [5]:
x = np.linspace(-8*np.pi,8*np.pi,201)

for dtype in [np.float64, np.float32]:
    plt.figure(figsize=(8,8)); 
    for idx,n in enumerate([2,3,5,9]): 
        plt.subplot(2,2,idx+1)
        plt.plot(x, new_diric(x, n))
        plt.title('diric, n={}'.format(n))



In [6]:
x = np.linspace(2*np.pi-1e-12,2*np.pi+1e-12,10001)
plt.figure(); plt.plot(x, new_diric(x,7)); plt.axis('tight');

#single precision case
x = np.linspace(2*np.pi-1e-4,2*np.pi+1e-4,10001).astype(np.float32)
plt.figure(); plt.plot(x, new_diric(x,7)); plt.axis('tight');



In [7]:
from numpy.testing import assert_raises
assert_raises(ValueError, new_diric, x, 0.1)
assert_raises(ValueError, new_diric, x, -1)
assert_raises(ValueError, new_diric, x, np.ones_like(x))

In [7]: