The dirichlet calculation is unstable near multiples of 2pi. The existing scipy code has a fix only at exact multiples of 2pi. This is demonstrated below, where for odd n, the output is sometimes substantially outside the expected range of [-1, 1]
In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import scipy
import scipy.special as specfun
print("scipy version = {}".format(scipy.__version__))
x = np.linspace(-8*np.pi,8*np.pi,201)
for xtype in [np.float64,]:
x = x.astype(xtype)
plt.figure(figsize=(8,8));
for idx,n in enumerate([2,3,4,9]):
plt.subplot(2,2,idx+1)
plt.plot(x, specfun.diric(x, n))
plt.title('diric, n={}'.format(n))
zoom around a region of numerical instability:
In [2]:
x = np.linspace(2*np.pi-1e-12, 2*np.pi+1e-12, 1001)
plt.figure(); plt.plot(x, specfun.diric(x,7)); plt.axis('tight');
#single precision case
x = np.linspace(2*np.pi-1e-4, 2*np.pi+1e-4, 1001).astype(np.float32)
plt.figure(); plt.plot(x, specfun.diric(x,7)); plt.axis('tight');
In [3]:
x = np.linspace(-8*np.pi,8*np.pi,201)
n = 7*np.ones_like(x)
n[:101] = 3
n[151:161] = np.nan
plt.figure(); plt.plot(x, specfun.diric(x, n))
Out[3]:
In [4]:
from numpy import (pi, asarray, floor, isscalar, sin, place, issubdtype, extract,
inexact, nan, zeros)
def new_diric(x,n):
"""Returns the periodic sinc function, also called the Dirichlet function
The Dirichlet function is defined as::
diric(x) = sin(x * n/2) / (n * sin(x / 2)),
where n is a positive integer.
Parameters
----------
x : array_like
Input data
n : int
Integer defining the periodicity.
Returns
-------
diric : ndarray
"""
if not isscalar(n) or (n < 0) or (n != floor(n)):
raise ValueError("n must be a non-negative integer.")
x = asarray(x)
if issubdtype(x.dtype, inexact):
ytype = x.dtype
else:
ytype = float
y = zeros(x.shape, ytype)
if np.finfo(ytype).eps < 1e-15:
minval = 1e-9
else:
minval = 1e-3
x = x/2.0
denom = sin(x)
mask1 = abs(denom) < minval
xsub = extract(mask1, x)
zsub = xsub / pi
place(y, mask1, pow(-1, np.round(zsub)*(n-1)))
mask = (1-mask1)
xsub = extract(mask, x)
dsub = extract(mask, denom)
place(y, mask, sin(n*xsub)/(n*dsub))
return y
In [5]:
x = np.linspace(-8*np.pi,8*np.pi,201)
for dtype in [np.float64, np.float32]:
plt.figure(figsize=(8,8));
for idx,n in enumerate([2,3,5,9]):
plt.subplot(2,2,idx+1)
plt.plot(x, new_diric(x, n))
plt.title('diric, n={}'.format(n))
In [6]:
x = np.linspace(2*np.pi-1e-12,2*np.pi+1e-12,10001)
plt.figure(); plt.plot(x, new_diric(x,7)); plt.axis('tight');
#single precision case
x = np.linspace(2*np.pi-1e-4,2*np.pi+1e-4,10001).astype(np.float32)
plt.figure(); plt.plot(x, new_diric(x,7)); plt.axis('tight');
In [7]:
from numpy.testing import assert_raises
assert_raises(ValueError, new_diric, x, 0.1)
assert_raises(ValueError, new_diric, x, -1)
assert_raises(ValueError, new_diric, x, np.ones_like(x))
In [7]: