Example reading fits file

Here we will read a fits file using astropy, a very common module used by many astronomers. Helping you.

First we need to import the modules in python we need.


In [ ]:
from astropy.io import fits
import numpy as np
import math

Now we are ready to open the file. We need to know that a FITS file consists of one or more "header-data-units (hdu)"


In [ ]:
filename = 'test3_0652.fits'

the fits file from the camera has some weird END card problem. stackoverflow tells us to use this special flag to ignore that error (but there is an END card....)


In [ ]:
%time hdu = fits.open(filename,ignore_missing_end=True)
print("We found ",len(hdu), "HDU")

We now extract the header and the data. Turns out the data from the camera a 16 bit unsigned. and will quickly run into overflow problems. so we convert to a float type.


In [ ]:
head = hdu[0].header
%time data = hdu[0].data.astype(np.float32)

In [ ]:
# head  is now a python dictionary
print(head.keys)

In [ ]:
print(type(data))
print(data.shape)

In [ ]:
x=757
y=304
#   this is an  example bad pixel
print(filename,x,y,data[y-1,x-1],type(data[0,0]))

Now we will fix this bad pixel value by replacing it with the average of the 4 pixels below, above, right and left


In [ ]:
newvalue = (data[y-2,x-1]  + data[y,x-1] + data[y-1,x-2]  +  data[y-1,x])/4.0
print(newvalue)
data[y-1,x-1] = newvalue

Now we will write a function that loops over the whole data array (2dim) and patches each pixel which deviates more than eps (relative number) from it's neighbors


In [ ]:
def patch_badpixels1(data, eps=0.1):
    """find bad pixels by comparing to some neighbor pixel values
    
        this algorithm is slow (5-10 seconds on a 1000x1300 image) because it loops
        over all pixels in python, and explicit looping in python is slow
        v2 with 4 neighbors took 6 sec (bad algorithm)
        v2 with improved 4 neighbors too 2.2 sec
        v2 with 8 neighbors took 9 sec (bad algorithm)
    """
    nx = data.shape[1]
    ny = data.shape[0]
    nbad = 0
    for ix in range(1,nx-1):
        for iy in range(1,ny-1):
            v1 = data[iy,ix]
            #v2 = (data[iy-1,ix] + data[iy+1,ix] + data[iy,ix-1] + data[iy,ix+1])/4.0
            #v2 =  (data[iy-1:iy+2,ix-1:ix+2].sum() - v1)/8.0            
            if abs(v1 - data[iy-1,ix])  < eps: continue
            if abs(v1 - data[iy+1,ix])  < eps: continue
            if abs(v1 - data[iy,ix-1])  < eps: continue
            if abs(v1 - data[iy,ix+1])  < eps: continue
            v2 = (data[iy-1,ix] + data[iy+1,ix] + data[iy,ix-1] + data[iy,ix+1])/4.0      # 4 point
            # only look for positive bad pixels
            nbad = nbad + 1
            print("Bad pixel",nbad,ix+1,iy+1,v1,v2)
            data[iy,ix] = v2
    return nbad

And now a slightly more elegant version with a helper function (good_pixel) which allow you to grab any pixels nearby and find the average of those to replace the bad pixel by. This way it was easier to switch between 4 pixel and 8 pixel averages


In [ ]:
def patch_badpixels2(data, eps=0.1):
    def good_pixel(v1,v2,dat,eps):
        if v1-dat < eps: return True
        v2.append(dat)
        return False
    
    nx = data.shape[1]
    ny = data.shape[0]
    nbad = 0
    for ix in range(1,nx-1):
        for iy in range(1,ny-1):
            v1 = data[iy,ix]
            v2 = []
            # top/bottom/left/right
            if good_pixel(v1,v2,data[iy-1,ix],eps): continue
            if good_pixel(v1,v2,data[iy+1,ix],eps): continue
            if good_pixel(v1,v2,data[iy,ix-1],eps): continue
            if good_pixel(v1,v2,data[iy,ix+1],eps): continue
            # 4 corners
            if True:
                if good_pixel(v1,v2,data[iy-1,ix-1],eps): continue
                if good_pixel(v1,v2,data[iy+1,ix-1],eps): continue
                if good_pixel(v1,v2,data[iy+1,ix+1],eps): continue
                if good_pixel(v1,v2,data[iy-1,ix+1],eps): continue
            v2 = np.array(v2).mean()
            nbad = nbad + 1
            print("Bad pixel",nbad,ix+1,iy+1,v1,v2)
            data[iy,ix] = v2
    return nbad

In [ ]:
data1 = data.copy()
%time nbad = patch_badpixels2(data1,5000)
print("Patched ",nbad)

In [ ]: