Source code for spaceKLIP.imagetools
from __future__ import division
import matplotlib
# =============================================================================
# IMPORTS
# =============================================================================
# general imports
import os
import re
import logging
import sys
import json
import types
import copy
import lmfit
import astrofix
import numpy as np
from copy import deepcopy
from tqdm.auto import trange
# astropy imports
import pysiaf
import astropy.stats
import astropy.io.fits as pyfits
from astropy.io import fits
# plotting imports
import matplotlib.pyplot as plt
import matplotlib as mpl
from skimage.registration import phase_cross_correlation
from sklearn.cluster import KMeans
# scipy imports
import scipy.ndimage
from scipy.ndimage import gaussian_filter, median_filter
from scipy.ndimage import shift as spline_shift
from scipy.ndimage import label, convolve, binary_dilation
from scipy.optimize import leastsq, minimize
from scipy.interpolate import griddata
# webbpsf_ext imports
import webbpsf_ext
from webbpsf_ext import robust
from webbpsf_ext.coords import dist_image
from webbpsf_ext.webbpsf_ext_core import _transmission_map
from stpsf.constants import JWST_CIRCUMSCRIBED_DIAMETER
# spaceKLIP imports
from spaceKLIP import utils as ut
from spaceKLIP.psf import JWST_PSF
from spaceKLIP.xara import core
from spaceKLIP.utils import gaussian_kernel
from spaceKLIP.psf import get_offsetpsf
from spaceKLIP.pyklippipeline import get_pyklip_filepaths
from spaceKLIP import mcmc_tools
from spaceKLIP.target_acq_tools import ta_analysis
from spaceKLIP.starphot import get_stellar_magnitudes, read_spec_file
from spaceKLIP.plotting import load_plt_style
# pyklip imports
import pyklip.fakes as fakes
from pyklip import parallelized
from pyklip.instruments.JWST import JWSTData
# jwst imports
import jwst.datamodels
from jwst.datamodels import dqflags
from stdatamodels.jwst import datamodels
from jwst.datamodels import ModelContainer, ModelLibrary
from jwst.resample import resample_step
# Set up log.
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)
# =============================================================================
# MAIN
# =============================================================================
# Load NIRCam true mask centers and filter-dependent shifts from Jarron.
path = 'resources/crpix_jarron.json'
path = os.path.join(os.path.split(os.path.abspath(__file__))[0], path)
file = open(path, 'r')
crpix_jarron = json.load(file)
file.close()
path = 'resources/filter_shifts_jarron.json'
path = os.path.join(os.path.split(os.path.abspath(__file__))[0], path)
file = open(path, 'r')
filter_shifts_jarron = json.load(file)
file.close()
[docs]
class ImageTools():
"""
The spaceKLIP image manipulation tools class.
"""
def __init__(self,
database):
"""
Initialize the spaceKLIP image manipulation tools class.
Parameters
----------
database : spaceKLIP.Database
SpaceKLIP database on which the image manipulation steps shall be
run.
Returns
-------
None.
"""
# Make an internal alias of the spaceKLIP database class.
self.database = database
pass
def _get_output_dir(self,
subdir):
"""Utility function to get full output dir path, and create it if needed"""
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
return output_dir
def _iterate_function_over_files(self,
types,
file_transformation_function,
restrict_to=None):
"""
Iterate some callable function over all files in a database.
This is a repetitive pattern used in many of the image processing functions, so
we abstract it here to reduce code repetition.
The file transformation function should take one filename as an input, perform some transformation or image processing
write out the file to some new path, and return the output filename.
Any other arguments should be provided prior to passing in the function, for instance
via functools.partial if necessary.
"""
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
# if we limit to only processing some concatenations, check whether this concatenation matches the pattern
if (restrict_to is not None) and (restrict_to not in key):
continue
log.info('--> Concatenation ' + key)
# Loop through FITS files.
nfitsfiles = len(self.database.obs[key])
for j in range(nfitsfiles):
# Read FITS file.
filename = self.database.obs[key]['FITSFILE'][j]
# Only process files of the specified types.
# (skip any files with types that are not in the list of types.)
if self.database.obs[key]['TYPE'][j] in types:
output_filename = file_transformation_function(filename)
# Update spaceKLIP database.
self.database.update_obs(key, j, output_filename)
[docs]
def remove_frames(self,
index=[0],
types=['SCI', 'SCI_BG', 'REF', 'REF_BG'],
subdir='removed'):
"""
Remove individual frames from the data.
Parameters
----------
index : int or list of int or dict of list of list of int, optional
Indices (0-indexed) of the frames to be removed. If int, then only
a single frame will be removed from each observation. If list of
int, then multiple frames can be removed from each observation. If
dict of list of list of int, then the dictionary keys must match
the keys of the observations database, and the number of entries in
the lists must match the number of observations in the
corresponding concatenation. Then, a different list of int can be
used for each individual observation to remove different frames.
The default is [0].
types : list of str, optional
List of data types from which the frames shall be removed. The
default is ['SCI', 'SCI_BG', 'REF', 'REF_BG'].
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'removed'.
Returns
-------
None.
"""
# Check input.
if isinstance(index, int):
index = [index]
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
# Loop through FITS files.
nfitsfiles = len(self.database.obs[key])
for j in range(nfitsfiles):
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
mask = ut.read_msk(maskfile)
nints = self.database.obs[key]['NINTS'][j]
# Skip file types that are not in the list of types.
if self.database.obs[key]['TYPE'][j] in types:
# Remove frames.
head, tail = os.path.split(fitsfile)
log.info(' --> Frame removal: ' + tail)
try:
index_temp = index[key][j]
except:
index_temp = index.copy()
log.info(' --> Frame removal: removing frame(s) ' + str(index_temp))
data = np.delete(data, index_temp, axis=0)
erro = np.delete(erro, index_temp, axis=0)
pxdq = np.delete(pxdq, index_temp, axis=0)
if align_shift is not None:
align_shift = np.delete(align_shift, index_temp, axis=0)
if center_shift is not None:
center_shift = np.delete(center_shift, index_temp, axis=0)
# There is only a single value for align_mask and center_mask, so we won't reshape
# This should be improved in the future once the PSF mask is properly handled at the integration level.
# if align_mask is not None:
# align_mask = np.delete(align_mask, index_temp, axis=0)
# if center_mask is not None:
# center_mask = np.delete(center_mask, index_temp, axis=0)
if maskoffs is not None:
maskoffs = np.delete(maskoffs, index_temp, axis=0)
nints = data.shape[0]
# Write FITS file and PSF mask.
head_pri['NINTS'] = nints
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile, nints=nints)
pass
[docs]
def crop_frames(self,
npix=1,
types=['SCI', 'SCI_BG', 'REF', 'REF_BG'],
subdir='cropped'):
"""
Crop all frames.
Parameters
----------
npix : int or list of four int, optional
Number of pixels to be cropped from the frames. If int, the same
number of pixels will be cropped on each side. If list of four int,
a different number of pixels can be cropped from the [left, right,
bottom, top] of the frames. The default is 1.
types : list of str, optional
List of data types from which the frames shall be cropped. The
default is ['SCI', 'SCI_BG', 'REF', 'REF_BG'].
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'cropped'.
Returns
-------
None.
"""
# Check input.
if isinstance(npix, int):
npix = [npix, npix, npix, npix] # left, right, bottom, top
if len(npix) != 4:
raise UserWarning('Parameter npix must either be an int or a list of four int (left, right, bottom, top)')
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
# Loop through FITS files.
nfitsfiles = len(self.database.obs[key])
for j in range(nfitsfiles):
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
nanmaskfile = self.database.obs[key]['NANMASKFILE'][j]
mask = ut.read_msk(maskfile)
nanmask = ut.read_msk(nanmaskfile)
crpix1 = self.database.obs[key]['CRPIX1'][j]
crpix2 = self.database.obs[key]['CRPIX2'][j]
starcenx = self.database.obs[key]['STARCENX'][j]
starceny = self.database.obs[key]['STARCENY'][j]
maskcenx = self.database.obs[key]['MASKCENX'][j]
maskceny = self.database.obs[key]['MASKCENY'][j]
# Skip file types that are not in the list of types.
if self.database.obs[key]['TYPE'][j] in types:
# Crop frames.
head, tail = os.path.split(fitsfile)
log.info(' --> Frame cropping: ' + tail)
sh = data.shape
data = data[:, npix[2]:-npix[3], npix[0]:-npix[1]]
erro = erro[:, npix[2]:-npix[3], npix[0]:-npix[1]]
pxdq = pxdq[:, npix[2]:-npix[3], npix[0]:-npix[1]]
if mask is not None:
mask = mask[npix[2]:-npix[3], npix[0]:-npix[1]]
if nanmask is not None:
nanmask = nanmask[npix[2]:-npix[3], npix[0]:-npix[1]]
crop_shiftx = npix[0]
crop_shifty = npix[2]
crpix1 -= crop_shiftx
crpix2 -= crop_shifty
starcenx -= crop_shiftx
starceny -= crop_shifty
maskcenx -= crop_shiftx
maskceny -= crop_shifty
log.info(' --> Frame cropping: old shape = ' + str(sh[1:]) + ', new shape = ' + str(data.shape[1:]))
# Write FITS file and PSF mask.
head_sci['CRPIX1'] = crpix1
head_sci['CRPIX2'] = crpix2
head_sci['STARCENX'] = starcenx
head_sci['STARCENY'] = starceny
head_sci['MASKCENX'] = maskcenx
head_sci['MASKCENY'] = maskceny
head_sci['CROP_SHIFTX'] = crop_shiftx # Store crop shift.
head_sci['CROP_SHIFTY'] = crop_shifty
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
nanmaskfile = ut.write_msk(nanmaskfile, nanmask, fitsfile, '_nanmask.fits')
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile,
crpix1=crpix1, crpix2=crpix2,
starcenx=starcenx, starceny=starceny,
maskcenx=maskcenx, maskceny=maskceny,
crop_shiftx=crop_shiftx, crop_shifty=crop_shifty, nanmaskfile=nanmaskfile)
pass
[docs]
def pad_frames(self,
npix=1,
tshape=None,
cval=np.nan,
types=['SCI', 'SCI_BG', 'REF', 'REF_BG'],
subdir='padded'):
"""
Pad all frames.
Parameters
----------
npix : int or list of four int, optional
Number of pixels to be padded around the frames. If int, the same
number of pixels will be padded on each side. If list of four int,
a different number of pixels can be padded on the [left, right,
bottom, top] of the frames. The default is 1.
npix : int or list of four int, optional
target shape in pixels to reshape the frame into it. If int, the same
shape will be applied to the each axis. If list of 2 int,
a different shape owill be applied [x, y] to the frames.
If none skip and use the default npix. The default is None.
cval : float, optional
Fill value for the padded pixels. The default is nan.
types : list of str, optional
List of data types from which the frames shall be padded. The
default is ['SCI', 'SCI_BG', 'REF', 'REF_BG'].
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'padded'.
Returns
-------
None.
"""
if tshape is not None:
if isinstance(tshape, int):
tshape = [tshape, tshape] # y,x
if len(tshape) != 2:
raise UserWarning( 'Parameter tshape must either be an int or a list of 2 int (y,x)')
else:
# Check input.
if isinstance(npix, int):
npix = [npix, npix, npix, npix] # left, right, bottom, top
if len(npix) != 4:
raise UserWarning('Parameter npix must either be an int or a list of four int (left, right, bottom, top)')
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
# Loop through FITS files.
nfitsfiles = len(self.database.obs[key])
for j in range(nfitsfiles):
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
nanmaskfile = self.database.obs[key]['NANMASKFILE'][j]
mask = ut.read_msk(maskfile)
nanmask = ut.read_msk(nanmaskfile)
crpix1 = self.database.obs[key]['CRPIX1'][j]
crpix2 = self.database.obs[key]['CRPIX2'][j]
starcenx = self.database.obs[key]['STARCENX'][j]
starceny = self.database.obs[key]['STARCENY'][j]
maskcenx = self.database.obs[key]['MASKCENX'][j]
maskceny = self.database.obs[key]['MASKCENY'][j]
# Skip file types that are not in the list of types.
if self.database.obs[key]['TYPE'][j] in types:
# Crop frames.
head, tail = os.path.split(fitsfile)
log.info(' --> Frame padding: ' + tail)
sh = data.shape
if tshape is not None:
y_shape = np.array([(tshape[0]-sh[1])//2,(tshape[0]-sh[1])//2]) if (tshape[0]-sh[1])%2 == 0 else np.array([(tshape[0]-sh[1])//2,(tshape[0]-sh[1])//2+1])
x_shape = np.array([(tshape[1]-sh[2])//2,(tshape[1]-sh[2])//2]) if (tshape[1]-sh[2])%2 == 0 else np.array([(tshape[1]-sh[2])//2,(tshape[1]-sh[2])//2+1])
npix = np.append(x_shape,y_shape)
data = np.pad(data, ((0, 0), (npix[2], npix[3]), (npix[0], npix[1])), mode='constant', constant_values=cval)
erro = np.pad(erro, ((0, 0), (npix[2], npix[3]), (npix[0], npix[1])), mode='constant', constant_values=cval)
pxdq = np.pad(pxdq, ((0, 0), (npix[2], npix[3]), (npix[0], npix[1])), mode='constant', constant_values=0)
if mask is not None:
mask = np.pad(mask, ((npix[2], npix[3]), (npix[0], npix[1])), mode='constant', constant_values=np.nan)
if nanmask is not None:
nanmask = np.pad(nanmask, ((npix[2], npix[3]), (npix[0], npix[1])), mode='constant',
constant_values=1)
crpix1 += npix[0]
crpix2 += npix[2]
starcenx += npix[0]
starceny += npix[2]
maskcenx += npix[0]
maskceny += npix[2]
log.info(' --> Frame padding: old shape = ' + str(sh[1:]) + ', new shape = ' + str(data.shape[1:]) + ', fill value = %.2f' % cval)
pass
# Write FITS file and PSF mask.
head_sci['CRPIX1'] = crpix1
head_sci['CRPIX2'] = crpix2
head_sci['STARCENX'] = starcenx
head_sci['STARCENY'] = starceny
head_sci['MASKCENX'] = maskcenx
head_sci['MASKCENY'] = maskceny
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
nanmaskfile = ut.write_msk(nanmaskfile, nanmask, fitsfile, '_nanmask.fits')
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile, crpix1=crpix1, crpix2=crpix2, starcenx=starcenx, starceny=starceny, maskcenx=maskcenx, maskceny=maskceny, nanmaskfile=nanmaskfile)
pass
def mask_NDsquares(self,
npix=1,
cval=np.nan,
minval=0.1,
types=['SCI', 'SCI_BG', 'REF', 'REF_BG'],
subdir='ndmasked'):
"""
Mask the ND squares in the frames by setting the pixel values to NaN.
Parameters
----------
npix : int or list of four int, optional
Number of pixels to be added around the square masks. The default is 1.
cval : float, optional
Fill value for the maked pixels. The default is nan.
minval: float, optional
Minimum value in the PSF mask to consider a pixel as
part of the ND square. The default is 0.1.
types : list of str, optional
List of data types from which the frames shall be padded. The
default is ['SCI', 'SCI_BG', 'REF', 'REF_BG'].
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'padded'.
Returns
-------
None.
"""
def dilate_squares(mask, n):
"""
Expand masks by n pixels in every direction.
mask : 2D array of 0/1 (or bool)
n : non-negative int
returns : 2D array (same dtype as input) with expanded clusters
"""
if n <= 0:
return mask.copy()
struct = np.ones((2 * n + 1, 2 * n + 1), dtype=bool)
out = scipy.ndimage.binary_dilation(mask.astype(bool), structure=struct)
return out.astype(mask.dtype)
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
# Loop through FITS files.
nfitsfiles = len(self.database.obs[key])
for j in range(nfitsfiles):
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
nanmaskfile = self.database.obs[key]['NANMASKFILE'][j]
mask = ut.read_msk(maskfile)
nanmask = ut.read_msk(nanmaskfile)
crpix1 = self.database.obs[key]['CRPIX1'][j]
crpix2 = self.database.obs[key]['CRPIX2'][j]
starcenx = self.database.obs[key]['STARCENX'][j]
starceny = self.database.obs[key]['STARCENY'][j]
maskcenx = self.database.obs[key]['MASKCENX'][j]
maskceny = self.database.obs[key]['MASKCENY'][j]
if self.database.obs[key]['TYPE'][j] in types:
head, tail = os.path.split(fitsfile)
log.info(' --> Frame ND square masking: ' + tail)
# Get data shape.
ny, nx = mask.shape
yy, xx = np.indices((ny, nx))
rows, cols = np.where(np.isfinite(data[0]))
bbox = [np.min(cols), np.max(cols)]
# only mask where psfmask indicates bad pixels (e.g. psfmask[0] < 1)
NDmask = (mask < minval) & ((xx < bbox[0] + 30) | (xx > bbox[1] - 50))
# apply to data (assumes data.shape == (n_frames, ny, nx))
data[:, dilate_squares(NDmask,n=npix)] = cval
# Write new FITS file and mask.
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
nanmaskfile = ut.write_msk(nanmaskfile, nanmask, fitsfile, '_nanmask.fits')
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile, crpix1=crpix1, crpix2=crpix2, starcenx=starcenx,
starceny=starceny, maskcenx=maskcenx, maskceny=maskceny, nanmaskfile=nanmaskfile)
pass
[docs]
def mask_NDsquares(self,
npix=1,
cval=np.nan,
minval=0.1,
types=['SCI', 'SCI_BG', 'REF', 'REF_BG'],
subdir='ndmasked'):
"""
Mask the ND squares in the frames by setting the pixel values to NaN.
Parameters
----------
npix : int or list of four int, optional
Number of pixels to be added around the square masks. The default is 1.
cval : float, optional
Fill value for the maked pixels. The default is nan.
minval: float, optional
Minimum value in the PSF mask to consider a pixel as
part of the ND square. The default is 0.1.
types : list of str, optional
List of data types from which the frames shall be padded. The
default is ['SCI', 'SCI_BG', 'REF', 'REF_BG'].
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'padded'.
Returns
-------
None.
"""
def dilate_squares(mask, n):
"""
Expand masks by n pixels in every direction.
mask : 2D array of 0/1 (or bool)
n : non-negative int
returns : 2D array (same dtype as input) with expanded clusters
"""
if n <= 0:
return mask.copy()
struct = np.ones((2 * n + 1, 2 * n + 1), dtype=bool)
out = scipy.ndimage.binary_dilation(mask.astype(bool), structure=struct)
return out.astype(mask.dtype)
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
# Loop through FITS files.
nfitsfiles = len(self.database.obs[key])
for j in range(nfitsfiles):
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
mask = ut.read_msk(maskfile)
crpix1 = self.database.obs[key]['CRPIX1'][j]
crpix2 = self.database.obs[key]['CRPIX2'][j]
starcenx = self.database.obs[key]['STARCENX'][j]
starceny = self.database.obs[key]['STARCENY'][j]
maskcenx = self.database.obs[key]['MASKCENX'][j]
maskceny = self.database.obs[key]['MASKCENY'][j]
if self.database.obs[key]['TYPE'][j] in types:
head, tail = os.path.split(fitsfile)
log.info(' --> Frame ND square masking: ' + tail)
# Get data shape.
ny, nx = mask.shape
yy, xx = np.indices((ny, nx))
rows, cols = np.where(np.isfinite(data[0]))
bbox = [np.min(cols), np.max(cols)]
# only mask where psfmask indicates bad pixels (e.g. psfmask[0] < 1)
NDmask = (mask < minval) & ((xx < bbox[0] + 30) | (xx > bbox[1] - 50))
# apply to data (assumes data.shape == (n_frames, ny, nx))
data[:, dilate_squares(NDmask,n=npix)] = cval
# Write new FITS file and mask.
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile, crpix1=crpix1, crpix2=crpix2, starcenx=starcenx,
starceny=starceny, maskcenx=maskcenx, maskceny=maskceny)
pass
[docs]
def coadd_frames(self,
nframes=None,
types=['SCI', 'SCI_BG', 'REF', 'REF_BG'],
subdir='coadded'):
"""
Coadd frames.
Parameters
----------
nframes : int, optional
Number of frames to be coadded. Modulo frames will be removed. If
None, will coadd all frames in an observation. The default is None.
types : list of str, optional
List of data types from which the frames shall be coadded. The
default is ['SCI', 'SCI_BG', 'REF', 'REF_BG'].
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'coadded'.
Returns
-------
None.
"""
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# The starting value.
nframes0 = nframes
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
# Loop through FITS files.
nfitsfiles = len(self.database.obs[key])
for j in range(nfitsfiles):
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
nanmaskfile = self.database.obs[key]['NANMASKFILE'][j]
mask = ut.read_msk(maskfile)
nanmask = ut.read_msk(nanmaskfile)
nints = self.database.obs[key]['NINTS'][j]
effinttm = self.database.obs[key]['EFFINTTM'][j]
# If nframes is not provided, collapse everything.
if nframes0 is None:
nframes = nints
# Skip file types that are not in the list of types.
if self.database.obs[key]['TYPE'][j] in types:
# Coadd frames.
head, tail = os.path.split(fitsfile)
log.info(' --> Frame coadding: ' + tail)
ncoadds = data.shape[0] // nframes
data = np.nanmedian(data[:nframes * ncoadds].reshape((nframes, ncoadds, data.shape[-2], data.shape[-1])), axis=0)
erro_reshape = erro[:nframes * ncoadds].reshape((nframes, ncoadds, erro.shape[-2], erro.shape[-1]))
nsample = np.sum(np.logical_not(np.isnan(erro_reshape)), axis=0)
erro = np.true_divide(np.sqrt(np.nansum(erro_reshape**2, axis=0)), nsample)
pxdq_temp = pxdq[:nframes * ncoadds].reshape((nframes, ncoadds, pxdq.shape[-2], pxdq.shape[-1]))
pxdq = pxdq_temp[0]
for k in range(1, nframes):
pxdq = np.bitwise_or(pxdq, pxdq_temp[k])
if align_shift is not None:
align_shift = np.mean(align_shift[:nframes * ncoadds].reshape((nframes, ncoadds, align_shift.shape[-1])), axis=0)
if center_shift is not None:
center_shift = np.mean(center_shift[:nframes * ncoadds].reshape((nframes, ncoadds, center_shift.shape[-1])), axis=0)
# There is only a single value for align_mask and center_mask, so we won't reshape
# This should be improved in the future once the PSF mask is properly handled at the integration level.
# if align_mask is not None:
# align_mask = np.mean(align_mask[:nframes * ncoadds].reshape((nframes, ncoadds, align_mask.shape[-1])), axis=0)
# if center_mask is not None:
# center_mask = np.mean(center_mask[:nframes * ncoadds].reshape((nframes, ncoadds, center_mask.shape[-1])), axis=0)
if maskoffs is not None:
maskoffs = np.mean(maskoffs[:nframes * ncoadds].reshape((nframes, ncoadds, maskoffs.shape[-1])), axis=0)
nints = data.shape[0]
effinttm *= nframes
log.info(' --> Frame coadding: %.0f coadd(s) of %.0f frame(s)' % (ncoadds, nframes))
# Write FITS file and PSF mask.
head_pri['NINTS'] = nints
head_pri['EFFINTTM'] = effinttm
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
nanmaskfile = ut.write_msk(fitsfile, nanmask, fitsfile, '_nanmask.fits')
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile, nints=nints, effinttm=effinttm, nanmaskfile=nanmaskfile)
pass
[docs]
def subtract_median(self,
types=['SCI', 'SCI_TA', 'SCI_BG', 'REF', 'REF_TA', 'REF_BG'],
method='border',
sigma=3.0,
borderwidth=32,
subdir='medsub'):
"""
Subtract the median from each frame. Clip everything brighter than 5-
sigma from the background before computing the median.
Parameters
----------
types : list of str, optional
List of data types for which the median shall be subtracted. The
default is ['SCI', 'SCI_TA', 'SCI_BG', 'REF', 'REF_TA', 'REF_BG'].
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'medsub'.
method : str, optional
- 'robust' for a robust median after masking out bright stars,
- 'sigma_clipped' for another version of robust median using astropy
sigma_clipped_stats on the whole image,
- 'border' for robust median on the outer border region only, to
ignore the bright stellar PSF in the center,
- 'simple' for a simple np.nanmedian.
sigma : float, optional
number of standard deviations to use for the clipping limit in
sigma_clipped_stats, if the robust option is selected.
borderwidth : int, optional
number of pixels to use when defining the outer border region, if
the border option is selected. Default is to use the outermost 32
pixels around all sides of the image.
Returns
-------
None.
"""
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
log.info(f'Median subtraction using method={method}')
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
# Loop through FITS files.
nfitsfiles = len(self.database.obs[key])
for j in range(nfitsfiles):
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
nanmaskfile = self.database.obs[key]['NANMASKFILE'][j]
mask = ut.read_msk(maskfile)
nanmask = ut.read_msk(nanmaskfile)
pxmask_donotuse = ut.get_dqmask(pxdq, 'DO_NOT_USE', return_bool=True)
# Skip file types that are not in the list of types.
if self.database.obs[key]['TYPE'][j] in types:
# Subtract median.
head, tail = os.path.split(fitsfile)
log.info(' --> Median subtraction: ' + tail)
data_temp = data.copy()
# if self.database.obs[key]['TELESCOP'][j] == 'JWST' and self.database.obs[key]['INSTRUME'][j] == 'NIRCAM':
# data_temp[pxdq != 0] = np.nan
data_temp[pxmask_donotuse] = np.nan
# else:
# data_temp[pxdq & 1 == 1] = np.nan
if method == 'robust':
# Robust median, using a method by Jens
bg_med = np.nanmedian(data_temp, axis=(1, 2), keepdims=True)
bg_std = robust.medabsdev(data_temp, axis=(1, 2), keepdims=True)
bg_ind = data_temp > (bg_med + 5. * bg_std) # clip bright PSFs for final calculation
data_temp[bg_ind] = np.nan
bg_median = np.nanmedian(data_temp, axis=(1, 2), keepdims=True)
elif method == 'sigma_clipped':
# Robust median using astropy.stats.sigma_clipped_stats
if len(data.shape) == 2:
mean, median, stddev = astropy.stats.sigma_clipped_stats(data_temp, sigma=sigma)
elif len(data.shape) == 3:
bg_median = np.zeros([data.shape[0], 1, 1])
for iint in range(data.shape[0]):
mean_i, median_i, stddev_i = astropy.stats.sigma_clipped_stats(data[iint])
bg_median[iint] = median_i
else:
raise NotImplementedError("data must be 2d or 3d for this method")
elif method == 'border':
# Use only the outer border region of the image, near the edges of the FOV
shape = data.shape
if len(shape) == 2:
# only one int
y, x = np.indices(shape)
bordermask = (x < borderwidth) | (x > shape[1] - borderwidth) | (y < borderwidth) | (y > shape[0] - borderwidth)
mean, bg_median, stddev = astropy.stats.sigma_clipped_stats(data[bordermask])
elif len(shape) == 3:
# perform robust stats on border region of each int
y, x = np.indices(data.shape[1:])
bordermask = (x < borderwidth) | (x > shape[1] - borderwidth) | (y < borderwidth) | (y > shape[0] - borderwidth)
bg_median = np.zeros([shape[0], 1, 1])
for iint in range(shape[0]):
mean_i, median_i, stddev_i = astropy.stats.sigma_clipped_stats(data[iint][bordermask])
bg_median[iint] = median_i
else:
raise NotImplementedError("data must be 2d or 3d for this method")
else:
# Plain vanilla median of the image
bg_median = np.nanmedian(data_temp, axis=(1, 2), keepdims=True)
data -= bg_median
log.info(' --> Median subtraction: mean of frame median = %.2f' % np.mean(bg_median))
# Write FITS file and PSF mask.
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
nanmaskfile = ut.write_msk(fitsfile, nanmask, fitsfile, '_nanmask.fits')
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile, nanmaskfile=nanmaskfile)
[docs]
def subtract_background_godoy(self,
types=['SCI', 'REF'],
subdir='bgsub'):
"""
Subtract the corresponding background observations from the SCI and REF
data in the spaceKLIP database using a method developed by Nico Godoy.
Parameters
----------
types : list of str
File types to run the subtraction over.
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'bgsub'.
Returns
-------
None.
"""
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
# Load in bunch of stuff
# Find science, reference, and background files.
ww_sci = np.where(self.database.obs[key]['TYPE'] == 'SCI')[0]
ww_ref = np.where(self.database.obs[key]['TYPE'] == 'REF')[0]
ww_sci_bg = np.where(self.database.obs[key]['TYPE'] == 'SCI_BG')[0]
ww_ref_bg = np.where(self.database.obs[key]['TYPE'] == 'REF_BG')[0]
# Loop over science and reference files
for typ in types:
if typ == 'SCI':
ww, ww_bg = ww_sci, ww_sci_bg
elif typ == 'REF':
ww, ww_bg = ww_ref, ww_ref_bg
# Gather background files.
if len(ww_bg) == 0:
raise UserWarning('Could not find any background files.')
else:
bg_data, bg_erro, bg_pxdq = [], [], []
for j in ww_bg:
# Read background file.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
# Compute median science background.
bg_data += [data]
bg_erro += [erro]
bg_pxdq += [pxdq]
bg_data, bg_erro, bg_pxdq = np.array(bg_data), np.array(bg_erro), np.array(bg_pxdq)
# If multiple files, take the median. Otherwise, carry on.
if bg_data.ndim == 4:
bg_data = np.nanmedian(bg_data, axis=0)
# Loop over individual files
for j in ww:
# Read FITS file.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
# Subtract the background per frame
head, tail = os.path.split(fitsfile)
log.info(' --> Background subtraction: ' + tail)
data -= bg_data
# Loop over integrations
data_bg_sub = np.empty_like(data)
for k in range(data.shape[0]):
# Subtract median of corresponding background frame from the frame
bg_submed = bg_data[k, :, :] - np.nanmedian(bg_data[k, :, :])
# Do the same for the data (that's already background subtracted)
data_submed = data[k, :, :] - np.nanmedian(data[k, :, :])
# Specify sections for initial guess
# sect1 = data_submed[108:118,12:62]/bg_submed[108:118, 12:62]
# sect2 = data_submed[93:106,152:207]/bg_submed[93:106, 152:207]
sect1 = data_submed[112:118, 4:10]/bg_submed[112:118, 4:10]
sect2 = data_submed[95:101, 207:212]/bg_submed[95:101, 207:212]
# Reshape into 1d arrays and concatenate
s1 = sect1.reshape(1, sect1.shape[0]*sect1.shape[1])
s2 = sect2.reshape(1, sect2.shape[0]*sect2.shape[1])
s12 = np.concatenate((s1[0, :], s2[0, :]))
# Take median of concatenated array
cte = np.nanmedian(s12)
# Use filter to determine mask for estimating BG scaling
# at the moment only have it working for F1140C.
filt = self.database.obs[key]['FILTER'][j]
if filt not in ['F1065C', 'F1140C', 'F1550C']:
raise NotImplementedError('Godoy subtraction is only supported for MIRI FQPMs at this time!')
else:
bgmaskbase = os.path.split(os.path.abspath(__file__))[0]
bgmaskfile = os.path.join(bgmaskbase, 'resources/miri_bg_masks/godoy_mask_{}.fits'.format(filt.lower()))
# Run minimisation function, 'res' will tell us if there is any residual
# background that wasn't removed in the initial attempt. I.e. do we
# need to subtract a little bit more or less?
res = minimize(ut.bg_minimize,
x0=cte*100,
args=(data_submed, bg_submed, bgmaskfile),
method='L-BFGS-B',
tol=1e-7)
# Extract scale factor for the background from res
scale = res.x/100
# Scale the background, and now subtract this correction from the original
# background subtracted data
data_improved_bgsub = data_submed - bg_submed*scale
# Subtract median of residual frame to remove any residual median offset
data_bg_sub[k] = data_improved_bgsub - np.nanmedian(data_improved_bgsub)
# Write FITS file and PSF mask.
fitsfile = ut.write_obs(fitsfile, output_dir, data_bg_sub, erro, pxdq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile)
pass
[docs]
def subtract_background(self,
nints_per_med=None,
subdir='bgsub'):
"""
Median subtract the corresponding background observations from the SCI and REF
data in the spaceKLIP database.
Parameters
----------
nints_per_med : int
Number of integrations per median. For example, if you have a target
+ background dataset with 20 integrations each and nints_per_med is
set to 5, a median of every 5 background images will be subtracted from
the corresponding 5 target images. The default is None (i.e. a median
across all images).
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'bgsub'.
Returns
-------
None.
"""
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Store the nints_per_med parameter
orig_nints_per_med = deepcopy(nints_per_med)
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
# Find science, reference, and background files.
ww = np.where((self.database.obs[key]['TYPE'] == 'SCI')
| (self.database.obs[key]['TYPE'] == 'REF'))[0]
ww_sci_bg = np.where(self.database.obs[key]['TYPE'] == 'SCI_BG')[0]
ww_ref_bg = np.where(self.database.obs[key]['TYPE'] == 'REF_BG')[0]
# Loop through science background files.
if len(ww_sci_bg) != 0:
sci_bg_data = []
sci_bg_erro = []
sci_bg_pxdq = []
for j in ww_sci_bg:
# Read science background file.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
# Determine split indices
nints = data.shape[0]
if orig_nints_per_med == None:
nints_per_med = nints
indxs = np.arange(nints)
split_inds = [x+1 for x in indxs if (x+1) % nints_per_med == 0
and x < (nints-nints_per_med)]
# Compute median science background.
sci_bg_data += [data]
sci_bg_erro += [erro]
sci_bg_pxdq += [pxdq]
sci_bg_data = np.concatenate(sci_bg_data)
sci_bg_erro = np.concatenate(sci_bg_erro)
sci_bg_pxdq = np.concatenate(sci_bg_pxdq)
sci_bg_data_split = np.array_split(sci_bg_data, split_inds, axis=0)
sci_bg_erro_split = np.array_split(sci_bg_erro, split_inds, axis=0)
sci_bg_pxdq_split = np.array_split(sci_bg_pxdq, split_inds, axis=0)
for k in range(len(split_inds)+1):
sci_bg_data_split[k] = np.nanmedian(sci_bg_data_split[k], axis=0)
nsample = np.sum(np.logical_not(np.isnan(sci_bg_erro_split[k])), axis=0)
sci_bg_erro_split[k] = np.true_divide(np.sqrt(np.nansum(sci_bg_erro_split[k]**2, axis=0)), nsample)
sci_split_donotuse = ut.get_dqmask(sci_bg_pxdq_split[k], 'DO_NOT_USE', return_bool=True)
sci_bg_pxdq_split[k] = np.sum(sci_split_donotuse, axis=0) != 0
else:
sci_bg_data = None
# Loop through reference background files.
if len(ww_ref_bg) != 0:
ref_bg_data = []
ref_bg_erro = []
ref_bg_pxdq = []
for j in ww_ref_bg:
# Read reference background file.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
# Determine split indices
nints = data.shape[0]
if orig_nints_per_med == None:
nints_per_med = nints
indxs = np.arange(nints)
split_inds = [x+1 for x in indxs if (x+1) % nints_per_med == 0
and x < (nints-nints_per_med)]
# Compute median reference background.
ref_bg_data += [data]
ref_bg_erro += [erro]
ref_bg_pxdq += [pxdq]
ref_bg_data = np.concatenate(ref_bg_data)
ref_bg_erro = np.concatenate(ref_bg_erro)
ref_bg_pxdq = np.concatenate(ref_bg_pxdq)
ref_bg_data_split = np.array_split(ref_bg_data, split_inds, axis=0)
ref_bg_erro_split = np.array_split(ref_bg_erro, split_inds, axis=0)
ref_bg_pxdq_split = np.array_split(ref_bg_pxdq, split_inds, axis=0)
for k in range(len(split_inds)+1):
ref_bg_data_split[k] = np.nanmedian(ref_bg_data_split[k], axis=0)
nsample = np.sum(np.logical_not(np.isnan(ref_bg_erro_split[k])), axis=0)
ref_bg_erro_split[k] = np.true_divide(np.sqrt(np.nansum(ref_bg_erro_split[k]**2, axis=0)), nsample)
ref_split_donotuse = ut.get_dqmask(ref_bg_pxdq_split[k], 'DO_NOT_USE', return_bool=True)
ref_bg_pxdq_split[k] = np.sum(ref_split_donotuse, axis=0) != 0
else:
ref_bg_data = None
# Check input.
if sci_bg_data is None and ref_bg_data is None:
raise UserWarning('Could not find any background files')
# Loop through science and reference files.
for j in ww:
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
mask = ut.read_msk(maskfile)
wwtype = self.database.obs[key]['TYPE'][j]
if wwtype == 'SCI':
sci = True
else:
sci = False
# Determine split indices
nints = data.shape[0]
if orig_nints_per_med == None:
nints_per_med = nints
indxs = np.arange(nints)
split_inds = [x+1 for x in indxs if (x+1) % nints_per_med == 0
and x < (nints-nints_per_med)]
# Subtract background.
head, tail = os.path.split(fitsfile)
log.info(' --> Background subtraction: ' + tail)
data_split = np.array_split(data, split_inds, axis=0)
erro_split = np.array_split(erro, split_inds, axis=0)
pxdq_split = np.array_split(pxdq, split_inds, axis=0)
# For each dataset, need to decide what to use as the background and subtract
for k in range(len(split_inds)+1):
pxmask_split_donotuse = ut.get_dqmask(pxdq_split[k], 'DO_NOT_USE', return_bool=True)
if (sci and sci_bg_data is not None) or (not sci and ref_bg_data is None):
if not sci and ref_bg_data is None:
log.warning(' --> Could not find reference background, attempting to use science background')
data_split[k] = data_split[k] - sci_bg_data_split[k]
erro_split[k] = np.sqrt(erro_split[k]**2 + sci_bg_erro_split[k]**2)
pxdq_split[k][(~pxmask_split_donotuse) & (sci_bg_pxdq_split[k] != 0)] += 1
elif (not sci and ref_bg_data is not None) or (sci and sci_bg_data is None):
if sci and sci_bg_data is None:
log.warning(' --> Could not find science background, attempting to use reference background')
data_split[k] = data_split[k] - ref_bg_data_split[k]
erro_split[k] = np.sqrt(erro_split[k]**2 + ref_bg_erro_split[k]**2)
pxdq_split[k][(~pxmask_split_donotuse) & (ref_bg_pxdq_split[k] != 0)] += 1
data = np.concatenate(data_split, axis=0)
erro = np.concatenate(erro_split, axis=0)
pxdq = np.concatenate(pxdq_split, axis=0)
# Write FITS file and PSF mask.
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile)
pass
[docs]
def find_bad_pixels(self,
method='dqarr',
set_dq_zero=True,
dqarr_kwargs=None,
sigclip_kwargs=None,
custom_kwargs=None,
timeints_kwargs=None,
gradient_kwargs=None,
types=['SCI', 'SCI_TA', 'SCI_BG', 'REF', 'REF_TA', 'REF_BG'],
subdir='bpfound',
restrict_to=None,
min_nancluster = None):
"""
Identify bad pixels for cleaning.
Parameters
----------
method : str, optional
Sequence of bad pixel cleaning methods to be run on the data.
Different methods must be joined by a '+' sign without
whitespace. Available methods are:
- dqarr: Uses DQ array to identify bad pixels.
- timeints: To identify bad pixels in time.
- sigclip: Uses sigma clipping to identify additional bad pixels.
- custom: Uses a custom bad pixel map.
The default is 'dqarr'.
set_dq_zero : bool, optional
Toggle to start a new empty DQ array, or built upon the existing array.
The default is True.
dqarr_kwargs : dict, optional
Keyword arguments for the 'dqarr' identification method. Available keywords are:
- flag_neighbors : bool, optional
If True, the 4-connected neighbors (up, down, left, right) of each DO_NOT_USE pixel
are evaluated and flagged if they are elevated relative to the local background, which
is estimated from the surrounding diagonal pixels.
- sigma : int, optional
Flag neighboring pixels with values greater than diag_med + sigma * diag_std.
The default is {}.
sigclip_kwargs : dict, optional
Keyword arguments for the 'sigclip' identification methods. Available keywords are:
- sigma : float, optional
Sigma clipping threshold. The default is 5.
- neg_sigma : float, optional
Sigma clipping threshold for negative outliers. The default is 1.
- shift_x : list of int, optional
Pixels in x-direction to which each pixel shall be compared to.
The default is [-1, 0, 1].
- shift_y : list of int, optional
Pixels in y-direction to which each pixel shall be compared to.
The default is [-1, 0, 1].
- diagonal_only : bool, optional
Only compare to diagonal neighbors? The default is False.
- threshold_metric : str, optional
Whether to use standard deviation or MAD.
- max_cluster_size : int, optional
Maximum size of bad pixel clusters to be flagged. If None, no limit is applied.
- cluster_dilate_radius : int, optional
Radius for dilating bad pixels before checking clustering. The default is 6 pixels.
- mode : str, optional
Sigma-clipping strategy used to identify bad pixels. Available options are:
'local' : Flags pixels that deviate from the median of their neighboring pixels.
Large negative outliers are also identified by comparing to background estimate.
'local_weighted' : Same as 'local', but includes the pixel uncertainty when computing the clipping threshold.
- mask_psf : bool, optional
Restrict bad pixel flagging inside the PSF?
- crpix1/crpix2: float/float, optional
The center of the PSF.
The default is {}.
custom_kwargs : dict, optional
Keyword arguments for the 'custom' method. The dictionary keys must
match the keys of the observations database and the dictionary
content must be binary bad pixel maps (1 = bad, 0 = good) with the
same shape as the corresponding data. The default is {}.
The default is {}.
timeints_kwargs : dict, optional
- sigma : float, optional
Sigma clipping threshold. The default is 5.
- mode : str, optional
Mode for detecting bad pixels. The default is 'per_pixel'.
'per_pixel' : Computes the variation across integrations independently for each pixel.
'group_pixels' : Groups pixels by similar flux, computes the variation across integrations for each pixel, and compares each pixel’s variation to that of its corresponding flux group.
- n_groups : int, optional
The number of groups if mode == 'group_pixels'. The default is 25.
- diagnostic_plots : bool, optional
Plot diagnostics?
The default is {}.
types : list of str, optional
List of data types for which bad pixels shall be identified.
The default is ['SCI', 'SCI_TA', 'SCI_BG', 'REF', 'REF_TA', 'REF_BG'].
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'bpfound'.
min_nancluster: int, optional
minimum number of pixels required to flag cluster of NaNs pixels. If None, the provided NaN mask
in the database will be used.
The default is None.
Returns
-------
None.
"""
def nan_clusters_mask(image, min_nancluster):
"""
Create a 3D boolean mask where NaN clusters larger than min_nancluster are marked as True,
but only checking within each 2D slice (ignoring connections along the Z-axis).
Parameters:
- image: 3D numpy array (with NaNs)
- min_nancluster: int, minimum cluster size to consider
Returns:
- 3D boolean numpy array with the same shape as input
"""
# Initialize the output mask (same shape as image)
output_mask = np.zeros_like(image, dtype=bool)
# Process each 2D slice independently
for z in range(image.shape[0]): # Loop over the first axis (Z)
nan_mask = np.isnan(image[z]) # Extract NaN mask for this slice
labeled_array, num_features = scipy.ndimage.label(nan_mask) # Label clusters
slices = scipy.ndimage.find_objects(labeled_array) # Get bounding boxes for clusters
for i, sl in enumerate(slices):
if sl is not None:
cluster_mask = (labeled_array[sl] == (i + 1)) # Mask for this cluster
if np.sum(cluster_mask) >= min_nancluster: # Only keep clusters larger than min_nancluster
output_mask[z][sl][cluster_mask] = True
return np.nanmedian(output_mask.astype(int), axis=0)
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Protect against mutability of default arguments
if dqarr_kwargs is None:
dqarr_kwargs = {}
else:
dqarr_kwargs = dqarr_kwargs.copy()
if sigclip_kwargs is None:
sigclip_kwargs = {}
else:
sigclip_kwargs = sigclip_kwargs.copy()
if custom_kwargs is None:
custom_kwargs = {}
else:
custom_kwargs = custom_kwargs.copy()
if timeints_kwargs is None:
timeints_kwargs = {}
else:
timeints_kwargs = timeints_kwargs.copy()
if gradient_kwargs is None:
gradient_kwargs = {}
else:
gradient_kwargs = gradient_kwargs.copy()
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
# if we limit to only processing some concatenations,
# check whether this concatenation matches the pattern
if (restrict_to is not None) and (restrict_to not in key):
continue
log.info('--> Concatenation ' + key)
# Loop through FITS files.
nfitsfiles = len(self.database.obs[key])
for j in range(nfitsfiles):
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
mask = ut.read_msk(maskfile)
if min_nancluster is not None:
nanmask = nan_clusters_mask(data, min_nancluster)
else:
nanmaskfile = self.database.obs[key]['NANMASKFILE'][j]
nanmask = ut.read_msk(nanmaskfile)
pxmask_nonsci = ut.get_dqmask(pxdq, 'NON_SCIENCE', return_bool=True)
# Make copy of DQ array filled with zeros, i.e. all good pixels or copy original.
pxdq_temp = np.zeros_like(pxdq) if set_dq_zero else pxdq.copy()
# NON SCIENCE and DO_NOT_USE pixels.
pxmask_nonsci = ut.get_dqmask(pxdq_temp, 'NON_SCIENCE', return_bool=True)
pxmask_donotuse = ut.get_dqmask(pxdq_temp, 'DO_NOT_USE', return_bool=True)
# Skip file types that are not in the list of types.
if self.database.obs[key]['TYPE'][j] in types:
# Call bad pixel identification routines.
method_split = method.split('+')
if method_split[0] != 'dqarr' and not set_dq_zero:
# If the first methond is not dqarr and you are not using a boolean mask for pxdq_temp,
# convert pxdq_temp to a boolean mask or some of the next steps won't work.
# This is just a place holder. We need to think about how this mask will look like.
pxdq_temp = (pxdq_temp > 0)
for k in range(len(method_split)):
head, tail = os.path.split(fitsfile)
if method_split[k] == 'dqarr':
log.info(' --> Method ' + method_split[k] + ': ' + tail)
total_added = 0
# Loop through each integration.
for i in range(data.shape[0]):
temp_nonsci = pxmask_nonsci[i].copy()
temp_donotuse = np.isnan(data[i])
if dqarr_kwargs.get('flag_neighbors', False):
added_count = 0 # Track flagged neighbors.
# Loop through DO_NOT_USE pixels.
ys, xs = np.where(temp_donotuse)
for y, x in zip(ys, xs):
# Collect diagonal pixels.
diag_vals = []
for dy, dx in [(-1, -1), (-1, 1), (1, -1), (1, 1)]:
ny, nx = y + dy, x + dx
# Make sure neighbor is within image bounds.
if 0 <= ny < data[i].shape[0] and 0 <= nx < data[i].shape[1]:
val = data[i, ny, nx]
if not np.isnan(val):# and val > 0:
diag_vals.append(val)
if len(diag_vals) == 0:
continue # Skip if no valid diagonal neighbors.
# Compute median and std of diagonal neighbors.
diag_med = np.nanmedian(diag_vals)
diag_std = robust.medabsdev(np.array(diag_vals))
diag_std = np.nanstd(diag_vals)
# Threshold for marking neighbors as bad.
thresh = diag_med + dqarr_kwargs.get('sigma', 5) * diag_std
# Check the 4-connected neighbors.
for dy, dx in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
ny, nx = y + dy, x + dx
# Make sure neighbor is within image bounds.
if 0 <= ny < data[i].shape[0] and 0 <= nx < data[i].shape[1]:
val = data[i, ny, nx]
# Skip NaNs, non-science pixels, or already flagged pixels.
if np.isnan(val) or temp_nonsci[ny, nx] or temp_donotuse[ny, nx]:
continue
# Flag neighbor if above threshold.
if val > thresh:
temp_donotuse[ny, nx] = True
added_count += 1
# Combine original DO_NOT_USE with newly flagged neighbors.
#log.info(f" Integration {i + 1}: neighbors flagged = {added_count}")
total_added += added_count
pxdq[i] = (np.isnan(data[i]) | temp_donotuse) & (~temp_nonsci)
if dqarr_kwargs.get('flag_neighbors', False):
log.info(f" --> Total neighbor pixels flagged: {total_added}")
elif method_split[k] == 'sigclip':
log.info(' --> Method ' + method_split[k] + ': ' + tail)
sigclip_kwargs['crpix1'] = self.database.obs[key]['CRPIX1'][j] - 1
sigclip_kwargs['crpix2'] = self.database.obs[key]['CRPIX2'][j] - 1
if self.database.obs[key]['EXP_TYPE'][j] in ['MIR_4QPM', 'MIR_LYOT']:
sigclip_kwargs['mask_psf'] = False
self.find_bad_pixels_sigclip(data, erro, pxdq_temp, pxmask_nonsci, sigclip_kwargs)
elif method_split[k] == 'custom':
log.info(' --> Method ' + method_split[k] + ': ' + tail)
if self.database.obs[key]['TYPE'][j] not in ['SCI_TA', 'REF_TA']:
self.find_bad_pixels_custom(data, erro, pxdq_temp, key, custom_kwargs)
else:
log.info(' --> Method ' + method_split[k] + ': skipped because TA file')
elif method_split[k] == 'timeints':
log.info(' --> Method ' + method_split[k] + ': ' + tail)
self.find_bad_pixels_timeints(data, erro, pxdq_temp, key, timeints_kwargs)
elif method_split[k] == 'gradient':
self.find_bad_pixels_gradient(data, erro, pxdq_temp, key, gradient_kwargs)
else:
log.info(' --> Unknown method ' + method_split[k] + ': skipped')
if set_dq_zero:
# The new DQ will just be the pxdq_temp we've been modifying
new_dq = pxdq_temp.astype(np.uint32)
else:
# The new DQ will be the original pxdq with added the flagged pixels from the pxdq_temp we've been modifying as do_not_use
new_dq = np.bitwise_or(pxdq.copy(), pxdq_temp).astype(np.uint32)
# Write FITS file and PSF mask.
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, new_dq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
nanmaskfile = ut.write_msk(fitsfile, nanmask, fitsfile, '_nanmask.fits')
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile, nanmaskfile=nanmaskfile)
pass
[docs]
def fix_bad_pixels(self,
method='timemed+localmed+medfilt',
sigclip_kwargs={},
custom_kwargs={},
timemed_kwargs={},
localmed_kwargs={},
medfilt_kwargs={},
types=['SCI', 'SCI_TA', 'SCI_BG', 'REF', 'REF_TA', 'REF_BG'],
subdir='bpcleaned',
restrict_to=None):
"""
TO BE DEPRECATED BY FIND_BAD_PIXELS() AND CLEAN_BAD_PIXELS()
Identify and fix bad pixels.
Parameters
----------
method : str, optional
Sequence of bad pixel identification and cleaning methods to be run
on the data. Different methods must be joined by a '+' sign without
whitespace. Available methods are:
- sigclip: use sigma clipping to identify additional bad pixels.
- custom: use a custom bad pixel map.
- timemed: replace pixels which are only bad in some frames with
their median value from the good frames.
- localmed: replace bad pixels with the median value of their
surrounding good pixels.
- medfilt: replace bad pixels with an image plane median filter.
The default is 'timemed+localmed+medfilt'.
sigclip_kwargs : dict, optional
Keyword arguments for the 'sigclip' method. Available keywords are:
- sigclip : float, optional
Sigma clipping threshold. The default is 5.
- shift_x : list of int, optional
Pixels in x-direction to which each pixel shall be compared to.
The default is [-1, 0, 1].
- shift_y : list of int, optional
Pixels in y-direction to which each pixel shall be compared to.
The default is [-1, 0, 1].
The default is {}.
custom_kwargs : dict, optional
Keyword arguments for the 'custom' method. The dictionary keys must
match the keys of the observations database and the dictionary
content must be binary bad pixel maps (1 = bad, 0 = good) with the
same shape as the corresponding data. The default is {}.
timemed_kwargs : dict, optional
Keyword arguments for the 'timemed' method. Available keywords are:
- n/a
The default is {}.
localmed_kwargs : dict, optional
Keyword arguments for the 'localmed' method. Available keywords are:
- shift_x : list of int, optional
Pixels in x-direction from which the median shall be computed.
The default is [-1, 0, 1].
- shift_y : list of int, optional
Pixels in y-direction from which the median shall be computed.
The default is [-1, 0, 1].
The default is {}.
medfilt_kwargs : dict, optional
Keyword arguments for the 'medfilt' method. Available keywords are:
- size : int, optional
Kernel size of the median filter to be used. The default is 4.
The default is {}.
types : list of str, optional
List of data types for which bad pixels shall be identified and
fixed. The default is ['SCI', 'SCI_TA', 'SCI_BG', 'REF', 'REF_TA',
'REF_BG'].
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'bpcleaned'.
Returns
-------
None.
"""
# log.info('--> WARNING! The fix_bad_pixels() routine is deprecated, the ..........')
# log.info('--> WARNING! find_bad_pixels() and clean_bad_pixels() are preferred!!!!')
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
# if we limit to only processing some concatenations, check whether this concatenation matches the pattern
if (restrict_to is not None) and (restrict_to not in key):
continue
log.info('--> Concatenation ' + key)
# Loop through FITS files.
nfitsfiles = len(self.database.obs[key])
for j in range(nfitsfiles):
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
mask = ut.read_msk(maskfile)
pxmask_nonsci = ut.get_dqmask(pxdq, 'NON_SCIENCE', return_bool=True)
# Skip file types that are not in the list of types.
if self.database.obs[key]['TYPE'][j] in types:
# Call bad pixel cleaning routines.
pxdq_temp = pxdq.copy()
# if self.database.obs[key]['TELESCOP'][j] == 'JWST' and self.database.obs[key]['INSTRUME'][j] == 'NIRCAM':
# pxdq_temp = (pxdq_temp != 0) & np.logical_not(pxdq_temp & 512 == 512)
# else:
temp_donotuse = ut.get_dqmask(pxdq_temp, 'DO_NOT_USE', return_bool=True)
temp_nonsci = ut.get_dqmask(pxdq_temp, 'NON_SCIENCE', return_bool=True)
pxdq_temp = (np.isnan(data) | temp_donotuse) & (~temp_nonsci)
method_split = method.split('+')
for k in range(len(method_split)):
head, tail = os.path.split(fitsfile)
if method_split[k] == 'sigclip':
log.info(' --> Method ' + method_split[k] + ': ' + tail)
self.find_bad_pixels_sigclip(data, erro, pxdq_temp, pxmask_nonsci, sigclip_kwargs)
elif method_split[k] == 'custom':
log.info(' --> Method ' + method_split[k] + ': ' + tail)
if self.database.obs[key]['TYPE'][j] not in ['SCI_TA', 'REF_TA']:
self.find_bad_pixels_custom(data, erro, pxdq_temp, key, custom_kwargs)
else:
log.info(' --> Method ' + method_split[k] + ': skipped because TA file')
elif method_split[k] == 'timemed':
log.info(' --> Method ' + method_split[k] + ': ' + tail)
self.fix_bad_pixels_timemed(data, erro, pxdq_temp, timemed_kwargs)
elif method_split[k] == 'localmed':
log.info(' --> Method ' + method_split[k] + ': ' + tail)
self.fix_bad_pixels_localmed(data, erro, pxdq_temp, localmed_kwargs)
elif method_split[k] == 'medfilt':
log.info(' --> Method ' + method_split[k] + ': ' + tail)
self.fix_bad_pixels_medfilt(data, erro, pxdq_temp, medfilt_kwargs)
else:
log.info(' --> Unknown method ' + method_split[k] + ': skipped')
# if self.database.obs[key]['TELESCOP'][j] == 'JWST' and self.database.obs[key]['INSTRUME'][j] == 'NIRCAM':
# pxdq[(pxdq != 0) & np.logical_not(pxdq & 512 == 512) & (pxdq_temp == 0)] = 0
# else:
# pxdq[(pxdq & 1 == 1) & np.logical_not(pxdq & 512 == 512) & (pxdq_temp == 0)] = 0
# update the pixel DQ bit flags for the output files.
# The pxdq variable here is effectively just the DO_NOT_USE flag, discarding other bits.
# We want to make a new dq which retains the other bits as much as possible.
# first, retain all the other bits (bits greater than 1), then add in the new/cleaned DO_NOT_USE bit
do_not_use = dqflags.pixel['DO_NOT_USE']
new_dq = np.bitwise_and(pxdq.copy(), np.invert(do_not_use)) # retain all other bits except the do_not_use bit
new_dq = np.bitwise_or(new_dq, pxdq_temp) # add in the do_not_use bit from the cleaned version
new_dq = new_dq.astype(np.uint32) # ensure correct output type for saving
# (the bitwise steps otherwise return np.int64 which isn't FITS compatible)
# Write FITS file and PSF mask.
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, new_dq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile)
pass
[docs]
def clean_bad_pixels(self,
method='timemed+localmed+medfilt',
timemed_kwargs=None,
localmed_kwargs=None,
medfilt_kwargs=None,
interp2d_kwargs=None,
astrofix_kwargs=None,
types=['SCI', 'SCI_TA', 'SCI_BG', 'REF', 'REF_TA', 'REF_BG'],
subdir='bpcleaned',
restrict_to=None,
plot=False):
"""
Clean bad pixels.
Parameters
----------
method : str, optional
Sequence of bad pixel cleaning methods to be run on the data.
Different methods must be joined by a '+' sign without
whitespace. Available methods are:
- timemed: replace pixels which are only bad in some frames with
their median value from the good frames.
- localmed: replace bad pixels with the median value of their
surrounding good pixels.
- medfilt: replace bad pixels with an image plane median filter.
- interp2d: replace bad pixels with an interpolation of neighbouring pixels.
- astrofix: is an astronomical image correction algorithm based on Gaussian Process Regression.
The default is 'timemed+localmed+medfilt'.
timemed_kwargs : dict, optional
Keyword arguments for the 'timemed' method. Available keywords are:
- n/a
The default is {}.
localmed_kwargs: dict, optional
Keyword arguments for the 'localmed' method. Available keywords are:
- shift_x : list of int, optional
Pixels in x-direction from which the median shall be computed.
The default is [-1, 0, 1].
- shift_y : list of int, optional
Pixels in y-direction from which the median shall be computed.
The default is [-1, 0, 1].
The default is {}.
medfilt_kwargs : dict, optional
Keyword arguments for the 'medfilt' method. Available keywords are:
- size : int, optional
Kernel size of the median filter to be used. The default is 4.
The default is {}.
interp2d_kwargs: dict, optional
Keyword arguments for the 'interp2d' method. Available keywords are:
- size : int, optional
Kernel size of the median filter to be used. The default is 4.
The default is {}.
astrofix_kwargs: dict, optional
Keyword arguments for the 'astrofix' method. Available keywords are:
The default is {}.
types : list of str, optional
List of data types for which bad pixels shall be identified and
fixed. The default is ['SCI', 'SCI_TA', 'SCI_BG', 'REF', 'REF_TA',
'REF_BG'].
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'bpcleaned'.
Returns
-------
None.
"""
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Protect against mutability of default arguments
if timemed_kwargs is None:
timemed_kwargs = {}
else:
timemed_kwargs = timemed_kwargs.copy()
if localmed_kwargs is None:
localmed_kwargs = {}
else:
localmed_kwargs = localmed_kwargs.copy()
if medfilt_kwargs is None:
medfilt_kwargs = {}
else:
medfilt_kwargs = medfilt_kwargs.copy()
if interp2d_kwargs is None:
interp2d_kwargs = {}
else:
interp2d_kwargs = interp2d_kwargs.copy()
if astrofix_kwargs is None:
astrofix_kwargs = {}
else:
astrofix_kwargs = astrofix_kwargs.copy()
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
# if we limit to only processing some concatenations, check whether this concatenation matches the pattern
if (restrict_to is not None) and (restrict_to not in key):
continue
log.info('--> Concatenation ' + key)
# Loop through FITS files.
nfitsfiles = len(self.database.obs[key])
for j in range(nfitsfiles):
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
nanmaskfile = self.database.obs[key]['NANMASKFILE'][j]
mask = ut.read_msk(maskfile)
nanmask = ut.read_msk(nanmaskfile)
if plot:
fig = plt.figure()
ax = plt.gca()
ax.hist(data.flatten(),
bins=int(np.sqrt(len(data.flatten()))),
histtype='step',
label='Pre Cleaning')
# Make copy of DQ array
pxdq_temp = pxdq.copy()
# Don't want to clean anything that isn't bad or is a non-science pixel
temp_donotuse = ut.get_dqmask(pxdq_temp, 'DO_NOT_USE', return_bool=True)
temp_nonsci = ut.get_dqmask(pxdq_temp, 'NON_SCIENCE', return_bool=True)
if nanmask is not None:
pxdq_temp =(np.isnan(data) | temp_donotuse) & (~temp_nonsci) & (nanmask != 1)
else:
pxdq_temp = (np.isnan(data) | temp_donotuse) & (~temp_nonsci)
# Skip file types that are not in the list of types.
if self.database.obs[key]['TYPE'][j] in types:
method_split = method.split('+')
spatial = ['localmed', 'medfilt', 'interp2d', 'astofix']
# If localmed and medfilt in cleaning, can't run both
if len(set(method_split) & set(spatial)) > 1:
log.info(' --> WARNING: Multiple spatial cleaning routines detected!')
log.info(' --> The localmed/medfilt/interp2d methods clean data in a similar manner!')
log.info(' --> medfilt and interp2d are redundant')
log.info(' --> only the first method listed will affect the data')
log.info(' --> localmed is partially redundant with other methods')
log.info(' --> if run first, large clusters of bad pixels may not be fully cleaned.')
# Loop over methods.
for k in range(len(method_split)):
head, tail = os.path.split(fitsfile)
log.info(' --> Method ' + method_split[k] + ': ' + tail)
if method_split[k] == 'timemed':
self.fix_bad_pixels_timemed(data, erro, pxdq_temp, timemed_kwargs)
elif method_split[k] == 'localmed':
self.fix_bad_pixels_localmed(data, erro, pxdq_temp, localmed_kwargs)
elif method_split[k] == 'medfilt':
self.fix_bad_pixels_medfilt(data, erro, pxdq_temp, medfilt_kwargs)
elif method_split[k] == 'interp2d':
self.fix_bad_pixels_interp2d(data, erro, pxdq_temp, interp2d_kwargs)
elif method_split[k] == 'astrofix':
self.fix_bad_pixels_astrofix(data, erro, pxdq_temp, astrofix_kwargs, plot=plot)
else:
log.info(' --> Unknown method ' + method_split[k] + ': skipped')
# update the pixel DQ bit flags for the output files.
# The pxdq variable here is effectively just the DO_NOT_USE flag, discarding other bits.
# We want to make a new dq which retains the other bits as much as possible.
# first, retain all the other bits (bits greater than 1), then add in the new/cleaned DO_NOT_USE bit
do_not_use = dqflags.pixel['DO_NOT_USE']
new_dq = np.bitwise_and(pxdq.copy(), np.invert(do_not_use)) # retain all other bits except the do_not_use bit
new_dq = np.bitwise_or(new_dq, pxdq_temp) # add in the do_not_use bit from the cleaned version
new_dq = new_dq.astype(np.uint32) # ensure correct output type for saving
# (the bitwise steps otherwise return np.int64 which isn't FITS compatible)
# Finish figure for this file
if plot:
ax.hist(data.flatten(),
bins=int(np.sqrt(len(data.flatten()))),
histtype='step',
label='Post Cleaning')
ax.legend()
# ax.set_xscale('log')
ax.set_yscale('log')
ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12)
ax.set_xlabel("Pixel Value", fontsize=14)
ax.set_ylabel("Frequency", fontsize=12)
ax.set_title(f"{os.path.basename(fitsfile)} \n Original vs. Cleaned Data", fontsize=16)
output_file = os.path.join(output_dir, tail.replace('.fits', '_hist.png'))
plt.savefig(output_file)
plt.close(fig)
# Write FITS file and PSF mask.
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, new_dq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
nanmaskfile = ut.write_msk(fitsfile, nanmask, fitsfile, '_nanmask.fits')
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile, nanmaskfile=nanmaskfile)
pass
[docs]
def find_bad_pixels_sigclip(self,
data,
erro,
pxdq,
NON_SCIENCE,
sigclip_kwargs=None):
"""
Use an iterative sigma clipping algorithm to identify additional bad
pixels in the data.
Parameters
----------
data : 3D-array
Input images.
erro : 3D-array
Input image uncertainties.
pxdq : 3D-array
Input binary bad pixel maps (1 = bad, 0 = good). Will be updated by
the routine to include the newly identified bad pixels.
NON_SCIENCE : 3D-array
Input binary non-science pixel maps (1 = bad, 0 = good). Will not
be modified by the routine.
sigclip_kwargs : dict, optional
Keyword arguments for the 'sigclip' method. Available keywords are:
- sigma : float, optional
Sigma clipping threshold. The default is 5.
- neg_sigma : float, optional
Sigma clipping threshold for negative outliers. The default is 1.
- shift_x : list of int, optional
Pixels in x-direction to which each pixel shall be compared to.
The default is [-1, 0, 1].
- shift_y : list of int, optional
Pixels in y-direction to which each pixel shall be compared to.
The default is [-1, 0, 1].
- diagonal_only : bool, optional
Only compare to diagonal neighbors? The default is False.
-threshold_metric : str, optional
Whether to use standard deviation or MAD
- max_cluster_size : int, optional
Maximum size of bad pixel clusters to be flagged. If None, no limit is applied.
- cluster_dilate_radius : int, optional
Radius for dilating bad pixels before checking clustering. The default is 6 pixels.
- mode : str, optional
Sigma-clipping strategy used to identify bad pixels. Available options are:
'local' : Flags pixels that deviate from the median of their neighboring pixels.
Large negative outliers are also identified by comparing to background estimate.
'local_weighted' : Same as 'local', but includes the pixel uncertainty when computing the clipping threshold.
- mask_psf : bool, optional
Restrict bad pixel flagging inside the PSF?
- crpix1/crpix2: float/float, optional
The center of the PSF.
The default is {}.
Returns
-------
None.
"""
# Protection for mutability
if sigclip_kwargs is None:
sigclip_kwargs = {}
else:
sigclip_kwargs = sigclip_kwargs.copy()
# Check inputs.
if 'sigma' not in sigclip_kwargs.keys():
sigclip_kwargs['sigma'] = 5.
if 'neg_sigma' not in sigclip_kwargs.keys():
sigclip_kwargs['neg_sigma'] = 1.
if 'shift_x' not in sigclip_kwargs.keys():
sigclip_kwargs['shift_x'] = [-1, 0, 1]
if 'shift_y' not in sigclip_kwargs.keys():
sigclip_kwargs['shift_y'] = [-1, 0, 1]
if 0 not in sigclip_kwargs['shift_x']:
sigclip_kwargs['shift_x'] += [0]
if 0 not in sigclip_kwargs['shift_y']:
sigclip_kwargs['shift_y'] += [0]
if 'mode' not in sigclip_kwargs.keys():
sigclip_kwargs['mode'] = 'local'
if 'diagonal_only' not in sigclip_kwargs.keys():
sigclip_kwargs['diagonal_only'] = False
if 'threshold_metric' not in sigclip_kwargs.keys():
sigclip_kwargs['threshold_metric'] = 'std'
if 'max_cluster_size' not in sigclip_kwargs.keys():
sigclip_kwargs['max_cluster_size'] = None
if 'cluster_dilate_radius' not in sigclip_kwargs.keys():
sigclip_kwargs['cluster_dilate_radius'] = 6
if 'mask_psf' not in sigclip_kwargs.keys():
sigclip_kwargs['mask_psf'] = False
if 'crpix1' not in sigclip_kwargs.keys():
sigclip_kwargs['crpix1'] = None
if 'crpix2' not in sigclip_kwargs.keys():
sigclip_kwargs['crpix2'] = None
# Build optional PSF mask.
psf_mask = None
if sigclip_kwargs['mask_psf'] and sigclip_kwargs['crpix1'] is not None and sigclip_kwargs['crpix2'] is not None:
ny, nx = data.shape[1:]
Y, X = np.meshgrid(np.arange(ny), np.arange(nx), indexing='ij')
psf_radius = 50 # Just setting to a general value for radius.
psf_mask = ((X - sigclip_kwargs['crpix1'] )**2 + (Y - sigclip_kwargs['crpix2'] )**2) < psf_radius**2
# Pad data.
pad_left = np.abs(np.min(sigclip_kwargs['shift_x']))
pad_right = np.abs(np.max(sigclip_kwargs['shift_x']))
if pad_right == 0:
right = None
else:
right = -pad_right
pad_bottom = np.abs(np.min(sigclip_kwargs['shift_y']))
pad_top = np.abs(np.max(sigclip_kwargs['shift_y']))
if pad_top == 0:
top = None
else:
top = -pad_top
pad_vals = ((pad_bottom, pad_top), (pad_left, pad_right))
# Find bad pixels using median of neighbors.
pxdq_orig = pxdq.copy()
ww = (pxdq != 0)
data_temp = data.copy()
data_temp[ww] = np.nan
erro_temp = erro.copy()
erro_temp[ww] = np.nan
for i in range(ww.shape[0]):
# Create initial mask of large negative values.
if sigclip_kwargs['mode'] == 'local':
# Get median background and standard deviation.
bg_med = np.nanmedian(data_temp[i])
bg_std = robust.medabsdev(data_temp[i])
bg_ind = data[i] < (bg_med + 10. * bg_std) # clip bright PSFs for final calculation
bg_med = np.nanmedian(data_temp[i][bg_ind])
bg_std = robust.medabsdev(data_temp[i][bg_ind])
#ww[i] = ww[i] | (data[i] < bg_med - sigclip_kwargs['neg_sigma'] * bg_std)
threshold_neg = bg_med - sigclip_kwargs['neg_sigma'] * bg_std
mask_neg = data[i] < threshold_neg
ww[i][NON_SCIENCE[i]] = 0
# Restrict flagging in the PSF center.
if psf_mask is not None:
ww[i][psf_mask] = 0
# Loop through max 10 iterations.
for it in range(10):
data_temp[i][ww[i]] = np.nan
erro_temp[i][ww[i]] = np.nan
# Shift data and calculate median and standard deviation of neighbours.
pad_data = np.pad(data_temp[i], pad_vals, mode='edge')
pad_erro = np.pad(erro_temp[i], pad_vals, mode='edge')
data_arr_med, data_arr_std = [], []
erro_arr_med, erro_arr_std = [], []
for ix in sigclip_kwargs['shift_x']:
for iy in sigclip_kwargs['shift_y']:
if ix==0 and iy==0:
# Don't want pixel itself
continue
data_arr_std += [np.roll(pad_data, (iy, ix), axis=(0, 1))]
erro_arr_std += [np.roll(pad_erro, (iy, ix), axis=(0, 1))]
if sigclip_kwargs['diagonal_only'] and abs(ix) != abs(iy):
# If diagonal_only is True, only want to include diagonal neighbors in median estimate
continue
data_arr_med += [np.roll(pad_data, (iy, ix), axis=(0, 1))]
erro_arr_med += [np.roll(pad_erro, (iy, ix), axis=(0, 1))]
data_arr_med = np.array(data_arr_med)
data_arr_med_trim = data_arr_med[:, pad_bottom:top, pad_left:right]
data_med = np.nanmedian(data_arr_med_trim, axis=0)
diff = data[i] - data_med
data_arr_std = np.array(data_arr_std)
data_arr_std_trim = data_arr_std[:, pad_bottom:top, pad_left:right]
if sigclip_kwargs['threshold_metric'] == 'std':
data_std = np.nanstd(data_arr_std_trim, axis=0)
elif sigclip_kwargs['threshold_metric'] == 'mad':
data_std = robust.medabsdev(data_arr_std_trim)
data_std_weighted = np.sqrt(data_std**2 + erro[i]**2)
if sigclip_kwargs['mode'] == 'local':
# Find values N standard deviations above the mean of neighbors.
threshold = sigclip_kwargs['sigma'] * data_std
mask_pos = diff > threshold
elif sigclip_kwargs['mode'] == 'local_weighted':
threshold = sigclip_kwargs['sigma'] * data_std_weighted
threshold_neg = -sigclip_kwargs['neg_sigma'] * data_std_weighted
mask_pos = diff > threshold
mask_neg = diff < threshold_neg
else:
mask_neg = np.zeros_like(ww[i]) # Default empty mask
mask_pos = np.zeros_like(ww[i]) # Default empty mask
# Restrict to specific cluster size if specified
if sigclip_kwargs['max_cluster_size'] is not None and sigclip_kwargs['cluster_dilate_radius'] is not None:
structure = np.ones((2*sigclip_kwargs['cluster_dilate_radius']+1,
2*sigclip_kwargs['cluster_dilate_radius']+1), dtype=bool)
def filter_clusters(mask, max_size):
# Dilate mask to merge nearby features
dilated = binary_dilation(mask, structure=structure)
labeled, num_features = label(dilated)
# Prepare final mask (only original pixels)
final_mask = np.zeros_like(mask, dtype=bool)
for feature in range(1, num_features + 1):
# Only count original masked pixels in this feature
original_cluster = (labeled == feature) & mask
cluster_size = np.sum(original_cluster)
if cluster_size <= max_size:
# keep original pixels
final_mask[original_cluster] = True
return final_mask
# Apply to positive outliers
mask_pos = filter_clusters(mask_pos, sigclip_kwargs['max_cluster_size'])
# Apply to negative outliers
mask_neg = filter_clusters(mask_neg, sigclip_kwargs['max_cluster_size'])
# Apply PSF mask.
if psf_mask is not None:
mask_pos[psf_mask] = False
mask_neg[psf_mask] = False
# Calculate how many NEW bad pixels were found.
mask_new = mask_pos | mask_neg
nmask_new = np.sum(mask_new & ~ww[i])
sys.stdout.write('\rFrame %.0f/%.0f, iteration %.0f' % (i + 1, ww.shape[0], it + 1))
sys.stdout.flush()
if it > 0 and nmask_new == 0:
break
ww[i] = ww[i] | mask_new
ww[i][NON_SCIENCE[i]] = 0
pxdq[i][ww[i]] = 1
print('')
log.info(' --> Method sigclip: identified %.0f additional bad pixel(s) -- %.2f%%' % (np.sum(pxdq) - np.sum(pxdq_orig), 100. * (np.sum(pxdq) - np.sum(pxdq_orig)) / np.prod(pxdq.shape)))
pass
[docs]
def find_bad_pixels_timeints(self,
data,
erro,
pxdq,
NON_SCIENCE,
timeints_kwargs=None):
"""
Identify bad pixels from temporal variations across integrations.
Parameters
----------
data : 3D-array
Input images.
erro : 3D-array
Input image uncertainties.
pxdq : 3D-array
Input binary bad pixel maps (1 = bad, 0 = good). Will be updated by
the routine to include the newly identified bad pixels.
NON_SCIENCE : 3D-array
Input binary non-science pixel maps (1 = bad, 0 = good). Will not
be modified by the routine.
timeints_kwargs : dict, optional
Keyword arguments for the 'timeints' method. Available keywords are:
- sigma : float, optional
Sigma clipping threshold. The default is 5.
- mode : str, optional
Mode for detecting bad pixels. The default is 'per_pixel'.
'per_pixel' : Computes the variation across integrations independently for each pixel.
'group_pixels' : Groups pixels by similar flux, computes the variation across integrations for each pixel, and compares each pixel’s variation to that of its corresponding flux group.
- n_groups : int, optional
The number of groups if mode == 'group_pixels'. The default is 25.
- diagnostic_plots : bool, optional
Plot diagnostics?
The default is {}.
Returns
-------
None.
"""
# Protection for mutability
if timeints_kwargs is None:
timeints_kwargs = {}
else:
timeints_kwargs = timeints_kwargs.copy()
# Check inputs.
if 'sigma' not in timeints_kwargs.keys():
timeints_kwargs['sigma'] = 10.
if 'mode' not in timeints_kwargs.keys():
timeints_kwargs['mode'] = "per_pixel"
if timeints_kwargs['mode'] not in ("group_pixels", "per_pixel"):
raise ValueError(f"Unknown timeints method: {method!r}")
if 'n_groups' not in timeints_kwargs.keys():
timeints_kwargs['n_groups'] = 25 # Number of pixel groups.
if 'plot_diagnostics' not in timeints_kwargs.keys():
timeints_kwargs['plot_diagnostics'] = False
# Mask existing bad pixels.
pxdq_orig = pxdq.copy()
ww = pxdq != 0
data_temp = data.copy()
data_temp[ww] = np.nan
# Find bad pixels across the cube.
# Compute per-pixel statistics.
med_ints = np.nanmedian(data_temp, axis=0)
mad_ints = robust.medabsdev(data_temp, axis=0)
absdiff = np.abs((data_temp - med_ints))
if timeints_kwargs['mode'] == "group_pixels":
# Low flux / negative pixels can be overflagged.
# Compute some floor.
bg_data = data_temp[data_temp < np.nanpercentile(data_temp, 50)]
bg_med = np.nanmedian(bg_data)
bg_mad = robust.medabsdev(bg_data)
#print(f"[timeints] MADINTS={np.nanmedian(mad_ints)}, BGMAD={bg_mad:.4g}")
dev_pixel = absdiff / np.maximum(mad_ints, bg_mad)
else:
dev_pixel = absdiff / mad_ints
# Compute group statistics.
groupID_map = None
group_bounds = None
group_stats = []
if timeints_kwargs['mode'] == "group_pixels":
# Positive, finite median values for clustering.
finite_mask = np.isfinite(med_ints)
pos_mask = finite_mask & (med_ints > 0)
pos_mask = finite_mask & (med_ints > 0)
med_ints_flat = med_ints[pos_mask].flatten()
if med_ints_flat.size < timeints_kwargs['n_groups']:
print(f"[timeints] Only {med_ints_flat.size} positive pixels; "
f"cannot form {timeints_kwargs['n_groups']} groups. "
"Falling back to per-pixel temporal outlier detection.")
timeints_kwargs['mode'] = "per_pixel"
else:
# Define flux groups in log-space.
X = np.log10(med_ints_flat).reshape(-1, 1)
km = KMeans(n_clusters=timeints_kwargs['n_groups'], n_init="auto", random_state=0).fit(X)
# Define group boundaries.
centers = np.sort(10.0 ** km.cluster_centers_.ravel()) # log to linear.
edges = np.concatenate((
[med_ints_flat.min()],
0.5 * (centers[:-1] + centers[1:]),
[med_ints_flat.max()],
))
group_bounds = list(zip(edges[:-1], edges[1:]))
# Assign each pixel to a flux group.
groupID_map = np.zeros_like(med_ints, dtype=int)
for gid, (lo, hi) in enumerate(group_bounds, start=1):
mask = (med_ints >= lo) & (med_ints <= hi)
groupID_map[mask] = gid
# Any remaining pixels are ungrouped. Mostly negative.
unassigned = (groupID_map == 0)
if np.any(unassigned):
# Just some checks of the unassigned pixels.
n_total = groupID_map.size
n_nan = np.sum(np.isnan(med_ints) & unassigned)
n_neg = np.sum((med_ints <= 0) & unassigned)
n_pos = np.sum((med_ints > 0) & unassigned)
#print(
# f"[timeints] Unassigned pixels: {np.sum(unassigned)} ({100*(np.sum(unassigned)/n_total):.2f}%) | "
# f"NaN: {n_nan} | "
# f"Negative Unassigned Pixels: {n_neg} | "
# f"Positive Unassigned Pixels: {n_pos} | using per-pixel method instead."
#)
groupID_map[unassigned] = timeints_kwargs['n_groups'] + 1
# Calculate group statistics.
group_stats = []
max_gid = int(np.nanmax(groupID_map))
for gid in range(1, timeints_kwargs['n_groups'] + 1):
mask = (groupID_map == gid)
if not np.any(mask):
continue
group_vals = data_temp[:, mask].flatten()
group_stats.append({
"group_id": gid,
"range": group_bounds[gid - 1],
"median": np.nanmedian(group_vals),
"mad": robust.medabsdev(group_vals),
})
# Find bad pixels.
bad_pixels = np.zeros_like(data_temp, dtype=bool)
if timeints_kwargs['mode'] == "group_pixels":
# Compare each pixel to its associated group statistics.
for g in group_stats:
gid = g["group_id"]
gmad = g["mad"]
mask = (groupID_map == gid)
ys, xs = np.where(mask)
#print(f"[timeints] Group {gid}: {len(ys)} pixels, MAD={gmad:.4g}")
dev_group = np.abs(data_temp[:, mask] - med_ints[mask]) / gmad
bad_pixels[:, mask] = dev_group > timeints_kwargs['sigma']
# Ungrouped pixels are compared to itself over time.
ungrouped_id = int(np.nanmax(groupID_map))
ungrouped_mask = (groupID_map == ungrouped_id)
if np.any(ungrouped_mask):
ys, xs = np.where(ungrouped_mask)
bad_pixels[:, ungrouped_mask] = dev_pixel[:, ungrouped_mask] > timeints_kwargs['sigma']
else:
# Compare each pixel to itself over time.
bad_pixels= dev_pixel > timeints_kwargs['sigma']
# Print/save results.
ww = ww | bad_pixels
pxdq[ww] = 1
log.info(f' --> Method timeints: identified %.0f additional bad pixel(s) -- %.2f%%' % (np.sum(pxdq) - np.sum(pxdq_orig), 100. * (np.sum(pxdq) - np.sum(pxdq_orig)) / np.prod(pxdq.shape)))
print('')
if timeints_kwargs['plot_diagnostics'] == True:
############################################
# DIAGNOSTIC PLOTS
############################################
n_int, ny, nx = data_temp.shape
# Plot grouping diagnostics.
if timeints_kwargs['mode'] == "group_pixels":
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(17, 6))
# GroupID map.
max_gid = int(np.nanmax(groupID_map))
cmap = plt.get_cmap("tab20", max_gid)
im = ax1.imshow(groupID_map, origin="lower", cmap=cmap, vmin=1, vmax=max_gid)
ax1.set(title="Group IDs", xlabel="X", ylabel="Y")
cbar = fig.colorbar(im, ax=ax1)
ticks = np.arange(1, max_gid + 1)
labels = [f"G{i}" for i in range(1, max_gid)] + ["Ungrouped"]
cbar.set_ticks(ticks)
cbar.set_ticklabels(labels)
cbar.set_label("Group ID")
# Histogram of positive median values + groups.
bins = np.logspace(np.log10(med_ints_flat.min()), np.log10(med_ints_flat.max()), 200)
ax2.hist(med_ints_flat, bins=bins, color="gray", alpha=0.6)
ymin, ymax = ax2.get_ylim()
for i, (lo, hi) in enumerate(group_bounds, start=1):
color = cmap((i - 1) / max_gid)
ax2.axvline(lo, linestyle="--", linewidth=0.8, color="black")
ax2.axvline(hi, linestyle="--", linewidth=0.8, color="black")
ax2.text(np.sqrt(lo * hi), ymax * 0.07, f"G{i}", color=color,
ha="center", va="bottom", rotation=90, fontsize=9)
ax2.set(xscale="log", yscale="log", xlabel="DN/s [log scale]", ylabel="Counts [log scale]")
fig.tight_layout()
plt.show(block=False)
plt.pause(0.001)
# Plot bad pixels found.
n_ints = data_temp.shape[-2]
ncols = min(5, n_int)
nrows = -(-n_int // ncols)
fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 5 * nrows), sharey=True)
axes = np.asarray(axes).ravel()
for i, ax in enumerate(axes):
if i >= n_int:
ax.axis("off")
continue
vmin, vmax = np.nanpercentile(data_temp, (1, 99))
ax.imshow(data_temp[i], origin="lower", vmin=vmin, vmax=vmax, cmap="gray")
yb, xb = np.where(bad_pixels[i])
if xb.size:
ax.plot(xb, yb, "rx", ms=3)
ax.set_title(f"Integration {i+1}")
ax.set_xlabel("X")
if i == 0:
ax.set_ylabel("Y")
fig.tight_layout()
plt.show(block=False)
plt.pause(0.001)
pass
[docs]
def find_bad_pixels_gradient(self,
data,
erro,
pxdq,
key,
gradient_kwargs=None):
print('')
log.info(' --> Warning!: This routine has not been thoroughly tested and requires further development')
# Protection for mutability
if gradient_kwargs is None:
gradient_kwargs = {}
else:
gradient_kwargs = gradient_kwargs.copy()
# Check input.
if 'sigma' not in gradient_kwargs.keys():
gradient_kwargs['sigma'] = 0.5
if 'threshold' not in gradient_kwargs.keys():
gradient_kwargs['threshold'] = 0.05
if 'negative' not in gradient_kwargs.keys():
gradient_kwargs['negative'] = True
sig = gradient_kwargs['sigma']
threshold = gradient_kwargs['threshold']
negative = gradient_kwargs['negative']
pxdq_orig = pxdq.copy()
ww = pxdq != 0
data_temp = data.copy()
data_temp[ww] = np.nan
# Loop over the images
for i in range(ww.shape[0]):
image = data_temp[i]
# remove nans
x = np.arange(0, image.shape[1])
y = np.arange(0, image.shape[0])
xx, yy = np.meshgrid(x, y)
# mask nans
image = np.ma.masked_invalid(image)
xvalid = xx[~image.mask]
yvalid = yy[~image.mask]
newimage = image[~image.mask]
image_no_nans = griddata((xvalid, yvalid),
newimage.ravel(),
(xx, yy),
method='linear')
# get smooth image
smimage=gaussian_filter(image_no_nans, sigma=sig)
# get sharp image
shimage = image_no_nans-smimage
# get gradients
image_to_gradient = shimage/smimage
gr = np.gradient((image_to_gradient))
gr_dx = gr[1]
gr_dy = gr[0]
# pad gradient adding 1 extra pixel at beginning and end
gr_dxp = np.pad(gr_dx, (1, 1))
gr_dyp = np.pad(gr_dy, (1, 1))
# identify bad pixels
# positive
bad_pixels = (gr_dxp[1:-1, 2:] < -threshold) & (gr_dxp[1:-1, :-2] > threshold) & (gr_dyp[2:, 1:-1] < -threshold) & (gr_dyp[:-2, 1:-1] > threshold)
# negative
if negative:
bad_pixels_n = (gr_dxp[1:-1, 2:] > threshold) & (gr_dxp[1:-1, :-2] < -threshold) & (gr_dyp[2:, 1:-1] > threshold) & (gr_dyp[:-2, 1:-1] < threshold)
bad_pixels = bad_pixels | bad_pixels_n
image[bad_pixels] = np.nan
# Flag DQ array
ww[i] = ww[i] | bad_pixels
pxdq[i][ww[i]] = 1
print('')
log.info(' --> Method gradient: identified %.0f additional bad pixel(s) -- %.2f%%' % (np.sum(pxdq) - np.sum(pxdq_orig), 100. * (np.sum(pxdq) - np.sum(pxdq_orig)) / np.prod(pxdq.shape)))
pass
[docs]
def find_bad_pixels_custom(self,
data,
erro,
pxdq,
key,
custom_kwargs=None):
"""
Use a custom bad pixel map to flag additional bad pixels in the data.
Parameters
----------
data : 3D-array
Input images.
erro : 3D-array
Input image uncertainties.
pxdq : 3D-array
Input binary bad pixel maps (1 = bad, 0 = good). Will be updated by
the routine to include the newly flagged bad pixels.
key : str
Database key of the observation to be updated.
custom_kwargs : dict, optional
Keyword arguments for the 'custom' method. The basic usage is dictionary with
keys that must match the keys of the observations database and the dictionary
content must be binary bad pixel maps (1 = bad, 0 = good) with the
same shape as the corresponding data.
It also allow for a more fine treatment to flag bad pixels at specific coordinates.
In this case the dictionary content must be a list of [y,x] or [key,y,x] coordinates.
The default is {}.
Returns
-------
None.
"""
# Protection for mutability
if custom_kwargs is None:
custom_kwargs = {}
else:
custom_kwargs = custom_kwargs.copy()
# Find bad pixels using median of neighbors.
pxdq_orig = pxdq.copy()
# pxdq_custom = custom_kwargs[key] != 0
# if pxdq_custom.ndim == pxdq.ndim - 1: # Enable 3D bad pixel map to flag individual frames
# pxdq_custom = np.array([pxdq_custom] * pxdq.shape[0])
# pxdq[pxdq_custom] = 1
# log.info(' --> Method custom: flagged %.0f additional bad pixel(s) -- %.2f%%' % (np.sum(pxdq) - np.sum(pxdq_orig), 100. * (np.sum(pxdq) - np.sum(pxdq_orig)) / np.prod(pxdq.shape)))
if key in custom_kwargs.keys():
if np.array(custom_kwargs[key]).shape == pxdq_orig.shape:
pxdq_custom = custom_kwargs[key] != 0
else:
pxqd_temp = np.zeros(pxdq.shape)
coordinates = np.array(custom_kwargs[key])
if coordinates.shape[-1] == 2:
pxqd_temp[:, coordinates[:, 1], coordinates[:, 0]] = 1
elif coordinates.shape[-1] == 3:
pxqd_temp[coordinates[:, 0], coordinates[:, 2], coordinates[:, 1]] = 1
pxdq_custom = pxqd_temp != 0
if pxdq_custom.ndim == pxdq.ndim - 1: # Enable 3D bad pixel map to flag individual frames
pxdq_custom = np.array([pxdq_custom] * pxdq.shape[0])
pxdq[pxdq_custom] = 1
log.info(
' --> Method custom: flagged %.0f additional bad pixel(s) -- %.2f%%' % (np.sum(pxdq) - np.sum(pxdq_orig),
100. * (np.sum(pxdq) - np.sum(pxdq_orig)) / np.prod(pxdq.shape)))
pass
[docs]
def fix_bad_pixels_timemed(self,
data,
erro,
pxdq,
timemed_kwargs=None):
"""
Replace pixels which are only bad in some frames with their median
value from the good frames.
Parameters
----------
data : 3D-array
Input images.
erro : 3D-array
Input image uncertainties.
pxdq : 3D-array
Input binary bad pixel maps (1 = bad, 0 = good). Will be updated by
the routine to exclude the fixed bad pixels.
timemed_kwargs : dict, optional
Keyword arguments for the 'timemed' method. Available keywords are:
- n/a
The default is {}.
Returns
-------
None.
"""
# Protection for mutability
if timemed_kwargs is None:
timemed_kwargs = {}
else:
timemed_kwargs = timemed_kwargs.copy()
# Fix bad pixels using time median.
ww = pxdq != 0
ww_all_bad = np.array([np.sum(ww, axis=0) == ww.shape[0]] * ww.shape[0])
ww_not_all_bad = ww & np.logical_not(ww_all_bad)
log.info(' --> Method timemed: fixing %.0f bad pixel(s) -- %.2f%%' % (np.sum(ww_not_all_bad), 100. * np.sum(ww_not_all_bad) / np.prod(ww_not_all_bad.shape)))
data[ww_not_all_bad] = np.nan
data[ww_not_all_bad] = np.array([np.nanmedian(data, axis=0)] * data.shape[0])[ww_not_all_bad]
erro[ww_not_all_bad] = np.nan
erro[ww_not_all_bad] = np.array([np.nanmedian(erro, axis=0)] * erro.shape[0])[ww_not_all_bad]
pxdq[ww_not_all_bad] = 0
pass
[docs]
def fix_bad_pixels_localmed(self,
data,
erro,
pxdq,
localmed_kwargs=None):
"""
Replace bad pixels with the median value of their surrounding good
pixels.
Parameters
----------
data : 3D-array
Input images.
erro : 3D-array
Input image uncertainties.
pxdq : 3D-array
Input binary bad pixel maps (1 = bad, 0 = good). Will be updated by
the routine to exclude the fixed bad pixels.
localmed_kwargs : dict, optional
Keyword arguments for the 'localmed' method. Available keywords are:
- shift_x : list of int, optional
Pixels in x-direction from which the median shall be computed.
The default is [-1, 0, 1].
- shift_y : list of int, optional
Pixels in y-direction from which the median shall be computed.
The default is [-1, 0, 1].
The default is {}.
Returns
-------
None.
"""
# Protection for mutability
if localmed_kwargs is None:
localmed_kwargs = {}
else:
localmed_kwargs = localmed_kwargs.copy()
# Check input.
if 'shift_x' not in localmed_kwargs.keys():
localmed_kwargs['shift_x'] = [-1, 0, 1]
if 'shift_y' not in localmed_kwargs.keys():
localmed_kwargs['shift_y'] = [-1, 0, 1]
if 0 not in localmed_kwargs['shift_x']:
localmed_kwargs['shift_x'] += [0]
if 0 not in localmed_kwargs['shift_y']:
localmed_kwargs['shift_y'] += [0]
# Pad data.
pad_left = np.abs(np.min(localmed_kwargs['shift_x']))
pad_right = np.abs(np.max(localmed_kwargs['shift_x']))
if pad_right == 0:
right = None
else:
right = -pad_right
pad_bottom = np.abs(np.min(localmed_kwargs['shift_y']))
pad_top = np.abs(np.max(localmed_kwargs['shift_y']))
if pad_top == 0:
top = None
else:
top = -pad_top
pad_vals = ((0, 0), (pad_bottom, pad_top), (pad_left, pad_right))
# Fix bad pixels using median of neighbors.
ww = pxdq != 0
data_temp = data.copy()
data_temp[ww] = np.nan
pad_data = np.pad(data_temp, pad_vals, mode='edge')
erro_temp = erro.copy()
erro_temp[ww] = np.nan
pad_erro = np.pad(erro_temp, pad_vals, mode='edge')
for i in range(ww.shape[0]):
data_arr = []
erro_arr = []
for ix in localmed_kwargs['shift_x']:
for iy in localmed_kwargs['shift_y']:
if ix != 0 or iy != 0:
data_arr += [np.roll(pad_data[i], (iy, ix), axis=(0, 1))]
erro_arr += [np.roll(pad_erro[i], (iy, ix), axis=(0, 1))]
data_arr = np.array(data_arr)
data_arr = data_arr[:, pad_bottom:top, pad_left:right]
data_med = np.nanmedian(data_arr, axis=0)
ww[i][np.isnan(data_med)] = 0
data[i][ww[i]] = data_med[ww[i]]
erro_arr = np.array(erro_arr)
erro_arr = erro_arr[:, pad_bottom:top, pad_left:right]
erro_med = np.nanmedian(erro_arr, axis=0)
erro[i][ww[i]] = erro_med[ww[i]]
pxdq[i][ww[i]] = 0
log.info(' --> Method localmed: fixing %.0f bad pixel(s) -- %.2f%%' % (np.sum(ww), 100. * np.sum(ww) / np.prod(ww.shape)))
pass
[docs]
def fix_bad_pixels_medfilt(self,
data,
erro,
pxdq,
medfilt_kwargs=None):
"""
Replace bad pixels with an image plane median filter.
Parameters
----------
data : 3D-array
Input images.
erro : 3D-array
Input image uncertainties.
pxdq : 3D-array
Input binary bad pixel maps (1 = bad, 0 = good). Will be updated by
the routine to exclude the fixed bad pixels.
medfilt_kwargs : dict, optional
Keyword arguments for the 'medfilt' method. Available keywords are:
- size : int, optional
Kernel size of the median filter to be used. The default is 4.
The default is {}.
Returns
-------
None.
"""
# Protection for mutability
if medfilt_kwargs is None:
medfilt_kwargs = {}
else:
medfilt_kwargs = medfilt_kwargs.copy()
# Check input.
if 'size' not in medfilt_kwargs.keys():
medfilt_kwargs['size'] = 4
# Fix bad pixels using median filter.
ww = pxdq != 0
log.info(' --> Method medfilt: fixing %.0f bad pixel(s) -- %.2f%%' % (np.sum(ww), 100. * np.sum(ww) / np.prod(ww.shape)))
data_temp = data.copy()
data_temp[np.isnan(data_temp)] = 0.
erro_temp = erro.copy()
erro_temp[np.isnan(erro_temp)] = 0.
for i in range(ww.shape[0]):
data[i][ww[i]] = median_filter(data_temp[i], **medfilt_kwargs)[ww[i]]
erro[i][ww[i]] = median_filter(erro_temp[i], **medfilt_kwargs)[ww[i]]
pxdq[i][ww[i]] = 0
pass
[docs]
def fix_bad_pixels_astrofix(self,
data,
erro,
pxdq,
astrofix_kwargs=None,
plot=False):
"""
Replace bad pixels with an algorithm based on Gaussian Process Regression.
It trains itself to apply the optimal interpolation kernel for each image,
performing multiple times better than median replacement and interpolation
with a fixed kernel.
Parameters
----------
data : 3D-array
Input images.
erro : 3D-array
Input image uncertainties.
pxdq : 3D-array
Input binary bad pixel maps (1 = bad, 0 = good). Will be updated by
the routine to exclude the fixed bad pixels.
astrofix_kwargs : dict, optional
Keyword arguments for the 'astrofix' method. Available keywords are:
- sig_clip : float, optional
Pixels that are smaller than median + sig_clip * median absolute deviation
of the image will not be used in the training process. Default: 10.
- max_clip : float, optional
Pixels that are greater than max(image)/max_clip will not be used in
the training process. Default: 5.
- sig_data : float, optional
Measurement noise, assumed to be uniform. The kernel depends only on the ratio a/sig_data. Default: 1.
- width : int, optional
Size of the window (width × width) used for interpolation. Default: 9.
- init_guess : array-like, optional
Initial guess for the training process. By default, the 0th element gives the initial guess of a and the
1st element gives the initial guess of h. If the size of init_guess is 3, the training optimizes h_x and
h_y separately instead of using h for all directions. In that case, the 1st element gives the initial
guess of h_x, and the 2nd element gives the initial guess of h_y. Default: [1,1].
The default is {}.
plot : bool, optional
Plot diagnostics?
Returns
-------
None.
"""
# Protection for mutability.
if astrofix_kwargs is None:
astrofix_kwargs = {}
else:
astrofix_kwargs = astrofix_kwargs.copy()
# Check inputs. These are astrofix defaults.
if 'sig_clip' not in astrofix_kwargs.keys():
astrofix_kwargs['sig_clip'] = 10
if 'max_clip' not in astrofix_kwargs.keys():
astrofix_kwargs['max_clip'] = 5
if 'sig_data' not in astrofix_kwargs.keys():
astrofix_kwargs['sig_data'] = 1
if 'width' not in astrofix_kwargs.keys():
astrofix_kwargs['width'] = 9
if 'init_guess' not in astrofix_kwargs.keys():
astrofix_kwargs['init_guess'] = [1, 1]
# Fix bad pixels using astrofix.
pxmask_nonsci = ut.get_dqmask(pxdq, 'NON_SCIENCE', return_bool=True)
ww = (pxdq != 0) & (~pxmask_nonsci)
log.info(' --> Method astrofix: fixing %.0f bad pixel(s) -- %.2f%%' % (np.sum(ww), 100. * np.sum(ww) / np.prod(ww.shape)))
# NaN pixels to be replaced.
data_temp = data.copy()
data_temp[np.where(np.isnan(data_temp))] = 0
data_temp[ww] = np.nan
# Prepare array to hold fixed images
fixed_img = np.zeros_like(data_temp)
# Loop over integrations
for i, integration in enumerate(data_temp):
if i == 0:
# First integration: run full Fix_Image to determine parameters
fixed_img[i], para, TS = astrofix.Fix_Image(integration, "asnan",
sig_clip=astrofix_kwargs['sig_clip'],
max_clip=astrofix_kwargs['max_clip'],
sig_data=astrofix_kwargs['sig_data'],
width=astrofix_kwargs['width'],
init_guess=astrofix_kwargs['init_guess']
)
#print(f"a={para[0]}, h={para[1]}")
#print(f"Number of training set pixels: {np.count_nonzero(TS)}")
else:
# Remaining integrations: interpolate using parameters from first
fixed_img[i] = astrofix.Interpolate(para[0], para[1], integration, BP="asnan")
# Update the original data using bad pixel mask
data[i][ww[i]] = fixed_img[i][ww[i]]
# Update DQ flag to good
pxdq[i][ww[i]] = 0
if plot:
from mpl_toolkits.axes_grid1 import make_axes_locatable
# -----------------------------
# Diagnostic plot
# -----------------------------
zoom = None#(100, 200, 100, 200)
frame = -1
if zoom is not None:
y1, y2, x1, x2 = zoom
images = [
data[frame, y1:y2, x1:x2],
fixed_img[frame, y1:y2, x1:x2],
]
else:
images = [
data[frame],
fixed_img[frame],
]
titles = ["Original", "Fixed"]
fig, ax = plt.subplots(1, 2, figsize=(18, 7))
for i in range(2):
im = ax[i].imshow(images[i], vmin=np.nanpercentile(images[i], 10), vmax=np.nanpercentile(images[i], 100))
divider = make_axes_locatable(ax[i])
cax = divider.append_axes("bottom", size="5%", pad=0.2)
fig.colorbar(im, ax=ax[i], cax=cax, orientation="horizontal")
ax[i].set_title(titles[i], fontsize=30, pad=15)
ax[i].axis("off")
plt.tight_layout()
plt.show()
pass
[docs]
def fix_bad_pixels_interp2d(self,
data,
erro,
pxdq,
interp2d_kwargs=None):
"""
Replace bad pixels with an interpolation of neighbouring pixels.
Parameters
----------
data : 3D-array
Input images.
erro : 3D-array
Input image uncertainties.
pxdq : 3D-array
Input binary bad pixel maps (1 = bad, 0 = good). Will be updated by
the routine to exclude the fixed bad pixels.
interp2d_kwargs : dict, optional
Keyword arguments for the 'interp2d' method. Available keywords are:
- size : int, optional
Kernel size of the median filter to be used. The default is 5.
- method : str, optional
Interpolation method to be used. The default is 'linear'. Available options are 'nearest', 'linear',
and 'cubic'.
The default is {}.
Returns
-------
None.
"""
# Protection for mutability.
if interp2d_kwargs is None:
interp2d_kwargs = {}
else:
interp2d_kwargs = interp2d_kwargs.copy()
# Check input.
if 'size' not in interp2d_kwargs.keys():
interp2d_kwargs['size'] = 5
if 'method' not in interp2d_kwargs.keys():
interp2d_kwargs['method'] = 'linear'
# Fix bad pixels using interpolation of neighbors.
pxmask_nonsci = ut.get_dqmask(pxdq, 'NON_SCIENCE', return_bool=True)
ww = (pxdq != 0) & (~pxmask_nonsci)
log.info(' --> Method interp2d: fixing %.0f bad pixel(s) -- %.2f%%' % (np.sum(ww), 100. * np.sum(ww) / np.prod(ww.shape)))
# NaN pixels to be replaced with interpolation
data_temp = data.copy()
data_temp[np.where(np.isnan(data_temp))] = 0
data_temp[ww] = np.nan
erro_temp = erro.copy()
erro_temp[np.where(np.isnan(erro_temp))] = 0
erro_temp[ww] = np.nan
rows, cols = data_temp[0].shape
half_box = interp2d_kwargs['size'] // 2
for i in range(ww.shape[0]):
for ri in range(rows):
for ci in range(cols):
if np.isnan(data_temp[i][ri, ci]):
# Calculate the indices of the NxN box centered around the NaN pixel
x_min = max(0, ci - half_box)
x_max = min(cols, ci + half_box + 1)
y_min = max(0, ri - half_box)
y_max = min(rows, ri + half_box + 1)
# Extract a NxN box within the valid range
box = data_temp[i][y_min:y_max, x_min:x_max]
ebox = erro_temp[i][y_min:y_max, x_min:x_max]
# Extract coordinates and values from the box
box_coords = np.array(np.where(~np.isnan(box))).T \
+ np.array([[x_min, y_min]])
box_values = box[~np.isnan(box)]
ebox_coords = np.array(np.where(~np.isnan(ebox))).T \
+ np.array([[x_min, y_min]])
ebox_values = ebox[~np.isnan(ebox)]
# Perform interpolation if there are valid values in the box
if len(box_values) > interp2d_kwargs['size'] \
and len(ebox_values) > interp2d_kwargs['size']:
# Extract x and y coordinates of valid values, same coords for
# data and err
x_coords = box_coords[:, 0]
y_coords = box_coords[:, 1]
ex_coords = ebox_coords[:, 0]
ey_coords = ebox_coords[:, 1]
# Perform interpolation of data
data_interp = griddata((x_coords, y_coords),
box_values,
(ci, ri),
method=interp2d_kwargs['method'],
fill_value=np.nan)
# Replace data pixel with interpolated value
data[i][ri, ci] = data_interp
# Perform interpolation of error
err_interp = griddata((ex_coords, ey_coords),
ebox_values,
(ci, ri),
method=interp2d_kwargs['method'],
fill_value=np.nan)
# Replace error pixel
erro[i][ri, ci] = err_interp
# Set DQ to good
pxdq[i][ri, ci] = 0
pass
[docs]
def clean_background(self,
sigma=3.,
gaussian_smoothing={'skip':True, 'sigma':1.},
types=['SCI_BG', 'REF_BG'],
subdir='bgcleaned'):
"""
Clean data from any contaminating sources in background observations using sigma clipping.
Used for MIRI background observations.
Parameters
----------
sigma : float, optional
Sigma clipping threshold for background cleaning. The default is 3.
gaussian_smoothing : dict, optional
Dictionary with keyword arguments to use scipy.ndimage.gaussian_filter.
This will smooth the backgrounds before the cleaning process. If 'skip' is True,
no Gaussian smoothing will be applied. The default is {'skip':True, 'sigma':1.}.
types : list of str, optional
List of data types for which background cleaning shall be applied.
The default is ['SCI_BG', 'REF_BG'].
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'bgcleaned'.
Returns
-------
None.
"""
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
for data_type in types:
log.info('--> Starting cleaning process for ' + data_type + ' files')
# find science background files
ww_bg = np.where(self.database.obs[key]['TYPE'] == data_type)[0]
# Loop through science background files.
if len(ww_bg) == 2:
bg_data = []
for j in ww_bg:
# Read FITS file.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
bg_data += [data]
bg_data = np.array(bg_data)
img1 = np.ones_like(bg_data[0])
img2 = np.ones_like(bg_data[1])
# Calculate SNR maps per pixel
snr_map1 = np.nanstd(bg_data[0], axis=0)
snr_map2 = np.nanstd(bg_data[1], axis=0)
if not gaussian_smoothing['skip']:
log.info(' --> Applying Gaussian smoothing with sigma = %.2f' % gaussian_smoothing['sigma'])
snr_map1 = gaussian_filter(bg_data[0], sigma=gaussian_smoothing['sigma'])
snr_map2 = gaussian_filter(bg_data[1], sigma=gaussian_smoothing['sigma'])
# Looping over integrations
for i in range(bg_data.shape[1]):
diff = (bg_data[0,i,:,:] - bg_data[1,i,:,:])/np.sqrt(snr_map1**2 + snr_map2**2)
img1[i,:,:] = np.where(diff > sigma, np.nan, bg_data[0,i,:,:])
img2[i,:,:] = np.where(diff < -sigma, np.nan, bg_data[1,i,:,:])
cleaned_bg_data = np.array([img1, img2])
# Write FITS file and PSF mask.
for i,j in enumerate(ww_bg):
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
mask = ut.read_msk(maskfile)
# Replace data with cleaned data
fitsfile = ut.write_obs(fitsfile, output_dir, cleaned_bg_data[i], erro, pxdq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile)
else:
raise NotImplementedError('Background cleaning currently only implemented if 2 background files available.')
# Saving unmodified files
log.info('--> Skipping cleaning process for non background files.')
nfitsfiles = len(self.database.obs[key])
for j in range(nfitsfiles):
# Skip file types that are in the list of types.
if self.database.obs[key]['TYPE'][j] not in types:
# Read FITS file.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
mask = ut.read_msk(maskfile)
# Write FITS file and PSF mask.
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile)
[docs]
def persistence_trimming(self,
radius_pxl=4,
ints_to_trim=2,
types=['SCI', 'REF'],
subdir='persistence_trimmed'):
"""
Remove persistence artifacts from the initial N integrations by trimming
a circular region around the persistence location. This is useful for
MIRI 4QPM data where persistence artifacts from target acquisition can
remain in early integrations.
Note that for this function to work, TA images must be included in the
database and previous data reduction steps must have been performed to
determine the persistence location.
Parameters
----------
radius_pxl : int, optional
Radius (in pixels) of the circular region around the persistence
location to trim/mask. The default is 4.
ints_to_trim : int, optional
Number of initial integrations to apply the persistence trimming to.
The default is 2.
types : list of str, optional
List of data types for which persistence trimming shall be applied.
The default is ['SCI', 'REF'].
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'persistence_trimmed'.
Returns
-------
None.
"""
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
# Loop through FITS files.
nfitsfiles = len(self.database.obs[key])
for j in range(nfitsfiles):
# Read FITS file.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
mask = ut.read_msk(maskfile)
# Skip file types that are not in the list of types.
if self.database.obs[key]['TYPE'][j] in types:
log.info(f' --> {self.database.obs[key]["TYPE"][j]} Persistence trimming: ' + os.path.basename(fitsfile))
# Get persistence location from TA observation.
ta_key = self.database.obs[key]['TYPE'][j] + '_TA'
ww_ta = np.where(self.database.obs[key]['TYPE'] == ta_key)[0]
pers_loc = []
log.info(f' --> Finding persistence location in ' + ta_key)
for k in ww_ta:
# Match ROLL angle to find correct TA observation.
if round(self.database.obs[key]['ROLL_REF'][j], 2) == round(self.database.obs[key]['ROLL_REF'][k], 2):
log.info(f' --> Matching TA found: ' + os.path.basename(self.database.obs[key]['FITSFILE'][k]))
# Read TA FITS file.
fitsfile_ta = self.database.obs[key]['FITSFILE'][k]
data_ta, erro_ta, pxdq_ta, head_pri_ta, head_sci_ta, is2d_ta, align_shift_ta, center_shift_ta, align_mask_ta, center_mask_ta, maskoffs_ta = ut.read_obs(fitsfile_ta)
# Get persistence location from TA data.
pers_loc += [np.where(data_ta == np.nanmax(data_ta))[1:]]
if len(pers_loc) == 0:
raise ValueError('No matching TA observation found for ' + os.path.basename(fitsfile) + '. Cannot determine persistence location.')
elif len(pers_loc) == 1:
raise ValueError('Only one matching TA observation found for ' + os.path.basename(fitsfile) + '. Need the two TA observations to find both persistence locations.')
# Check for valid number of integrations to trim.
if ints_to_trim >= data.shape[0]:
raise ValueError(f'Number of integrations to trim ({ints_to_trim}) is greater than or equal to total number of integrations in {self.database.obs[key]["TYPE"][j]} observations ({data.shape[0]}). \n Try again with a smaller number of integrations to trim or remove persistence trimming for {self.database.obs[key]["TYPE"][j]} observations.')
# Trim persistence region in initial integrations.
for nints in range(data.shape[0]):
if nints < ints_to_trim:
for loc in pers_loc:
for yy in range(data.shape[1]):
for xx in range(data.shape[2]):
if (yy - loc[0])**2 + (xx - loc[1])**2 < radius_pxl**2:
data[nints, yy, xx] = np.nan
# Write FITS file and PSF mask.
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile)
[docs]
def replace_nans(self,
cval=0.,
types=['SCI', 'SCI_BG', 'REF', 'REF_BG'],
subdir='nanreplaced'):
"""
Replace all nans in the data with a constant value.
Parameters
----------
cval : float, optional
Fill value for the nan pixels. The default is 0.
types : list of str, optional
List of data types for which nans shall be replaced. The default is
['SCI', 'SCI_BG', 'REF', 'REF_BG'].
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'nanreplaced'.
Returns
-------
None.
"""
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
# Loop through FITS files.
nfitsfiles = len(self.database.obs[key])
for j in range(nfitsfiles):
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
nanmaskfile = self.database.obs[key]['NANMASKFILE'][j]
mask = ut.read_msk(maskfile)
nanmask = ut.read_msk(nanmaskfile)
# Skip file types that are not in the list of types.
if self.database.obs[key]['TYPE'][j] in types:
# Replace nans.
head, tail = os.path.split(fitsfile)
log.info(' --> Nan replacement: ' + tail)
ww = np.isnan(data)
data[ww] = cval
log.info(' --> Nan replacement: replaced %.0f nan pixel(s) with value ' % (np.sum(ww)) + str(cval) + ' -- %.2f%%' % (100. * np.sum(ww)/np.prod(ww.shape)))
# Write FITS file and PSF mask.
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
nanmaskfile = ut.write_msk(nanmaskfile, nanmask, fitsfile, '_nanmask.fits')
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile, nanmaskfile=nanmaskfile)
pass
[docs]
def update_frames_with_nans_from_nanmask(self,
cval=np.nan,
types=['SCI', 'SCI_BG', 'REF', 'REF_BG'],
subdir='nansback'):
"""
Replace values in data wiht NaNs from the nanmask.
Parameters
----------
cval : float, optional
Fill value for the nan pixels. The default is 0.
types : list of str, optional
List of data types for which nans shall be replaced. The default is
['SCI', 'SCI_BG', 'REF', 'REF_BG'].
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'nanreplaced'.
Returns
-------
None.
"""
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
# Loop through FITS files.
nfitsfiles = len(self.database.obs[key])
for j in range(nfitsfiles):
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
nanmaskfile = self.database.obs[key]['NANMASKFILE'][j]
mask = ut.read_msk(maskfile)
nanmask = ut.read_msk(nanmaskfile)
# Skip file types that are not in the list of types.
if self.database.obs[key]['TYPE'][j] in types:
# Replace nans.
head, tail = os.path.split(fitsfile)
log.info(' --> Value replacement: ' + tail)
ww = (nanmask==1)
data[np.tile(ww, (data.shape[0], 1, 1))] = cval
log.info(' --> Value replacement: replaced %.0f pixel(s) with value ' % (np.sum(ww)) + str(cval) + ' -- %.2f%%' % (100. * np.sum(ww)/np.prod(ww.shape)))
# Write FITS file and PSF mask.
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs )
maskfile = ut.write_msk(maskfile, mask, fitsfile)
nanmaskfile = ut.write_msk(nanmaskfile, nanmask, fitsfile, '_nanmask.fits')
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile, nanmaskfile=nanmaskfile)
pass
[docs]
def blur_frames(self,
fact='auto',
types=['SCI', 'SCI_BG', 'REF', 'REF_BG'],
subdir='blurred'):
"""
Blur frames with a Gaussian filter.
Parameters
----------
fact : 'auto' or 'fix23' or float or dict of list of float or None, optional
FWHM (pix) of the Gaussian filter. If 'auto', will compute the FWHM
automatically based on the Nyquist sampling criterion for discrete
data, which is FWHM = lambda / 2.3D, where D = 5.2 m for NIRCam
coronagraphy and D = 6.5 m otherwise. If 'fix23', will always blur
the data with a Gaussian kernel of FWHM = 2.3 pix, so that even bad
pixels cause no more Fourier ripples. If dict of list of float,
then the dictionary keys must match the keys of the observations
database, and the number of entries in the lists must match the
number of observations in the corresponding concatenation. Then, a
different FWHM can be used for each observation. If None, the
corresponding observation will be skipped. The default is 'auto'.
types : list of str, optional
List of data types for which the frames shall be blurred. The
default is ['SCI', 'SCI_BG', 'REF', 'REF_BG'].
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'blurred'.
Returns
-------
None.
"""
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
# Loop through FITS files.
Nfitsfiles = len(self.database.obs[key])
for j in range(Nfitsfiles):
# Read FITS file.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
mask = ut.read_msk(maskfile)
# Skip file types that are not in the list of types.
fact_temp = None
if self.database.obs[key]['TYPE'][j] in types:
# Blur frames.
head, tail = os.path.split(fitsfile)
log.info(' --> Frame blurring: ' + tail)
try:
fact_temp = fact[key][j]
except:
fact_temp = fact
if self.database.obs[key]['TELESCOP'][j] == 'JWST':
if self.database.obs[key]['EXP_TYPE'][j] in ['NRC_CORON', 'NRC_TACONFIRM', 'NRC_TACQ']:
diam = 5.2
else:
diam = JWST_CIRCUMSCRIBED_DIAMETER
else:
raise UserWarning('Data originates from unknown telescope')
if fact_temp is not None:
if str(fact_temp) == 'auto':
wave_min = self.database.obs[key]['CWAVEL'][j] - self.database.obs[key]['DWAVEL'][j] # micron
fwhm_current = wave_min * 1e-6 / diam * 180. / np.pi * 3600. / self.database.obs[key]['PIXSCALE'][j] # pix
fwhm_desired = 2.3 # pix; see, e.g., Pawley 2006
fwhm_desired *= 1.5 # go to 1.5 times the theoretically required bluring to safely avoid numerical ringing effects
fact_temp = np.sqrt(fwhm_desired**2 - fwhm_current**2)
fact_temp /= np.sqrt(8. * np.log(2.)) # fix from Marshall
if str(fact_temp) == 'fix23':
fwhm_current = 1. # pix
fwhm_desired = 2.3 # pix; see, e.g., Pawley 2006
fact_temp = np.sqrt(fwhm_desired**2 - fwhm_current**2)
fact_temp /= np.sqrt(8. * np.log(2.)) # fix from Marshall
if np.isnan(fact_temp):
fact_temp = None
log.info(' --> Frame blurring: skipped')
continue
log.info(' --> Frame blurring: factor = %.3f' % fact_temp)
for k in range(data.shape[0]):
data[k] = gaussian_filter(data[k], fact_temp)
erro[k] = gaussian_filter(erro[k], fact_temp)
if mask is not None:
mask = gaussian_filter(mask, fact_temp)
else:
log.info(' --> Frame blurring: skipped')
# Write FITS file.
if fact_temp is None:
pass
else:
head_pri['BLURFWHM'] = fact_temp * np.sqrt(8. * np.log(2.)) # Factor to convert from sigma to FWHM
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
# Update spaceKLIP database.
if fact_temp is None:
self.database.update_obs(key, j, fitsfile, maskfile, blurfwhm=np.nan)
else:
self.database.update_obs(key, j, fitsfile, maskfile, blurfwhm=fact_temp * np.sqrt(8. * np.log(2.)))
pass
[docs]
def hpf(self,
size='auto',
types=['SCI', 'SCI_BG', 'REF', 'REF_BG'],
subdir='filtered'):
"""
Blur frames with a Gaussian filter.
Parameters
----------
size : 'auto' or float or dict of list of float or None, optional
FWHM (pix) of the Gaussian filter. If 'auto', will compute the FWHM
automatically based on the Nyquist sampling criterion for discrete
data, which is FWHM = lambda / 2.3D, where D = 5.2 m for NIRCam
coronagraphy and D = 6.5 m otherwise. If dict of list of float,
then the dictionary keys must match the keys of the observations
database, and the number of entries in the lists must match the
number of observations in the corresponding concatenation. Then, a
different FWHM can be used for each observation. If None, the
corresponding observation will be skipped. The default is 'auto'.
types : list of str, optional
List of data types for which the frames shall be blurred. The
default is ['SCI', 'SCI_BG', 'REF', 'REF_BG'].
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'blurred'.
Returns
-------
None.
"""
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
# Loop through FITS files.
Nfitsfiles = len(self.database.obs[key])
for j in range(Nfitsfiles):
# Read FITS file.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
mask = ut.read_msk(maskfile)
# Skip file types that are not in the list of types.
size_temp = None
if self.database.obs[key]['TYPE'][j] in types:
# High-pass filter frames.
head, tail = os.path.split(fitsfile)
log.info(' --> Frame filtering: ' + tail)
try:
size_temp = size[key]
except:
size_temp = float(size)
if size_temp is not None:
log.info(' --> Frame filtering: HPF FWHM = %.2f pix' % size_temp)
fourier_sigma_size = (data.shape[1] / size_temp) / (2. * np.sqrt(2. * np.log(2.)))
data = parallelized.high_pass_filter_imgs(data, numthreads=None, filtersize=fourier_sigma_size)
erro = parallelized.high_pass_filter_imgs(erro, numthreads=None, filtersize=fourier_sigma_size)
else:
log.info(' --> Frame filtering: skipped')
# Write FITS file.
if size_temp is None:
pass
else:
head_pri['HPFSIZE'] = size_temp
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile)
pass
[docs]
def inject_companions(self,
companions,
starfile,
spectral_type,
output_dir,
highpass=False,
subdir='test',
date='auto',
kwargs={}):
"""
Function to inject synthetic PSFs into a set of frames loaded from a dataset, and save the new frames with the
injected companion.
Parameters
----------
companions : list of list of three float, optional
List of companions to be injected.
For each companion, there should be a three element list containing
[RA offset (arcsec), Dec offset (arcsec), contrast].
raw_dataset : pyKLIP dataset
A pyKLIP dataset which companions will be injected into and KLIP
will be performed on.
injection_psf : 2D-array
The PSF of the companion to be injected.
injection_seps : 1D-array
List of separations to inject companions at (pixels).
injection_pas : 1D-array
List of position angles to inject companions at (degrees).
injection_spacing : int, None
Spacing between companions injected in a single image. If companions
are too close then it can pollute the recovered flux. Set to 'None'
to inject only one companion at a time (pixels).
injection_fluxes : 1D-array
Same size as injection_seps, units should correspond to the image
units. This is the *peak* flux of the injection.
true_companions : list of list of three float, optional
List of real companions to be masked before computing the raw contrast.
For each companion, there should be a three element list containing
[RA offset (pixels), Dec offset (pixels), mask radius (pixels)].
The default is None.
Returns
-------
None
"""
# Check input.
if not isinstance(companions[0], list):
if len(companions) == 3:
companions = [companions]
for i in range(len(companions)):
if len(companions[i]) != 3:
raise UserWarning('There should be three elements for each companion in the companions list')
Ncompanions = len(companions)
for _, key in enumerate(self.database.obs.keys()):
ww_type = list(self.database.obs[key]['TYPE'])
list_of_injected = []
all_injected = False
log.info('--> Concatenation ' + key)
#######################################################################################################################
filepaths, psflib_filepaths = get_pyklip_filepaths(self.database, key)
raw_dataset = JWSTData(filepaths, psflib_filepaths,
center_keywords=['STARCENX', 'STARCENY'])
for ww in range(len(ww_type)):
# Read input files and store values that we just want to save in the output_dir
fitsfile = self.database.obs[key]['FITSFILE'][ww]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][ww]
mask = ut.read_msk(maskfile)
crpix1 = self.database.obs[key]['CRPIX1'][ww]
crpix2 = self.database.obs[key]['CRPIX2'][ww]
maskcenx = self.database.obs[key]['MASKCENX'][ww]
maskceny = self.database.obs[key]['MASKCENY'][ww]
starcenx = self.database.obs[key]['STARCENX'][ww]
starceny = self.database.obs[key]['STARCENY'][ww]
head, tail = os.path.split(fitsfile)
# Write FITS file and PSF mask.
head_sci['CRPIX1'] = crpix1
head_sci['CRPIX2'] = crpix2
head_sci['MASKCENX'] = maskcenx
head_sci['MASKCENY'] = maskceny
head_sci['STARCENX'] = starcenx
head_sci['STARCENY'] = starceny
# Inject only into SCI type data
if ww_type[ww] == 'SCI':
# Convert the host star brightness from vegamag to MJy. Use an
# unocculted model PSF whose integrated flux is normalized to
# one in order to obtain the theoretical peak count of the
# star.
filt = self.database.obs[key]['FILTER'][ww]
# Get stellar magnitudes and filter zero points.
mstar, fzero, _, _ = get_stellar_magnitudes(starfile, spectral_type,
self.database.obs[key]['INSTRUME'][ww],
self.database.obs[key]['DETECTOR'][ww],
self.database.obs[key]['EXP_TYPE'][ww],
output_dir=output_dir,
**kwargs) # vegamag, Jy, erg/s/cm^2/A, W/m^2/um
# Compute the pixel area in steradian.
pxsc_arcsec = self.database.obs[key]['PIXSCALE'][ww] # arcsec
pxsc_rad = pxsc_arcsec / 3600. / 180. * np.pi # rad
pxar = pxsc_rad ** 2 # sr
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Make a copy of the dataset
dataset = copy.deepcopy(raw_dataset)
apername = self.database.obs[key]['APERNAME'][ww]
if 'planetfile' not in kwargs.keys() or kwargs['planetfile'] is None:
if starfile is not None and starfile.endswith('.txt'):
sed = read_spec_file(starfile)
else:
sed = None
else:
sed = read_spec_file(kwargs['planetfile'])
ww_sci = np.where(self.database.obs[key]['TYPE'] == 'SCI')[0]
if date is not None:
if date == 'auto':
date = fits.getheader(self.database.obs[key]['FITSFILE'][ww_sci[0]], 0)['DATE-BEG']
offsetpsf_func = JWST_PSF(apername,
filt,
date=date,
fov_pix=65,
oversample=2,
sp=sed,
use_coeff=False)
# Offset PSF that is not affected by the coronagraphic
# mask, but only the Lyot stop.
psf_no_coronmsk = offsetpsf_func.psf_off
log.info('--> Injecting companions, writing FITS files and updating spaceKLIP database: ')
# Loop over companions
for i in trange(Ncompanions):
# Initial guesses for the fit parameters.
guess_dx = companions[i][0] / pxsc_arcsec # pix
guess_dy = companions[i][1] / pxsc_arcsec # pix
guess_sep = np.sqrt(guess_dx ** 2 + guess_dy ** 2) # pix
guess_pa = np.rad2deg(np.arctan2(guess_dx, guess_dy)) # deg
guess_flux = companions[i][2] # contrast
roll_ref = self.database.obs[key]['ROLL_REF'][ww] # deg
# Get shift between star and coronagraphic mask
# position. If positive, the coronagraphic mask center
# is to the left/bottom of the star position.
_, _, _, _, _, _, _, _, _, _, maskoffs = ut.read_obs(self.database.obs[key]['FITSFILE'][ww])
# NIRCam.
if maskoffs is not None:
mask_xoff = -maskoffs[:, 0] # pix
mask_yoff = -maskoffs[:, 1] # pix
# Need to rotate by the roll angle (CCW) and flip
# the x-axis so that positive RA is to the left.
mask_raoff = -(mask_xoff * np.cos(np.deg2rad(roll_ref)) - mask_yoff * np.sin(
np.deg2rad(roll_ref))) # pix
mask_deoff = mask_xoff * np.sin(np.deg2rad(roll_ref)) + mask_yoff * np.cos(
np.deg2rad(roll_ref)) # pix
# Compute the true offset between the companion and
# the coronagraphic mask center.
sim_dx = guess_dx - mask_raoff # pix
sim_dy = guess_dy - mask_deoff # pix
sim_sep = np.sqrt(sim_dx ** 2 + sim_dy ** 2) * pxsc_arcsec # arcsec
sim_pa = np.rad2deg(np.arctan2(sim_dx, sim_dy)) # deg
# Take median of observation. Typically, each
# dither position is a separate observation.
sim_sep = np.median(sim_sep)
sim_pa = np.median(sim_pa)
# Otherwise.
else:
sim_sep = np.sqrt(guess_dx ** 2 + guess_dy ** 2) * pxsc_arcsec # arcsec
sim_pa = np.rad2deg(np.arctan2(guess_dx, guess_dy)) # deg
# Generate offset PSF for this roll angle. Do not add
# the V3Yidl angle as it has already been added to the
# roll angle by spaceKLIP. This is only for estimating
# the coronagraphic mask throughput!
offsetpsf_coronmsk = offsetpsf_func.gen_psf([sim_sep, sim_pa],
mode='rth',
PA_V3=roll_ref,
do_shift=False,
quick=True,
addV3Yidl=False)
# Coronagraphic mask throughput is not incorporated
# into the flux calibration of the JWST pipeline so
# that the companion flux from the detector pixels will
# be underestimated. Therefore, we need to scale the
# model offset PSF to account for the coronagraphic
# mask throughput (it becomes fainter). Compute scale
# factor by comparing a model PSF with and without
# coronagraphic mask.
scale_factor = np.sum(offsetpsf_coronmsk) / np.sum(psf_no_coronmsk)
# Normalize model offset PSF to a total integrated flux
# of 1 at infinity. Generates a new webbpsf model with
# PSF normalization set to 'exit_pupil'.
offsetpsf = offsetpsf_func.gen_psf([sim_sep, sim_pa],
mode='rth',
PA_V3=roll_ref,
do_shift=False,
quick=False,
addV3Yidl=False,
normalize='exit_pupil')
# Normalize model offset PSF by the flux of the star.
mcomp = mstar[filt] - 2.5*np.log10(guess_flux)
offsetpsf *= fzero[filt] / 10 ** (mcomp / 2.5) / 1e6 / pxar # MJy/sr
# Apply scale factor to incorporate the coronagraphic
# mask througput.
offsetpsf *= scale_factor
# For Test only, we apply a gaussian kernel to the psf we want to inject to test if we are able
# to recover it later when using Analysis.extract_companions
if 'sigma_xy' in kwargs.keys():
if 'theta_degrees' not in kwargs.keys():
kwargs['theta_degrees'] = 0
sigma_xy = kwargs['sigma_xy']
theta_degrees = kwargs['theta_degrees']
kernel = gaussian_kernel(sigma_x=sigma_xy[0], sigma_y=sigma_xy[1], theta_degrees=theta_degrees, n=6)
offsetpsf = scipy.ndimage.convolve(offsetpsf, kernel)
# Injected PSF needs to be a 3D array that matches dataset
inj_psf_3d = np.array([offsetpsf for k in range(dataset.input.shape[0])])
# Inject the PSF
fakes.inject_planet(frames=dataset.input, centers=dataset.centers, inputflux=inj_psf_3d,
astr_hdrs=dataset.wcs, radius=guess_sep, pa=guess_pa, stampsize=65)
data = dataset.input
# Write FITS file and PSF mask.
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
# Update spaceKLIP database.
self.database.update_obs(key, ww, fitsfile, maskfile, crpix1=crpix1, crpix2=crpix2, starcenx=starcenx, starceny=starceny,maskcenx=maskcenx, maskceny=maskceny)
[docs]
def update_nircam_centers(self,
force_siaf_center=False,
force_db_center=False):
"""
Checks SIAF PRD version against FITS header PRD version and updates
CRPIX if SIAF version is newer. Also accounts for filter-dependent
distortion. Might not be required for simulated data.
This step uses lookup tables of information derived from NIRCam
commissioning activities CAR-30 and CAR-31, by J. Leisenring and J. Girard,
and subsequent reanalyses using additional data from PSF SGD observations.
Parameters
----------
force_siaf_center : bool, optional
Force the use of the SIAF reference pixel position irrespective of
versions. The default is False
force_db_center : bool, optional
Force the use of the database reference pixel position irrespective
of versions. The default is False
Returns
-------
None.
"""
if force_siaf_center and force_db_center:
raise UserWarning('Both force_siaf_center and force_db_center are set to True. Only one can be True.')
if not force_db_center:
siaf = pysiaf.Siaf('NIRCAM')
# Use same RegEx as pysiaf to get sorted PRDS for later comparison.
prds = [prd for i, prd in enumerate(pysiaf.prd_list) if
bool(re.match(r"^[A-Z]-\d+", pysiaf.prd_list[i].split("PRDOPSSOC-")[1]))
is False] # PRDs matching format: PRODOSSOC-###
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
# Loop through FITS files.
for j in range(len(self.database.obs[key])):
# Skip file types that are not NIRCam coronagraphy.
if self.database.obs[key]['EXP_TYPE'][j] == 'NRC_CORON':
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
mask = ut.read_msk(maskfile)
# Update current reference pixel position.
head, tail = os.path.split(fitsfile)
log.info(' --> Update NIRCam coronagraphy centers: ' + tail)
# Get PRD version used for the current file.
file_prd_ver = head_pri['PRD_VER']
# use SIAF for reference pixel positions if its PRD is
# newer or if force_siaf_center unless force_db_center.
if (not force_db_center) and ((np.searchsorted(prds, file_prd_ver) < np.searchsorted(prds, pysiaf.JWST_PRD_VERSION))
or force_siaf_center):
log.info(' --> Update NIRCam coronagraphy centers: using MASKCEN from pysiaf')
apsiaf = siaf[self.database.obs[key]['APERNAME'][j]]
maskcenx = apsiaf.XSciRef
maskceny = apsiaf.YSciRef
else:
log.info(' --> Update NIRCam coronagraphy centers: using MASKCEN from database')
maskcenx = self.database.obs[key]['MASKCENX'][j]
maskceny = self.database.obs[key]['MASKCENY'][j]
# Get filter shift from Jarron.
try:
xshift_jarron, yshift_jarron = filter_shifts_jarron[self.database.obs[key]['FILTER'][j]]
except KeyError:
log.warning(' --> Update NIRCam coronagraphy centers: no filter shift found for ' + self.database.obs[key]['FILTER'][j])
xshift_jarron, yshift_jarron = 0., 0.
xoff, yoff = xshift_jarron, yshift_jarron
log.info(' --> Update NIRCam coronagraphy centers: old = (%.2f, %.2f), new = (%.2f, %.2f)' % (maskcenx, maskceny, maskcenx + xoff, maskceny + yoff))
maskcenx += xoff
maskceny += yoff
# Update spaceKLIP database.
# Change also CRPIX to be the same as maskcen.
self.database.update_obs(key, j, fitsfile, maskfile, crpix1=maskcenx, crpix2=maskceny, maskcenx=maskcenx, maskceny=maskceny)
pass
[docs]
def update_miri_offsets(self):
"""
Updates SCI frame X/Y offsets to zero for older MIRI coronagraphy datasets,
and updates REF frame offsets accordingly in FITS headers and the database.
"""
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
# Find science and reference files.
ww_sci = np.where(self.database.obs[key]['TYPE'] == 'SCI')[0]
ww_ref = np.where(self.database.obs[key]['TYPE'] == 'REF')[0]
ww_all = np.append(ww_sci, ww_ref)
# Set the SCI offsets to 0.
for j in ww_sci:
xoffset_fix = self.database.obs[key]['XOFFSET'][j] # arcsec
yoffset_fix = self.database.obs[key]['YOFFSET'][j] # arcsec
continue
# Loop through FITS files.
for j in ww_all:
# Skip file types that are not MIRI coronagraphy.
if self.database.obs[key]['EXP_TYPE'][j] in ['MIR_4QPM', 'MIR_LYOT']:
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
# Update MIRI mask offsets.
head, tail = os.path.split(fitsfile)
log.info(' --> Update MIRI coronagraphy offsets: ' + tail)
# Apply offset fixes to SCI and REF data.
xoffset_old = self.database.obs[key]['XOFFSET'][j] # arcsec
yoffset_old = self.database.obs[key]['YOFFSET'][j] # arcsec
xoffset_new = xoffset_old - xoffset_fix # arcsec
yoffset_new = yoffset_old - yoffset_fix # arcsec
log.info(f' --> Update MIRI coronagraphy offsets: old = ({xoffset_old:.3g}, {yoffset_old:.3g}), new = ({xoffset_new:.3g}, {yoffset_new:.3g})')
head_pri['XOFFSET'] = xoffset_new
head_pri['YOFFSET'] = yoffset_new
output_dir = os.path.dirname(self.database.obs[key]['FITSFILE'][j])
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile, xoffset=xoffset_new, yoffset=yoffset_new)
pass
[docs]
def recenter_frames(self,
method='fourier',
subpix_first_sci_only=False,
first_sci_only=True,
spectral_type='G2V',
shft_exp=1,
kwargs={},
highpass=False,
subdir='recentered'):
"""
Recenter frames so that the host star position is data.shape // 2. For
NIRCam coronagraphy, use a WebbPSF model to determine the star position
behind the coronagraphic mask for the first SCI frame. Then, shift all
other SCI and REF frames by the same amount. For MIRI coronagraphy, do
nothing. For all other data types, simply recenter the host star PSF.
Parameters
----------
method : 'fourier' or 'spline' (not recommended), optional
Method for shifting the frames. The default is 'fourier'.
subpix_first_sci_only : bool, optional
By default, all frames will be recentered to subpixel precision. If
'subpix_first_sci_only' is True, then only the first SCI frame will
be recentered to subpixel precision and all other SCI and REF
frames will only be recentered to integer pixel precision by
rolling the image. Can be helpful when working with poorly sampled
data to avoid another interpolation step if the 'align_frames'
routine is run subsequently. Only applicable to non-coronagraphic
data. The default is False.
first_sci_only : bool, optional
Recenter all files and not just the first SCI file in each concate-
nation. Only applicable to NIRCam coronagraphy. The default is
True.
spectral_type : str, optional
Host star spectral type for the WebbPSF model used to determine the
star position behind the coronagraphic mask. The default is 'G2V'.
shft_exp : float, optional
Take image to the given power before cross correlating for shifts,
default is 1. For instance, 1/2 helps align nircam bar/narrow data
(or other data with weird speckles).
kwargs : dict, optional
Keyword arguments for the scipy.ndimage.shift routine. The default
is {}.
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'recentered'.
Returns
-------
None.
"""
# DEPRECATION WARNING
raise DeprecationWarning('This function is deprecated. Use `calculate_centers` and `shift_frames` instead.')
# Update NIRCam coronagraphy centers, i.e., change SIAF CRPIX position
# to true mask center determined by Jarron
# self.update_nircam_centers() # Shall be run purposely by the user.
# Update MIRI coronagraphy offsets, of older data.
# self.update_miri_offsets() # Shall be run purposely by the user.
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
# Find science and reference files.
ww_sci = np.where(self.database.obs[key]['TYPE'] == 'SCI')[0]
ww_sci_ta = np.where(self.database.obs[key]['TYPE'] == 'SCI_TA')[0]
ww_ref = np.where(self.database.obs[key]['TYPE'] == 'REF')[0]
ww_ref_ta = np.where(self.database.obs[key]['TYPE'] == 'REF_TA')[0]
# Loop through FITS files.
ww_all = np.append(ww_sci, ww_ref)
ww_all = np.append(ww_all, ww_sci_ta)
ww_all = np.append(ww_all, ww_ref_ta)
shifts_all = []
for j in ww_all:
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
(data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs) = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
mask = ut.read_msk(maskfile)
# Recenter frames. Use different algorithms based on data type.
head, tail = os.path.split(fitsfile)
log.info(' --> Recenter frames: ' + tail)
if np.sum(np.isnan(data)) != 0:
raise UserWarning('Please replace nan pixels before attempting to recenter frames')
shifts = [] # Shift between star position and image center (data.shape // 2)
maskoffs_temp = [] # Shift between star and coronagraphic mask position
# SCI and REF data.
if j in ww_sci or j in ww_ref:
# NIRCam coronagraphy.
if self.database.obs[key]['EXP_TYPE'][j] in ['NRC_CORON']:
for k in range(data.shape[0]):
# For the first SCI frame, get the star position
# and the shift between the star and coronagraphic
# mask position.
if (not first_sci_only or j == ww_sci[0]) and k == 0:
xc, yc, xshift, yshift = self.find_nircam_centers(data0=data.copy(),
key=key,
j=j,
shft_exp=shft_exp,
spectral_type=spectral_type,
date=head_pri['DATE-BEG'],
output_dir=output_dir,
highpass=highpass)
# Apply the same shift to all SCI and REF frames.
shifts += [np.array([-(xc - (data.shape[-1] - 1.) / 2.), -(yc - (data.shape[-2] - 1.) / 2.)])]
maskoffs_temp += [np.array([xshift, yshift])]
data[k] = ut.imshift(data[k], [shifts[k][0], shifts[k][1]], method=method, kwargs=kwargs)
erro[k] = ut.imshift(erro[k], [shifts[k][0], shifts[k][1]], method=method, kwargs=kwargs)
if mask is not None:
# mask = ut.imshift(mask, [shifts[k][0], shifts[k][1]], method=method, kwargs=kwargs)
mask = spline_shift(mask, [shifts[k][1], shifts[k][0]], order=0, mode='constant', cval=np.nanmedian(mask))
xoffset = self.database.obs[key]['XOFFSET'][j] - self.database.obs[key]['XOFFSET'][ww_sci[0]] # arcsec
yoffset = self.database.obs[key]['YOFFSET'][j] - self.database.obs[key]['YOFFSET'][ww_sci[0]] # arcsec
# Update star center
starcenx = (data.shape[-1] - 1.) / 2. + 1. # 1-indexed
starceny = (data.shape[-2] - 1.) / 2. + 1. # 1-indexed
# Update mask center (using the shift of the first frame)
maskcenx = starcenx - xshift # 1-indexed
maskceny = starceny - yshift # 1-indexed
# MIRI coronagraphy.
elif self.database.obs[key]['EXP_TYPE'][j] in ['MIR_4QPM', 'MIR_LYOT']:
log.warning(' --> Recenter frames: not implemented for MIRI coronagraphy, skipped')
for k in range(data.shape[0]):
# Do nothing.
shifts += [np.array([0., 0.])]
maskoffs_temp += [np.array([0., 0.])]
xoffset = self.database.obs[key]['XOFFSET'][j] # arcsec
yoffset = self.database.obs[key]['YOFFSET'][j] # arcsec
# Star center and mask center stay the same
starcenx = self.database.obs[key]['STARCENX'][j] # 1-indexed
starceny = self.database.obs[key]['STARCENY'][j] # 1-indexed
maskcenx = self.database.obs[key]['MASKCENX'][j] # 1-indexed
maskceny = self.database.obs[key]['MASKCENY'][j] # 1-indexed
# Other data types.
else:
for k in range(data.shape[0]):
# Recenter SCI and REF frames to subpixel precision
# using the 'BCEN' routine from XARA.
# https://github.com/fmartinache/xara
if subpix_first_sci_only == False or (j == ww_sci[0] and k == 0):
pp = core.determine_origin(data[k], algo='BCEN')
shifts += [np.array([-(pp[0] - data.shape[-1]//2), -(pp[1] - data.shape[-2]//2)])]
maskoffs_temp += [np.array([0., 0.])]
data[k] = ut.imshift(data[k], [shifts[k][0], shifts[k][1]], method=method, kwargs=kwargs)
erro[k] = ut.imshift(erro[k], [shifts[k][0], shifts[k][1]], method=method, kwargs=kwargs)
else:
shifts += [np.array([0., 0.])]
maskoffs_temp += [np.array([0., 0.])]
# Recenter SCI and REF frames to integer pixel
# precision by rolling the image.
ww_max = np.unravel_index(np.argmax(data[k]), data[k].shape)
if ww_max != (data.shape[-2]//2, data.shape[-1]//2):
dx, dy = data.shape[-1]//2 - ww_max[1], data.shape[-2]//2 - ww_max[0]
shifts[-1][0] += dx
shifts[-1][1] += dy
data[k] = np.roll(np.roll(data[k], dx, axis=1), dy, axis=0)
erro[k] = np.roll(np.roll(erro[k], dx, axis=1), dy, axis=0)
xoffset = 0. # arcsec
yoffset = 0. # arcsec
# Update star center
starcenx = data.shape[-1]//2 + 1 # 1-indexed
starceny = data.shape[-2]//2 + 1 # 1-indexed
maskcenx = None
maskceny = None
# TA data.
if j in ww_sci_ta or j in ww_ref_ta:
for k in range(data.shape[0]):
# Center TA frames on the nearest pixel center. This
# pixel center is not necessarily the image center,
# which is why a subsequent integer pixel recentering
# is required.
p0 = np.array([0., 0.])
pp = minimize(ut.recenterlsq,
p0,
args=(data[k], method, kwargs))['x']
shifts += [np.array([pp[0], pp[1]])]
maskoffs_temp += [np.array([0., 0.])]
data[k] = ut.imshift(data[k], [shifts[k][0], shifts[k][1]], method=method, kwargs=kwargs)
erro[k] = ut.imshift(erro[k], [shifts[k][0], shifts[k][1]], method=method, kwargs=kwargs)
# Recenter TA frames to integer pixel precision by
# rolling the image.
ww_max = np.unravel_index(np.argmax(data[k]), data[k].shape)
if ww_max != (data.shape[-2]//2, data.shape[-1]//2):
dx, dy = data.shape[-1]//2 - ww_max[1], data.shape[-2]//2 - ww_max[0]
shifts[-1][0] += dx
shifts[-1][1] += dy
data[k] = np.roll(np.roll(data[k], dx, axis=1), dy, axis=0)
erro[k] = np.roll(np.roll(erro[k], dx, axis=1), dy, axis=0)
xoffset = 0. # arcsec
yoffset = 0. # arcsec
starcenx = data.shape[-1]//2 + 1 # 1-indexed
starceny = data.shape[-2]//2 + 1 # 1-indexed
maskcenx = None
maskceny = None
shifts = np.array(shifts)
shifts_all += [shifts]
maskoffs_temp = np.array(maskoffs_temp)
if imshifts is not None:
imshifts += shifts
else:
imshifts = shifts
if maskoffs is not None:
maskoffs += maskoffs_temp
else:
maskoffs = maskoffs_temp
# Compute shift distances.
dist = np.sqrt(np.sum(shifts[:, :2]**2, axis=1)) # pix
dist *= self.database.obs[key]['PIXSCALE'][j] * 1000 # mas
head, tail = os.path.split(self.database.obs[key]['FITSFILE'][j])
log.info(' --> Recenter frames: ' + tail)
log.info(' --> Recenter frames: median required shift = %.2f mas' % np.median(dist))
# Write FITS file and PSF mask.
head_pri['XOFFSET'] = xoffset # arcsec
head_pri['YOFFSET'] = yoffset # arcsec
head_sci['STARCENX'] = starcenx
head_sci['STARCENY'] = starceny
if maskcenx is not None:
head_sci['MASKCENX'] = maskcenx
head_sci['MASKCENY'] = maskceny
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile, xoffset=xoffset, yoffset=yoffset, starcenx=starcenx, starceny=starceny, maskcenx=maskcenx, maskceny=maskceny)
pass
[docs]
def calculate_centers_binary(self,
kwargs={},
subdir='recentered'):
"""
Recenter frames so that the host star position is data.shape // 2.
Parameters
----------
kwargs : dict, optional
Keyword arguments for the mcmc_tools.MCMCTools class. The default
is {}.
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'recentered'.
Returns
-------
None.
"""
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
output_dir2 = os.path.join(self.database.output_dir, subdir + '/residual')
if not os.path.exists(output_dir):
os.makedirs(output_dir)
database_temp = deepcopy(self.database.obs)
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
# Find science and reference files.
ww_sci = np.where(self.database.obs[key]['TYPE'] == 'SCI')[0]
ww_sci_ta = np.where(self.database.obs[key]['TYPE'] == 'SCI_TA')[0]
ww_ref = np.where(self.database.obs[key]['TYPE'] == 'REF')[0]
ww_ref_ta = np.where(self.database.obs[key]['TYPE'] == 'REF_TA')[0]
# Loop through FITS files.
ww_all = np.append(ww_sci, ww_ref)
ww_all = np.append(ww_all, ww_sci_ta)
ww_all = np.append(ww_all, ww_ref_ta)
shifts_all = []
for j in ww_all:
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
mask = ut.read_msk(maskfile)
nanmaskfile = self.database.obs[key]['NANMASKFILE'][j]
nanmask = ut.read_msk(nanmaskfile)
# Recenter frames. Use different algorithms based on data type.
head, tail = os.path.split(fitsfile)
sub_fitsfile = output_dir + '/' + tail
log.info(' --> Recenter frames: ' + tail)
if np.sum(np.isnan(data)) != 0:
raise UserWarning('Please replace nan pixels before attempting to recenter frames')
shifts = [] # shift between star position and image center (data.shape // 2)
maskoffs_temp = [] # shift between star and coronagraphic mask position
mask_shifts = [] # shift between mask position and image center (data.shape // 2)
# SCI and REF data.
if j in ww_sci or j in ww_ref:
kwargs['debug'] = True
MCMCTools = mcmc_tools.MCMCTools(data, type=self.database.obs[key]['TYPE'][j],kwargs=kwargs)
for k in range(data.shape[0]):
if k == 0:
# Initialize a function that can generate model offset PSFs.
filt = self.database.obs[key]['FILTER'][j]
apername = self.database.obs[key]['APERNAME'][j]
date = fits.getheader(self.database.obs[key]['FITSFILE'][ww_sci[0]], 0)['DATE-BEG']
offsetpsf_func = JWST_PSF(apername,
filt,
date=date,
fov_pix=data.shape[-1]+1 if data.shape[-1]%2==0 else data.shape[-1],
oversample=2,
sp=None,
use_coeff=False)
psf_no_coronmsk = offsetpsf_func.gen_psf([0, 0], return_oversample=False, quick=False)
psf_no_coronmsk /= np.nanmax(psf_no_coronmsk)
MCMCTools.run(np.median(data, axis=0).copy(),
psf_no_coronmsk,
x_guess=MCMCTools.x_guess,
y_guess=MCMCTools.y_guess,
r=MCMCTools.r,
nsteps=MCMCTools.nsteps,
ndim=len(MCMCTools.initial_guess),
nwalkers=MCMCTools.nwalkers,
initial_guess=MCMCTools.initial_guess,
limits=MCMCTools.limits,
verbose=MCMCTools.verbose,
size=MCMCTools.size,
binarity=MCMCTools.binarity,
filename=output_dir + '/' +self.database.obs[key]['FITSFILE'][j].split('/')[-1].split('.fits')[0])
# Apply the same shift to all SCI and REF frames.
shifts += [np.array([-(MCMCTools.best_fit_params[0] - (data.shape[-1]) // 2), -(MCMCTools.best_fit_params[1] - (data.shape[-2]) // 2)])]
mask_shifts += [np.array([0., 0.])]
maskoffs_temp += [np.array([0., 0.])]
xoffset = 0 # arcsec
yoffset = 0 # arcsec
starcenx = (data.shape[-1]) // 2. - shifts[0][0] + 1 # 1-indexed
starceny = (data.shape[-2]) // 2. - shifts[0][1] + 1 # 1-indexed
maskcenx = None
maskceny = None
shifts = np.array(shifts)
shifts_all += [shifts]
maskoffs_temp = np.array(maskoffs_temp)
maskshifts_temp = np.median(mask_shifts, axis=0)
if center_shift is not None:
center_shift += shifts
else:
center_shift = shifts
if center_mask is not None:
center_mask += maskshifts_temp
else:
center_mask = maskshifts_temp
if maskoffs is not None:
maskoffs += maskoffs_temp
else:
maskoffs = maskoffs_temp
# Compute shift distances.
dist = np.sqrt(np.sum(shifts[:, :2] ** 2, axis=1)) # pix
dist *= self.database.obs[key]['PIXSCALE'][j] * 1000 # mas
head, tail = os.path.split(self.database.obs[key]['FITSFILE'][j])
log.info(' --> Calculate centers: median measured shift = %.2f mas' % np.median(dist))
# Write FITS file and PSF mask.
head_pri['XOFFSET'] = xoffset # arcsec
head_pri['YOFFSET'] = yoffset # arcsec
head_sci['STARCENX'] = starcenx
head_sci['STARCENY'] = starceny
if maskcenx is not None:
head_sci['MASKCENX'] = maskcenx
head_sci['MASKCENY'] = maskceny
# Reading in CRPIX1/2 from database for updates from update_nircam_centers.
head_sci['CRPIX1'] = self.database.obs[key]['CRPIX1'][j]
head_sci['CRPIX2'] = self.database.obs[key]['CRPIX2'][j]
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
nanmaskfile = ut.write_msk(nanmaskfile, nanmask, fitsfile, '_nanmask.fits')
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile,
xoffset=xoffset, yoffset=yoffset,
starcenx=starcenx, starceny=starceny,
maskcenx=maskcenx, maskceny=maskceny,
center_shift=center_shift, center_mask=center_mask,
nanmaskfile=nanmaskfile)
pass
[docs]
def calculate_centers(self,
method='fourier',
use_ta=False,
subpix_first_sci_only=False,
first_sci_only=True,
spectral_type='G2V',
shft_exp=1,
kwargs={},
highpass=False,
subdir='recentered'):
"""
Calculate shifts necessary to recenter frames so that the host star
position is data.shape // 2. For NIRCam coronagraphy, use either a
WebbPSF model or target acquisition data (if available) to determine
the star's position behind the coronagraphic mask in the first SCI frame.
For MIRI coronagraphy, use target acquisition data (if available) to determine
the star's position behind the coronagraphic mask in the first SCI frame;
otherwise, no shift is applied. Then, shift all other SCI and REF frames by
the same amount. For all other data types, simply recenter the host star PSF.
Parameters
----------
method : 'fourier' or 'spline' (not recommended), optional
Method for shifting the frames. The default is 'fourier'.
use_ta: bool, optional
Use target acquisition data to determine the star's position behind
the coronagraphic mask in the first SCI frame? The default is False.
subpix_first_sci_only : bool, optional
By default, all frames will be recentered to subpixel precision. If
'subpix_first_sci_only' is True, then only the first SCI frame will
be recentered to subpixel precision and all other SCI and REF
frames will only be recentered to integer pixel precision by
rolling the image. Can be helpful when working with poorly sampled
data to avoid another interpolation step if the 'align_frames'
routine is run subsequently. Only applicable to non-coronagraphic
data. The default is False.
first_sci_only : bool, optional
Recenter all files and not just the first SCI file in each concate-
nation. Only applicable to NIRCam/MIRI coronagraphy. The default is
True.
spectral_type : str, optional
Host star spectral type for the WebbPSF model used to determine the
star position behind the coronagraphic mask. The default is 'G2V'.
shft_exp : float, optional
Take image to the given power before cross correlating for shifts, default is 1.
For instance, 1/2 helps align NIRCam bar/narrow data (or other data with weird speckles).
kwargs : dict, optional
Keyword arguments for the scipy.ndimage.shift routine. The default
is {}.
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'recentered'.
Returns
-------
None.
"""
# Update NIRCam coronagraphy centers, i.e., change SIAF CRPIX/MASKCEN position
# to true mask center determined by Jarron.
# self.update_nircam_centers() # Shall be run purposely by the user.
# Update MIRI coronagraphy offsets, of older data.
# self.update_miri_offsets() # Shall be run purposely by the user.
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Loop through concatenations.
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
# Find science and reference files.
ww_sci = np.where(self.database.obs[key]['TYPE'] == 'SCI')[0]
ww_sci_ta = np.where(self.database.obs[key]['TYPE'] == 'SCI_TA')[0]
ww_ref = np.where(self.database.obs[key]['TYPE'] == 'REF')[0]
ww_ref_ta = np.where(self.database.obs[key]['TYPE'] == 'REF_TA')[0]
# Loop through FITS files.
ww_all = np.append(ww_sci, ww_ref)
ww_all = np.append(ww_all, ww_sci_ta)
ww_all = np.append(ww_all, ww_ref_ta)
shifts_all = []
# Need to preserve the ww_sci[0] offsets in case we are operating only on first_sci_only
xoffset_orig = np.copy(self.database.obs[key]['XOFFSET'][ww_sci[0]])
yoffset_orig = np.copy(self.database.obs[key]['YOFFSET'][ww_sci[0]])
for j in ww_all:
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
(data, erro, pxdq, head_pri, head_sci, is2d,
align_shift, center_shift, align_mask, center_mask, maskoffs) = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
nanmaskfile = self.database.obs[key]['NANMASKFILE'][j]
mask = ut.read_msk(maskfile)
nanmask = ut.read_msk(nanmaskfile)
# Find center of frames. Use different algorithms based on data type.
head, tail = os.path.split(fitsfile)
log.info(' --> Calculating centers: ' + tail)
if np.sum(np.isnan(data)) != 0:
raise UserWarning('Please replace nan pixels before attempting to calculate centers')
shifts = [] # Shift between star position and image center (data.shape // 2).
maskoffs_temp = [] # Offset between star and coronagraphic mask center.
mask_shifts = [] # Shift applied to the PSF mask to match the recentered image.
# SCI and REF data.
if j in ww_sci or j in ww_ref:
# NIRCam coronagraphy.
if self.database.obs[key]['EXP_TYPE'][j] in ['NRC_CORON']:
for k in range(data.shape[0]):
# Mask center remains the same.
maskcenx = self.database.obs[key]['MASKCENX'][j] # 1-indexed
maskceny = self.database.obs[key]['MASKCENY'][j] # 1-indexed
xoffset = self.database.obs[key]['XOFFSET'][j] # arcsec
yoffset = self.database.obs[key]['YOFFSET'][j] # arcsec
# For the first SCI/REF frame, get the star position
# and the shift between the star and coronagraphic
# mask position.
if (not first_sci_only or j == ww_sci[0]) and k == 0:
# Use TA data to determine the star position behind the coronagraphic mask.
if use_ta:
xc, yc = ta_analysis(self.database.obs[key][j]['FITSFILE'],
plot=True, verbose=True,
output_dir=output_dir) # 0 indexed.
# Calculate the star - mask offset.
xshift, yshift = (xc - (maskcenx - 1), yc - (maskceny - 1))
# Use WebbPSF model to determine the star position behind the coronagraphic mask.
else:
xc, yc, xshift, yshift = self.find_nircam_centers(data0=data.copy(),
key=key,
j=j,
shft_exp=shft_exp,
spectral_type=spectral_type,
date=head_pri['DATE-BEG'],
output_dir=output_dir,
highpass=highpass)
# Apply the same shift to all SCI and REF frames and masks.
shifts += [np.array([-(xc - (data.shape[-1] - 1.) / 2.), -(yc - (data.shape[-2] - 1.) / 2.)])]
mask_shifts += [np.array([-(xc - (data.shape[-1] - 1.) / 2.), -(yc - (data.shape[-2] - 1.) / 2.)])]
maskoffs_temp += [np.array([xshift, yshift])] # pixels
if first_sci_only:
# Adjust centers based on first science frame. Typically these should be zero, but in some
# cases manual offsets have been applied in APT to ensure better coronagraph alignment.
# As such, these aren't "true" offsets from the coronagraph, and need to be removed.
xoffset -= xoffset_orig # arcsec
yoffset -= yoffset_orig # arcsec
log.info(' --> Calculate centers: adjusted XOFFSET/YOFFSET relative to first SCI frame.')
else:
# XOFFSET/YOFFSET remain the same.
log.info(' --> Calculate centers: no adjustment made to XOFFSET/YOFFSET, all SCI/REF files recentered individually.')
# Set star center (image center - shift).
starcenx = (data.shape[-1] - 1) / 2. - shifts[0][0] + 1 # 1-indexed
starceny = (data.shape[-2] - 1) / 2. - shifts[0][1] + 1 # 1-indexed
# MIRI coronagraphy.
elif self.database.obs[key]['EXP_TYPE'][j] in ['MIR_4QPM', 'MIR_LYOT']:
for k in range(data.shape[0]):
# Mask center remains the same.
maskcenx = self.database.obs[key]['MASKCENX'][j] # 1-indexed
maskceny = self.database.obs[key]['MASKCENY'][j] # 1-indexed
xoffset = self.database.obs[key]['XOFFSET'][j] # arcsec
yoffset = self.database.obs[key]['YOFFSET'][j] # arcsec
# For the first SCI/REF frame, get the star position
# and the shift between the star and coronagraphic
# mask position (if TA data avaliable else None).
if (not first_sci_only or j == ww_sci[0]) and k == 0:
# Use TA data to determine the star position behind the coronagraphic mask.
if use_ta:
xc, yc = ta_analysis(self.database.obs[key][j]['FITSFILE'],
plot=True, verbose=True,
output_dir=output_dir) # 0 indexed.
# Adjust for any cropping done prior.
xc -= self.database.obs[key]['CROP_SHIFTX'][j] # Adjust X for left cropping; 0 indexed.
yc -= self.database.obs[key]['CROP_SHIFTY'][j] # Adjust Y for bottom cropping; 0 indexed.
# Calculate the star - mask offset.
xshift, yshift = (xc - (maskcenx - 1), yc - (maskceny - 1))
# Do nothing.
else:
log.warning(' --> Calculate centers: not implemented for MIRI coronagraphy unless use_ta=True, skipped')
xc, yc = (data.shape[-1] - 1.) / 2., (data.shape[-2] - 1.) / 2. # Shift will be 0.
xshift, yshift = 0., 0.
# Apply the same shift to all SCI and REF frames and masks.
shifts += [np.array([-(xc - (data.shape[-1] - 1.) / 2.), -(yc - (data.shape[-2] - 1.) / 2.)])]
mask_shifts += [np.array([-(xc - (data.shape[-1] - 1.) / 2.), -(yc - (data.shape[-2] - 1.) / 2.)])]
maskoffs_temp += [np.array([xshift, yshift])]
if first_sci_only:
xoffset -= xoffset_orig # arcsec
yoffset -= yoffset_orig # arcsec
log.info(' --> Calculate centers: adjusted XOFFSET/YOFFSET relative to first SCI frame.')
else:
# XOFFSET/YOFFSET remain the same.
log.info(' --> Calculate centers: no adjustment made to XOFFSET/YOFFSET, all SCI/REF files measured individually.')
# Star center.
# If TA was used, shift from image center; otherwise, use provided header values.
starcenx = ((data.shape[-1] - 1) / 2. - shifts[0][0] + 1) if use_ta else self.database.obs[key]['STARCENX'][j] # 1-indexed
starceny = ((data.shape[-2] - 1) / 2. - shifts[0][1] + 1) if use_ta else self.database.obs[key]['STARCENY'][j] # 1-indexed
# Other data types.
else:
for k in range(data.shape[0]):
# Recenter SCI and REF frames to subpixel precision
# using the 'BCEN' routine from XARA.
# https://github.com/fmartinache/xara
if subpix_first_sci_only == False or (j == ww_sci[0] and k == 0):
pp = core.determine_origin(data[k], algo='BCEN')
# shifts += [np.array([-(pp[0] - (data.shape[-1]-1)/2), -(pp[1] - (data.shape[-2]-1)/2)])]
shifts += [np.array([-(pp[0] - (data.shape[-1])//2), -(pp[1] - (data.shape[-2])//2)])]
mask_shifts += [np.array([0., 0.])]
maskoffs_temp += [np.array([0., 0.])]
else:
shifts += [np.array([0., 0.])]
mask_shifts += [np.array([0., 0.])]
maskoffs_temp += [np.array([0., 0.])]
xoffset = 0. # arcsec
yoffset = 0. # arcsec
# Update star center (image center - shift).
starcenx = (data.shape[-1]) // 2. - shifts[0][0] + 1 # 1-indexed
starceny = (data.shape[-2]) // 2. - shifts[0][1] + 1 # 1-indexed
maskcenx = None
maskceny = None
# TA data.
if j in ww_sci_ta or j in ww_ref_ta:
for k in range(data.shape[0]):
# Center TA frames on the nearest pixel center. This
# pixel center is not necessarily the image center,
# which is why a subsequent integer pixel recentering
# is required.
p0 = np.array([0., 0.])
pp = minimize(ut.recenterlsq,
p0,
args=(data[k], method, kwargs))['x']
shifts += [np.array([pp[0], pp[1]])]
mask_shifts += [np.array([0., 0.])]
maskoffs_temp += [np.array([0., 0.])]
xoffset = 0. # arcsec
yoffset = 0. # arcsec
# Update star center (image center - shift).
starcenx = (data.shape[-1] - 1) / 2. - shifts[0][0] + 1 # 1-indexed
starceny = (data.shape[-2] - 1) / 2. - shifts[0][1] + 1 # 1-indexed
maskcenx = None
maskceny = None
shifts = np.array(shifts)
shifts_all += [shifts]
maskoffs_temp = np.array(maskoffs_temp)
maskshifts_temp = np.median(mask_shifts, axis=0)
if center_shift is not None:
center_shift += shifts
else:
center_shift = shifts
if center_mask is not None:
center_mask += maskshifts_temp
else:
center_mask = maskshifts_temp
if maskoffs is not None:
maskoffs += maskoffs_temp
else:
maskoffs = maskoffs_temp
# Compute shift distances.
dist = np.sqrt(np.sum(shifts[:, :2]**2, axis=1)) # pix
dist *= self.database.obs[key]['PIXSCALE'][j] * 1000 # mas
head, tail = os.path.split(self.database.obs[key]['FITSFILE'][j])
log.info(' --> Calculate centers: median measured shift = %.2f mas' % np.median(dist))
# Write FITS file and PSF mask.
head_pri['XOFFSET'] = xoffset # arcsec
head_pri['YOFFSET'] = yoffset # arcsec
head_sci['STARCENX'] = starcenx
head_sci['STARCENY'] = starceny
if maskcenx is not None:
head_sci['MASKCENX'] = maskcenx
head_sci['MASKCENY'] = maskceny
# Reading in CRPIX1/2 from database for updates from update_nircam_centers.
head_sci['CRPIX1'] = self.database.obs[key]['CRPIX1'][j]
head_sci['CRPIX2'] = self.database.obs[key]['CRPIX2'][j]
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
nanmaskfile = ut.write_msk(nanmaskfile, nanmask, fitsfile, '_nanmask.fits')
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile,
xoffset=xoffset, yoffset=yoffset,
starcenx=starcenx, starceny=starceny,
maskcenx=maskcenx, maskceny=maskceny,
center_shift=center_shift, center_mask=center_mask,
nanmaskfile=nanmaskfile)
pass
[docs]
def resample_frames(self, subdir='resampled'):
'''
Resample frames applying distortion correction to the data
Parameters
----------
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'recentered'.
Returns
-------
None.
'''
def to_container(model, target=True):
"""Convert to a ModelContainer of ImageModels for each plane"""
container = ModelContainer()
if target:
attr_list = [
'data', 'dq', 'err', 'zeroframe', 'area',
'var_poisson', 'var_rnoise', 'var_flat'
]
else:
# model = model
attr_list = [
'data'
]
for plane in range(model.data.shape[0]):
image = datamodels.ImageModel()
for attribute in attr_list:
try:
setattr(image, attribute, model.getarray_noinit(attribute)[plane])
except AttributeError:
pass
image.update(model)
try:
image.meta.wcs = model.meta.wcs
except AttributeError:
pass
container.append(image)
return container
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Loop through concatenations.
sci_good = False
for i, key in enumerate(self.database.obs.keys()):
# Read FITS file and PSF mask.
log.info('--> Concatenation ' + key)
# Find science and reference files.
ww_sci = np.where(self.database.obs[key]['TYPE'] == 'SCI')[0]
if len(ww_sci) == 0:
# raise UserWarning('Could not find any science files')
print(f'Warning: Could not find any science files. Skipping {key}.')
continue
sci_good = True
ww_ref = np.where(self.database.obs[key]['TYPE'] == 'REF')[0]
ww_all = np.append(ww_sci, ww_ref)
# Make ASN file.
log.info(' --> Resampling: ' + key)
resample_step.ResampleStep.blendheaders = False
for j in ww_all:
target_file = self.database.obs[key]['FITSFILE'][j]
# data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs = ut.read_obs(target_file)
(data, erro, pxdq, head_pri, head_sci, is2d,
align_shift, center_shift, align_mask, center_mask, maskoffs) = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
mask = ut.read_msk(maskfile)
with datamodels.open(target_file) as target:
# #storing wcs information for mask resampling
# wcs = target.meta.wcs
# wcsinfo = target.meta.wcsinfo
data_list = []
dq_list = []
err_list = []
for model in to_container(target):
resample_input = ModelContainer()
resample_input.append(model)
# Call the resample step to combine all psf-subtracted target images
# for compatibility with image3 pipeline use of ModelLibrary,
# convert ModelContainer to ModelLibrary
resample_library = ModelLibrary(resample_input, on_disk=False)
# Output is a single datamodel
result = resample_step.ResampleStep.call(resample_library)
data_list.append(result.data)
dq_list.append(result.dq)
err_list.append(result.err)
target.data = np.array([mask])
for model in to_container(target, target=False):
resample_input = ModelContainer()
resample_input.append(model)
# Call the resample step to combine all psf-subtracted target images
# for compatibility with image3 pipeline use of ModelLibrary,
# convert ModelContainer to ModelLibrary
resample_library = ModelLibrary(resample_input, on_disk=False)
# Output is a single datamodel
result = resample_step.ResampleStep.call(resample_library)
mask = result.data
# Write FITS file and PSF mask.
# fitsfile = ut.write_obs(target_file, output_dir, data_list, err_list, dq_list, head_pri, head_sci, is2d, imshifts,
# maskoffs)
fitsfile = ut.write_obs(target_file, output_dir, data_list, err_list, dq_list, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile)
pass
if not sci_good:
raise(ValueError('No science frames found in database concatenations'))
[docs]
def find_nircam_centers(self,
data0,
key,
j,
spectral_type='G2V',
shft_exp=1,
date=None,
output_dir=None,
fov_pix=65,
oversample=2,
use_coeff=False,
highpass=False,
save_figures=True,
plot_style=None):
"""
Find the star position behind the coronagraphic mask using a WebbPSF
model.
Parameters
----------
data0 : list
List of frame for which the star position shall be determined.
key : str
Database key of the observation containing the first frame in data0.
j : int
Database index of the observation containing the first frame in data0.
spectral_type : str, optional
Host star spectral type for the WebbPSF model used to determine the
star position behind the coronagraphic mask. The default is 'G2V'.
shft_exp : float, optional
Take image to the given power before cross correlating for shifts, default is 1.
date : str, optional
Observation date in the format 'YYYY-MM-DDTHH:MM:SS.MMM'. The
default is None.
output_dir : path, optional
Path of the directory where the data products shall be saved. The
default is None.
oversample : int, optional
Factor by which the WebbPSF model shall be oversampled. The
default is 2.
use_coeff : bool, optional
Use pre-computed coefficients to generate the WebbPSF model. The
default is False.
save_figures : bool, optional
Save the plots in a PDF?
Returns
-------
xc : float
Star x-position (pix, 0-indexed).
yc : float
Star y-position (pix, 0-indexed).
xshift : float
X-shift between star and coronagraphic mask position (pix).
yshift : float
Y-shift between star and coronagraphic mask position (pix).
"""
# Generate host star spectrum.
spectrum = webbpsf_ext.stellar_spectrum(spectral_type)
# Get true mask center.
maskcenx = self.database.obs[key]['MASKCENX'][j] - 1 # 0-indexed
maskceny = self.database.obs[key]['MASKCENY'][j] - 1 # 0-indexed
# Initialize JWST_PSF object. Use odd image size so that PSF is
# centered in pixel center.
log.info(' --> Recenter frames: generating WebbPSF image for absolute centering (this might take a while)')
FILTER = self.database.obs[key]['FILTER'][j]
APERNAME = self.database.obs[key]['APERNAME'][j]
kwargs = {
'fov_pix': fov_pix,
'oversample': oversample,
'date': date,
'use_coeff': use_coeff,
'sp': spectrum
}
psf = JWST_PSF(APERNAME, FILTER, **kwargs)
# Get SIAF reference pixel position.
apsiaf = psf.inst_on.siaf_ap
xsciref, ysciref = (apsiaf.XSciRef, apsiaf.YSciRef)
# Generate model PSF. Apply offset between SIAF reference pixel
# position and true mask center.
xoff = (maskcenx + 1) - xsciref
yoff = (maskceny + 1) - ysciref
model_psf = psf.gen_psf_idl((0, 0), coord_frame='idl', return_oversample=False, quick=True)
if not isinstance(highpass, bool):
highpass = float(highpass)
fourier_sigma_size = (model_psf.shape[0] / highpass) / (2. * np.sqrt(2. * np.log(2.)))
model_psf = parallelized.high_pass_filter_imgs(np.array([model_psf]), numthreads=None, filtersize=fourier_sigma_size)[0]
else:
if highpass:
raise NotImplementedError()
if not np.isnan(self.database.obs[key]['BLURFWHM'][j]):
gauss_sigma = self.database.obs[key]['BLURFWHM'][j] / np.sqrt(8. * np.log(2.))
model_psf = gaussian_filter(model_psf, gauss_sigma)
shift_list = []
count = 0
for data in data0:
# Get transmission mask.
yi, xi = np.indices(data.shape)
xidl, yidl = apsiaf.sci_to_idl(xi + 1 - xoff, yi + 1 - yoff)
mask = psf.inst_on.gen_mask_transmission_map((xidl, yidl), 'idl')
# Determine relative shift between data and model PSF. Iterate 3 times
# to improve precision.
xc, yc = (maskcenx, maskceny) # Start assuming that star is exactly at the center of the coronagraph.
for i in range(3):
# Crop data and transmission mask.
datasub, xsub_indarr, ysub_indarr = ut.crop_image(image=data,
xycen=(xc, yc),
npix=fov_pix,
return_indices=True)
masksub = ut.crop_image(image=mask,
xycen=(xc, yc),
npix=fov_pix)
if shft_exp == 1:
img1 = datasub * masksub
img2 = model_psf * masksub
else:
img1 = np.power(np.abs(datasub), shft_exp) * masksub
img2 = np.power(np.abs(model_psf), shft_exp) * masksub
# Determine relative shift between data and model PSF.
shift, error, phasediff = phase_cross_correlation(img1,
img2,
upsample_factor=1000,
normalization=None)
yshift, xshift = shift
# Update star position.
xc = np.mean(xsub_indarr) + xshift
yc = np.mean(ysub_indarr) + yshift
xshift, yshift = (xc - maskcenx, yc - maskceny)
shift_list.append([xshift, yshift])
log.info(' --> Recenter frames: star offset between frame %i and coronagraph center (dx, dy) = (%.3f, %.3f) pix' % (count, xshift, yshift))
count += 1
median_xshift, median_yshift = np.median(np.array(shift_list), axis=0)
std_xshift, std_yshift = np.std(np.array(shift_list), axis=0)
log.info(' --> Recenter frames: median star offset from coronagraph center (dx, dy) = (%.3f, %.3f) pix' % (median_xshift, median_yshift))
log.info(' --> Recenter frames: std for the star offset from coronagraph center (dx, dy) = (%.3f, %.3f) pix' % (std_xshift, std_yshift))
# Plot data, model PSF, and scene overview.
if output_dir is not None:
# Intialize the matplotlib style.
load_plt_style(plot_style)
fig, ax = plt.subplots(1, 3, figsize=(3 * 6.4, 1 * 4.8))
ax[0].imshow(datasub, origin='lower', cmap='Reds')
ax[0].contourf(masksub, levels=[0.00, 0.25, 0.50, 0.75], cmap='Greys_r', vmin=0., vmax=2., alpha=0.5)
ax[0].set_title('1. SCI frame & transmission mask')
ax[1].imshow(model_psf, origin='lower', cmap='Reds')
ax[1].contourf(masksub, levels=[0.00, 0.25, 0.50, 0.75], cmap='Greys_r', vmin=0., vmax=2., alpha=0.5)
ax[1].set_title('Model PSF & transmission mask')
ax[2].scatter((xsciref), (ysciref), marker='+', color='black', label='SIAF reference point')
ax[2].scatter((maskcenx + 1), (maskceny + 1), marker='x', color='skyblue', label='True mask center')
ax[2].scatter((xc + 1), (yc + 1), marker='*', color='red', label='Computed star position')
ax[2].set_aspect('equal')
xlim = ax[2].get_xlim()
ylim = ax[2].get_ylim()
xrng = xlim[1]-xlim[0]
yrng = ylim[1]-ylim[0]
if xrng > yrng:
ax[2].set_xlim(np.mean(xlim) - xrng, np.mean(xlim) + xrng)
ax[2].set_ylim(np.mean(ylim) - xrng, np.mean(ylim) + xrng)
else:
ax[2].set_xlim(np.mean(xlim) - yrng, np.mean(xlim) + yrng)
ax[2].set_ylim(np.mean(ylim) - yrng, np.mean(ylim) + yrng)
ax[2].set_xlabel('x-position [pix]')
ax[2].set_ylabel('y-position [pix]')
ax[2].legend(loc='upper right', fontsize=12)
ax[2].set_title('Scene overview (1-indexed)')
plt.tight_layout()
if save_figures:
output_file = os.path.split(self.database.obs[key]['FITSFILE'][j])[1]
output_file = output_file.replace('.fits', '.pdf')
output_file = os.path.join(output_dir, output_file)
plt.savefig(output_file)
log.info(f" Plot saved in {output_file}")
plt.show()
plt.close(fig)
# Return star position.
return xc, yc, median_xshift, median_yshift
[docs]
def align_frames(self,
method='fourier',
align_algo='leastsq',
mask_override=None,
msk_shp=8,
shft_exp=1,
align_to_file=None,
scale_prior=False,
kwargs={},
subdir='aligned',
save_figures=True,
plot_style=None):
"""
Align all SCI and REF frames to the first SCI frame.
Parameters
----------
method : 'fourier' or 'spline' (not recommended), optional
Method for shifting the frames. The default is 'fourier'.
align_algo : 'leastsq' or 'header'
Algorithm to determine the alignment offsets. Default is 'leastsq',
'header' assumes perfect header offsets.
mask_override : str, optional
Mask some pixels when cross correlating for shifts
msk_shp : int, optional
Shape (height or radius, or [inner radius, outer radius]) for custom mask invoked by "mask_override"
shft_exp : float, optional
Take image to the given power before cross correlating for shifts, default is 1. For instance, 1/2 helps align nircam bar/narrow data (or other data with weird speckles)
align_to_file : str, optional
Path to FITS file to which all images shall be aligned. Needs to be
a file with the same observational setup as all concatenations in
the spaceKLIP database. Hence, this can only be applied to one
observational setup at a time. The default is None.
scale_prior : bool, optional
If True, tries to find a better prior for the scale factor instead
of simply using 1. The default is False.
kwargs : dict, optional
Keyword arguments for the scipy.ndimage.shift routine. The default
is {}.
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'aligned'.
save_figures : bool, optional
Save the plots in a PDF?
Returns
-------
None.
"""
#### DEPRECATION WARNING ####
raise DeprecationWarning('This function is deprecated. Use `calculate_alignment` and `shift_frames` instead.')
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Useful masks for computing shifts:
def create_annulus_mask(h, w, center=None, radius=None):
if center is None: # use the middle of the image
center = (int(w/2), int(h/2))
if radius is None: # use the smallest distance between the center and image walls
radius = min(center[0], center[1], w-center[0], h-center[1])
Y, X = np.ogrid[:h, :w]
dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)
mask = (dist_from_center <= radius[0]) | (dist_from_center >= radius[1])
return mask
def create_circular_mask(h, w, center=None, radius=None):
if center is None: # use the middle of the image
center = (int(w/2), int(h/2))
if radius is None: # use the smallest distance between the center and image walls
radius = min(center[0], center[1], w-center[0], h-center[1])
Y, X = np.ogrid[:h, :w]
dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)
mask = dist_from_center <= radius
return mask
def create_rec_mask(h, w, center=None, z=None):
if center is None: # use the middle of the image
center = (int(w/2), int(h/2))
if z is None:
z = h//4
mask = np.zeros((h, w), dtype=bool)
mask[center[1]-z:center[1]+z, :] = True
return mask
# Loop through concatenations.
database_temp = deepcopy(self.database.obs)
sci_good = False
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
# Find science and reference files.
ww_sci = np.where(self.database.obs[key]['TYPE'] == 'SCI')[0]
if len(ww_sci) == 0:
# raise UserWarning('Could not find any science files')
print(f'Warning: Could not find any science files. Skipping {key}.')
continue
sci_good = True
ww_ref = np.where(self.database.obs[key]['TYPE'] == 'REF')[0]
ww_all = np.append(ww_sci, ww_ref)
# Loop through FITS files.
if align_to_file is not None:
try:
ref_image = pyfits.getdata(align_to_file, 'SCI')
except:
ref_image = pyfits.getdata(align_to_file, 0)
if ref_image.ndim == 3:
ref_image = np.nanmedian(ref_image, axis=0)
shifts_all = []
for j in ww_all:
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
mask = ut.read_msk(maskfile)
if mask_override is not None:
if mask_override == 'ann':
mask_circ = create_annulus_mask(data[0].shape[0], data[0].shape[1], radius=msk_shp)
elif mask_override == 'circ':
mask_circ = create_circular_mask(data[0].shape[0], data[0].shape[1], radius=msk_shp)
elif mask_override == 'rec':
mask_circ = create_rec_mask(data[0].shape[0], data[0].shape[1], z=msk_shp)
else:
raise ValueError('There are `circ` and `rec` custom masks available')
mask_temp = data[0].copy()
mask_temp[~mask_circ] = 1
mask_temp[mask_circ] = 0
elif mask is None:
mask_temp = np.ones_like(data[0])
else:
mask_temp = mask.copy()
# Align frames.
head, tail = os.path.split(fitsfile)
log.info(' --> Align frames: ' + tail)
if np.sum(np.isnan(data)) != 0:
raise UserWarning('Please replace nan pixels before attempting to align frames')
shifts = []
maskcen = []
crpix = []
for k in range(data.shape[0]):
# Take the first science frame as reference frame.
if j == ww_sci[0] and k == 0:
if align_to_file is None:
ref_image = data[k].copy()
pp = np.array([0., 0., 1.])
xoffset = self.database.obs[key]['XOFFSET'][j] # arcsec
yoffset = self.database.obs[key]['YOFFSET'][j] # arcsec
starcenx = self.database.obs[key]['STARCENX'][j] # pixels
starceny = self.database.obs[key]['STARCENY'][j] # pixels
pxsc = self.database.obs[key]['PIXSCALE'][j] # arcsec
# Align all other SCI and REF frames to the first science
# frame.
if align_to_file is not None or j != ww_sci[0] or k != 0:
# Calculate shifts relative to first frame, work in pixels.
xfirst = crpix1 + (xoffset/pxsc)
xoff_curr_pix = self.database.obs[key]['XOFFSET'][j]/self.database.obs[key]['PIXSCALE'][j]
xcurrent = self.database.obs[key]['CRPIX1'][j] + xoff_curr_pix
xshift = xfirst - xcurrent
yfirst = crpix2 + (yoffset/pxsc)
yoff_curr_pix = self.database.obs[key]['YOFFSET'][j]/self.database.obs[key]['PIXSCALE'][j]
ycurrent = self.database.obs[key]['CRPIX2'][j] + yoff_curr_pix
yshift = yfirst - ycurrent
# Get mask center and crpix to also register the shift in their locations.
maskcenx = self.database.obs[key]['MASKCENX'][j] # pixels
maskceny = self.database.obs[key]['MASKCENY'][j] # pixels
crpix1 = self.database.obs[key]['CRPIX1'][j] # pixels
crpix2 = self.database.obs[key]['CRPIX2'][j] # pixels
maskcen += [np.array([maskcenx, maskceny])]
crpix += [np.array([crpix1, crpix2])]
if scale_prior:
ww = mask < 0.5
sh = mask.shape
bw = 100
ww[:bw, :] = 0.
ww[:, :bw] = 0.
ww[sh[0] - bw:, :] = 0.
ww[:, sh[1] - bw:] = 0.
# plt.imshow(ww, origin='lower')
# plt.show()
scale = np.nanmedian(np.true_divide(ref_image, data[k])[ww])
if shft_exp != 1:
scale = np.power(np.abs(scale), shft_exp)
p0 = np.array([xshift, yshift, scale])
else:
p0 = np.array([xshift, yshift, 1.])
# Fix for weird numerical behaviour if shifts are small
# but not exactly zero.
if (np.abs(xshift) < 1e-3) and (np.abs(yshift) < 1e-3):
p0 = np.array([0., 0., p0[-1]])
if align_algo == 'leastsq':
if shft_exp != 1:
args = (np.power(np.abs(data[k]), shft_exp), np.power(np.abs(ref_image), shft_exp), mask_temp, method, kwargs)
else:
args = (data[k], ref_image, mask_temp, method, kwargs)
# Use header values to initiate least squares fit
pp = leastsq(ut.alignlsq,
p0,
args=args)[0]
elif align_algo == 'header':
# Just assume the header values are correct
pp = p0
# Append shifts to array and apply shift to image
# using defined method.
shifts += [np.array([pp[0], pp[1], pp[2]])]
if align_to_file is not None or j != ww_sci[0] or k != 0:
data[k] = ut.imshift(data[k], [shifts[k][0], shifts[k][1]], method=method, kwargs=kwargs)
erro[k] = ut.imshift(erro[k], [shifts[k][0], shifts[k][1]], method=method, kwargs=kwargs)
shifts = np.array(shifts)
maskcen = np.array(maskcen)
crpix = np.array(crpix)
if mask is not None:
if align_to_file is not None or j != ww_sci[0]:
temp = np.median(shifts, axis=0)
mask = spline_shift(mask, [temp[1], temp[0]], order=0, mode='constant', cval=np.nanmedian(mask))
shifts_all += [shifts]
if imshifts is not None:
imshifts += shifts[:, :-1]
else:
imshifts = shifts[:, :-1]
if maskoffs is not None:
maskoffs -= shifts[:, :-1]
else:
maskoffs = -shifts[:, :-1]
# Compute shift distances.
dist = np.sqrt(np.sum(shifts[:, :2]**2, axis=1)) # pix
dist *= self.database.obs[key]['PIXSCALE'][j]*1000 # mas
if j == ww_sci[0]:
dist = dist[1:]
log.info(' --> Align frames: median required shift = %.2f mas' % np.median(dist))
if self.database.obs[key]['TELESCOP'][j] == 'JWST':
ww = (dist < 1e-5) | (dist > 100.)
else:
ww = (dist < 1e-5)
if np.sum(ww) != 0:
if j == ww_sci[0]:
ww = np.append(np.array([False]), ww)
ww = np.where(ww == True)[0]
if align_algo != 'header':
log.warning(' --> The following frames might not be properly aligned: '+str(ww))
# Write FITS file and PSF mask.
head_pri['XOFFSET'] = xoffset # arcseconds
head_pri['YOFFSET'] = yoffset # arcseconds
head_sci['STARCENX'] = starcenx
head_sci['STARCENY'] = starceny
# Change mask location too: take first values of the shifts of the current fits file for now.
maskcenx = maskcen[0, 0] + shifts[0, 0]
maskceny = maskcen[0, 1] + shifts[0, 1]
crpix1 = crpix[0, 0] + shifts[0, 0]
crpix2 = crpix[0, 1] + shifts[0, 1]
head_sci['MASKCENX'] = maskcenx
head_sci['MASKCENY'] = maskceny
head_sci['CRPIX1'] = crpix1
head_sci['CRPIX2'] = crpix2
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile,
xoffset=xoffset, yoffset=yoffset,
starcenx=starcenx, starceny=starceny,
maskcenx=maskcenx, maskceny=maskceny,
crpix1=crpix1, crpix2=crpix2)
# Plot science frame alignment.
# Intialize the matplotlib style.
load_plt_style(plot_style)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
fig = plt.figure(figsize=(6.4, 4.8))
ax = plt.gca()
for index, j in enumerate(ww_sci):
ax.scatter(shifts_all[index][:, 0] * self.database.obs[key]['PIXSCALE'][j] * 1000,
shifts_all[index][:, 1] * self.database.obs[key]['PIXSCALE'][j] * 1000,
s=5, color=colors[index % len(colors)], marker='o',
label='PA = %.0f deg' % self.database.obs[key]['ROLL_REF'][j])
ax.axhline(0., color='gray', lw=1, zorder=-1) # set zorder to ensure lines are drawn behind all the scatter points
ax.axvline(0., color='gray', lw=1, zorder=-1)
ax.set_aspect('equal')
xlim = ax.get_xlim()
ylim = ax.get_ylim()
xrng = xlim[1]-xlim[0]
yrng = ylim[1]-ylim[0]
if xrng > yrng:
ax.set_xlim(np.mean(xlim) - xrng, np.mean(xlim) + xrng)
ax.set_ylim(np.mean(ylim) - xrng, np.mean(ylim) + xrng)
else:
ax.set_xlim(np.mean(xlim) - yrng, np.mean(xlim) + yrng)
ax.set_ylim(np.mean(ylim) - yrng, np.mean(ylim) + yrng)
ax.set_xlabel('x-shift [mas]')
ax.set_ylabel('y-shift [mas]')
ax.legend(loc='upper right')
ax.set_title(f'Science frame alignment\nfor {self.database.obs[key]["TARGPROP"][ww_sci[0]]}, {self.database.obs[key]["FILTER"][ww_sci[0]]}')
if save_figures:
output_file = os.path.join(output_dir, key + '_align_sci.pdf')
plt.savefig(output_file)
log.info(f" Plot saved in {output_file}")
plt.show()
plt.close(fig)
# Plot reference frame alignment.
if len(ww_ref) > 0:
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
fig = plt.figure(figsize=(6.4, 4.8))
ax = plt.gca()
seen = []
reps = []
syms = ['o', 'v', '^', '<', '>'] * (1 + len(ww_ref) // 5)
add = len(ww_sci)
for index, j in enumerate(ww_ref):
this = '%.3f_%.3f' % (database_temp[key]['XOFFSET'][j], database_temp[key]['YOFFSET'][j])
if this not in seen:
ax.scatter(shifts_all[index + add][:, 0] * self.database.obs[key]['PIXSCALE'][j] * 1000,
shifts_all[index + add][:, 1] * self.database.obs[key]['PIXSCALE'][j] * 1000,
s=5, color=colors[len(seen) % len(colors)], marker=syms[0],
label='dither %.0f' % (len(seen) + 1))
ax.hlines((-database_temp[key]['YOFFSET'][j] + yoffset) * 1000,
(-database_temp[key]['XOFFSET'][j] + xoffset) * 1000 - 4.,
(-database_temp[key]['XOFFSET'][j] + xoffset) * 1000 + 4.,
color=colors[len(seen) % len(colors)], lw=1)
ax.vlines((-database_temp[key]['XOFFSET'][j] + xoffset) * 1000,
(-database_temp[key]['YOFFSET'][j] + yoffset) * 1000 - 4.,
(-database_temp[key]['YOFFSET'][j] + yoffset) * 1000 + 4.,
color=colors[len(seen) % len(colors)], lw=1)
seen += [this]
reps += [1]
else:
ww = np.where(np.array(seen) == this)[0][0]
ax.scatter(shifts_all[index + add][:, 0] * self.database.obs[key]['PIXSCALE'][j] * 1000,
shifts_all[index + add][:, 1] * self.database.obs[key]['PIXSCALE'][j] * 1000,
s=5, color=colors[ww % len(colors)], marker=syms[reps[ww]])
reps[ww] += 1
ax.set_aspect('equal')
xlim = ax.get_xlim()
ylim = ax.get_ylim()
xrng = xlim[1]-xlim[0]
yrng = ylim[1]-ylim[0]
if xrng > yrng:
ax.set_xlim(np.mean(xlim) - xrng, np.mean(xlim) + xrng)
ax.set_ylim(np.mean(ylim) - xrng, np.mean(ylim) + xrng)
else:
ax.set_xlim(np.mean(xlim) - yrng, np.mean(xlim) + yrng)
ax.set_ylim(np.mean(ylim) - yrng, np.mean(ylim) + yrng)
ax.set_xlabel('x-shift [mas]')
ax.set_ylabel('y-shift [mas]')
ax.legend(loc='upper right', fontsize='small')
ax.set_title(f'Reference frame alignment\n showing {len(ww_ref)} PSF refs for {self.database.obs[key]["FILTER"][ww_ref[0]]}')
if save_figures:
output_file = os.path.join(output_dir, key + '_align_ref.pdf')
plt.savefig(output_file)
log.info(f" Plot saved in {output_file}")
plt.show()
plt.close(fig)
if not sci_good:
raise(ValueError('No science frames found in database concatenations'))
[docs]
def calculate_alignment(self,
method='fourier',
align_algo='leastsq',
mask_override=None,
msk_shp=8,
shft_exp=1,
align_to_file=None,
scale_prior=False,
kwargs={},
subdir='aligned',
save_figures=True,
plot_style=None):
"""
Calculate shifts necessary to align all SCI and REF frames to the first SCI frame.
Parameters
----------
method : 'fourier' or 'spline' (not recommended), optional
Method for shifting the frames. The default is 'fourier'.
align_algo : 'leastsq' or 'header'
Algorithm to determine the alignment offsets. Default is 'leastsq',
'header' assumes perfect header offsets.
mask_override : str, optional
Mask some pixels when cross correlating for shifts
msk_shp : int, optional
Shape (height or radius, or [inner radius, outer radius]) for custom mask invoked by "mask_override"
shft_exp : float, optional
Take image to the given power before cross correlating for shifts, default is 1. For instance, 1/2 helps align nircam bar/narrow data (or other data with weird speckles)
align_to_file : str, optional
Path to FITS file to which all images shall be aligned. Needs to be
a file with the same observational setup as all concatenations in
the spaceKLIP database. Hence, this can only be applied to one
observational setup at a time. The default is None.
scale_prior : bool, optional
If True, tries to find a better prior for the scale factor instead
of simply using 1. The default is False.
kwargs : dict, optional
Keyword arguments for the scipy.ndimage.shift routine. The default
is {}.
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'aligned'.
save_figures : bool, optional
Save the plots in a PDF?
Returns
-------
None.
"""
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Useful masks for computing shifts:
def create_annulus_mask(h, w, center=None, radius=None):
if center is None: # Use the middle of the image.
center = (int(w/2), int(h/2))
if radius is None: # Use the smallest distance between the center and image walls.
radius = min(center[0], center[1], w-center[0], h-center[1])
Y, X = np.ogrid[:h, :w]
dist_from_center = np.sqrt((X - center[0])**2 + (Y - center[1])**2)
mask = (dist_from_center <= radius[0]) | (dist_from_center >= radius[1])
return mask
def create_circular_mask(h, w, center=None, radius=None):
if center is None: # Use the middle of the image.
center = (int(w/2), int(h/2))
if radius is None: # Use the smallest distance between the center and image walls.
radius = min(center[0], center[1], w-center[0], h-center[1])
Y, X = np.ogrid[:h, :w]
dist_from_center = np.sqrt((X - center[0])**2 + (Y - center[1])**2)
mask = dist_from_center <= radius
return mask
def create_rec_mask(h, w, center=None, z=None):
if center is None: # Use the middle of the image.
center = (int(w/2), int(h/2))
if z is None:
z = h//4
mask = np.zeros((h, w), dtype=bool)
mask[center[1]-z:center[1]+z, :] = True
return mask
# Loop through concatenations.
database_temp = deepcopy(self.database.obs)
sci_good = False
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
# Find science and reference files.
ww_sci = np.where(self.database.obs[key]['TYPE'] == 'SCI')[0]
if len(ww_sci) == 0:
# raise UserWarning('Could not find any science files')
print(f'Warning: Could not find any science files. Skipping {key}.')
continue
sci_good = True
ww_ref = np.where(self.database.obs[key]['TYPE'] == 'REF')[0]
ww_all = np.append(ww_sci, ww_ref)
# Loop through FITS files.
if align_to_file is not None:
try:
ref_image = pyfits.getdata(align_to_file, 'SCI')
except:
ref_image = pyfits.getdata(align_to_file, 0)
if ref_image.ndim == 3:
ref_image = np.nanmedian(ref_image, axis=0)
shifts_all = []
for j in ww_all:
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
nanmaskfile = self.database.obs[key]['NANMASKFILE'][j]
mask = ut.read_msk(maskfile)
nanmask = ut.read_msk(nanmaskfile)
if mask_override is not None:
if mask_override == 'ann':
mask_circ = create_annulus_mask(data[0].shape[0], data[0].shape[1], center=(int(self.database.obs[key]['CRPIX1'][j]),int(self.database.obs[key]['CRPIX2'][j])), radius=msk_shp)
elif mask_override == 'circ':
mask_circ = create_circular_mask(data[0].shape[0], data[0].shape[1], center=(int(self.database.obs[key]['CRPIX1'][j]),int(self.database.obs[key]['CRPIX2'][j])), radius=msk_shp)
elif mask_override == 'rec':
mask_circ = create_rec_mask(data[0].shape[0], data[0].shape[1], center=(int(self.database.obs[key]['CRPIX1'][j]),int(self.database.obs[key]['CRPIX2'][j])), z=msk_shp)
else:
raise ValueError('There are `circ` and `rec` custom masks available')
mask_temp = data[0].copy()
mask_temp[~mask_circ] = 1
mask_temp[mask_circ] = 0
elif mask is None:
mask_temp = np.ones_like(data[0])
else:
mask_temp = mask.copy()
# Align frames.
head, tail = os.path.split(fitsfile)
log.info(' --> Align frames: ' + tail)
if np.sum(np.isnan(data)) != 0:
raise UserWarning('Please replace nan pixels before attempting to align frames')
shifts = []
mask_shifts = []
for k in range(data.shape[0]):
# Take the first science frame as reference frame.
if j == ww_sci[0] and k == 0:
if align_to_file is None:
ref_image = data[k].copy()
pp = np.array([0., 0., 1.])
xoffset = self.database.obs[key]['XOFFSET'][j] # arcsec
yoffset = self.database.obs[key]['YOFFSET'][j] # arcsec
starcenx = self.database.obs[key]['STARCENX'][j] # pixels, 1 indexed
starceny = self.database.obs[key]['STARCENY'][j] # pixels, 1 indexed
pxsc = self.database.obs[key]['PIXSCALE'][j] # arcsec
# Align all other SCI and REF frames to the first science
# frame.
if align_to_file is not None or j != ww_sci[0] or k != 0:
# Calculate shifts relative to first frame, work in pixels.
xfirst = starcenx + (xoffset/pxsc)
xoff_curr_pix = self.database.obs[key]['XOFFSET'][j]/self.database.obs[key]['PIXSCALE'][j]
xcurrent = self.database.obs[key]['STARCENX'][j] + xoff_curr_pix
xshift = xfirst - xcurrent
yfirst = starceny + (yoffset/pxsc)
yoff_curr_pix = self.database.obs[key]['YOFFSET'][j]/self.database.obs[key]['PIXSCALE'][j]
ycurrent = self.database.obs[key]['STARCENY'][j] + yoff_curr_pix
yshift = yfirst - ycurrent
if scale_prior:
ww = mask < 0.5
sh = mask.shape
bw = 100
ww[:bw, :] = 0.
ww[:, :bw] = 0.
ww[sh[0] - bw:, :] = 0.
ww[:, sh[1] - bw:] = 0.
# plt.imshow(ww, origin='lower')
# plt.show()
scale = np.nanmedian(np.true_divide(ref_image, data[k])[ww])
if shft_exp != 1:
scale = np.power(np.abs(scale), shft_exp)
p0 = np.array([xshift, yshift, scale])
else:
p0 = np.array([xshift, yshift, 1.])
# Fix for weird numerical behaviour if shifts are small
# but not exactly zero.
if (np.abs(xshift) < 1e-3) and (np.abs(yshift) < 1e-3):
p0 = np.array([0., 0., p0[-1]])
if align_algo == 'leastsq':
if shft_exp != 1:
args = (np.power(np.abs(data[k]), shft_exp), np.power(np.abs(ref_image), shft_exp), mask_temp, method, kwargs)
else:
args = (data[k], ref_image, mask_temp, method, kwargs)
# Use header values to initiate least squares fit.
pp = leastsq(ut.alignlsq,
p0,
args=args)[0]
elif align_algo == 'header':
# Just assume the header values are correct.
pp = p0
# Append shifts to array.
if align_to_file is not None or j != ww_sci[0] or k != 0:
shifts += [np.array([pp[0], pp[1], pp[2]])]
else:
shifts += [np.array([0, 0, 0])]
# Do the same for the mask
if mask is not None or nanmask is not None:
if align_to_file is not None or j != ww_sci[0]:
temp = np.median(shifts, axis=0)
mask_shifts = np.array([temp[1], temp[0]])
else:
mask_shifts = np.array([0, 0])
shifts = np.array(shifts)
mask_shifts = np.array(mask_shifts)
shifts_all += [shifts]
if align_shift is not None:
align_shift += shifts[:, :-1]
else:
align_shift = shifts[:, :-1]
if align_mask is not None:
align_mask += mask_shifts
else:
align_mask = mask_shifts
if maskoffs is not None:
maskoffs -= shifts[:, :-1]
else:
maskoffs = -shifts[:, :-1]
# Compute shift distances.
dist = np.sqrt(np.sum(shifts[:, :2]**2, axis=1)) # pix
dist *= self.database.obs[key]['PIXSCALE'][j] * 1000 # mas
if j == ww_sci[0]:
dist = dist[1:]
log.info(' --> Align frames: median required shift = %.2f mas' % np.median(dist))
if self.database.obs[key]['TELESCOP'][j] == 'JWST':
ww = (dist < 1e-5) | (dist > 100.)
else:
ww = (dist < 1e-5)
if np.sum(ww) != 0:
if j == ww_sci[0]:
ww = np.append(np.array([False]), ww)
ww = np.where(ww == True)[0]
if align_algo != 'header':
log.warning(' --> The following frames might not be properly aligned: '+str(ww))
# Write FITS file and PSF mask.
head_pri['XOFFSET'] = xoffset # arcseconds
head_pri['YOFFSET'] = yoffset # arcseconds
# Use alignment shifts to update STARCENX/Y with accurate position.
if not (j == ww_sci[0]): # Skip the very first science frame.
median_x_shift = np.nanmedian(shifts[:, 0]) # Use median shift of all frames.
median_y_shift = np.nanmedian(shifts[:, 1])
starcenx = self.database.obs[key]['STARCENX'][j] - median_x_shift
starceny = self.database.obs[key]['STARCENY'][j] - median_y_shift
head_sci['STARCENX'] = starcenx
head_sci['STARCENY'] = starceny
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
nanmaskfile = ut.write_msk(nanmaskfile, nanmask, fitsfile, '_nanmask.fits')
# Update spaceKLIP database.
if not (j == ww_sci[0]): # Skip updating the STARCENX/Y for very first science frame.
self.database.update_obs(key, j, fitsfile, maskfile,
xoffset=xoffset, yoffset=yoffset,
starcenx=starcenx, starceny=starceny,
align_shift=align_shift, align_mask=align_mask,nanmaskfile=nanmaskfile)
else:
self.database.update_obs(key, j, fitsfile, maskfile,
xoffset=xoffset, yoffset=yoffset,
align_shift=align_shift, align_mask=align_mask,nanmaskfile=nanmaskfile)
# Plot science frame alignment.
# Intialize the matplotlib style.
load_plt_style(plot_style)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
fig = plt.figure(figsize=(6.4, 4.8))
ax = plt.gca()
for index, j in enumerate(ww_sci):
ax.scatter(shifts_all[index][:, 0] * self.database.obs[key]['PIXSCALE'][j] * 1000,
shifts_all[index][:, 1] * self.database.obs[key]['PIXSCALE'][j] * 1000,
s=5, color=colors[index % len(colors)], marker='o',
label='PA = %.0f deg' % self.database.obs[key]['ROLL_REF'][j])
ax.axhline(0., color='gray', lw=1, zorder=-1) # set zorder to ensure lines are drawn behind all the scatter points
ax.axvline(0., color='gray', lw=1, zorder=-1)
ax.set_aspect('equal')
xlim = ax.get_xlim()
ylim = ax.get_ylim()
xrng = xlim[1]-xlim[0]
yrng = ylim[1]-ylim[0]
if xrng > yrng:
ax.set_xlim(np.mean(xlim) - xrng, np.mean(xlim) + xrng)
ax.set_ylim(np.mean(ylim) - xrng, np.mean(ylim) + xrng)
else:
ax.set_xlim(np.mean(xlim) - yrng, np.mean(xlim) + yrng)
ax.set_ylim(np.mean(ylim) - yrng, np.mean(ylim) + yrng)
ax.set_xlabel('x-shift [mas]')
ax.set_ylabel('y-shift [mas]')
ax.legend(loc='upper right')
ax.set_title(f'Science frame alignment\nfor {self.database.obs[key]["TARGPROP"][ww_sci[0]]}, {self.database.obs[key]["FILTER"][ww_sci[0]]}')
if save_figures:
output_file = os.path.join(output_dir, key + '_align_sci.pdf')
plt.savefig(output_file)
log.info(f" Plot saved in {output_file}")
plt.show()
plt.close(fig)
# Plot reference frame alignment.
if len(ww_ref) > 0:
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
fig = plt.figure(figsize=(6.4, 4.8))
ax = plt.gca()
seen = []
reps = []
syms = ['o', 'v', '^', '<', '>'] * (1 + len(ww_ref) // 5)
add = len(ww_sci)
for index, j in enumerate(ww_ref):
this = '%.3f_%.3f' % (database_temp[key]['XOFFSET'][j], database_temp[key]['YOFFSET'][j])
if this not in seen:
ax.scatter(shifts_all[index + add][:, 0] * self.database.obs[key]['PIXSCALE'][j] * 1000,
shifts_all[index + add][:, 1] * self.database.obs[key]['PIXSCALE'][j] * 1000,
s=5, color=colors[len(seen) % len(colors)], marker=syms[0],
label='dither %.0f' % (len(seen) + 1))
ax.hlines((-database_temp[key]['YOFFSET'][j] + yoffset) * 1000,
(-database_temp[key]['XOFFSET'][j] + xoffset) * 1000 - 4.,
(-database_temp[key]['XOFFSET'][j] + xoffset) * 1000 + 4.,
color=colors[len(seen) % len(colors)], lw=1)
ax.vlines((-database_temp[key]['XOFFSET'][j] + xoffset) * 1000,
(-database_temp[key]['YOFFSET'][j] + yoffset) * 1000 - 4.,
(-database_temp[key]['YOFFSET'][j] + yoffset) * 1000 + 4.,
color=colors[len(seen) % len(colors)], lw=1)
seen += [this]
reps += [1]
else:
ww = np.where(np.array(seen) == this)[0][0]
ax.scatter(shifts_all[index + add][:, 0] * self.database.obs[key]['PIXSCALE'][j] * 1000,
shifts_all[index + add][:, 1] * self.database.obs[key]['PIXSCALE'][j] * 1000,
s=5, color=colors[ww % len(colors)], marker=syms[reps[ww]])
reps[ww] += 1
ax.set_aspect('equal')
xlim = ax.get_xlim()
ylim = ax.get_ylim()
xrng = xlim[1]-xlim[0]
yrng = ylim[1]-ylim[0]
if xrng > yrng:
ax.set_xlim(np.mean(xlim) - xrng, np.mean(xlim) + xrng)
ax.set_ylim(np.mean(ylim) - xrng, np.mean(ylim) + xrng)
else:
ax.set_xlim(np.mean(xlim) - yrng, np.mean(xlim) + yrng)
ax.set_ylim(np.mean(ylim) - yrng, np.mean(ylim) + yrng)
ax.set_xlabel('x-shift [mas]')
ax.set_ylabel('y-shift [mas]')
ax.legend(loc='upper right', fontsize='small')
ax.set_title(f'Reference frame alignment\n showing {len(ww_ref)} PSF refs for {self.database.obs[key]["FILTER"][ww_ref[0]]}')
if save_figures:
output_file = os.path.join(output_dir, key + '_align_ref.pdf')
plt.savefig(output_file)
log.info(f" Plot saved in {output_file}")
plt.show()
plt.close(fig)
if not sci_good:
raise(ValueError('No science frames found in database concatenations'))
[docs]
def shift_frames(self,
method='fourier',
kwargs={},
subdir='shifted'):
"""
Calculate shifts necessary to align all SCI and REF frames to the first SCI frame.
Parameters
----------
method : 'fourier' or 'spline' (not recommended), optional
Method for shifting the frames. The default is 'fourier'.
kwargs : dict, optional
Keyword arguments for the scipy.ndimage.shift routine. The default is {}.
subdir : str, optional
Name of the directory where the data products shall be saved. The default is 'shifted'.
Returns
-------
None.
"""
# Set output directory.
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Loop through concatenations.
database_temp = deepcopy(self.database.obs)
for i, key in enumerate(self.database.obs.keys()):
log.info('--> Concatenation ' + key)
# Find science and reference files.
ww_sci = np.where(self.database.obs[key]['TYPE'] == 'SCI')[0]
ww_sci_ta = np.where(self.database.obs[key]['TYPE'] == 'SCI_TA')[0]
ww_ref = np.where(self.database.obs[key]['TYPE'] == 'REF')[0]
ww_ref_ta = np.where(self.database.obs[key]['TYPE'] == 'REF_TA')[0]
# Loop through FITS files.
ww_all = np.append(ww_sci, ww_ref)
ww_all = np.append(ww_all, ww_sci_ta)
ww_all = np.append(ww_all, ww_ref_ta)
# Load in previously calculated image shifts.
align_shift_star = self.database.obs[key]['ALIGN_SHIFT']
align_shift_mask = self.database.obs[key]['ALIGN_MASK']
center_shift_star = self.database.obs[key]['CENTER_SHIFT']
center_shift_mask = self.database.obs[key]['CENTER_MASK']
# Need to determine largest potential shift for padding purposes
shiftpad = ut.estimate_padding_for_shift(align_shift_star, center_shift_star)
log.info(f' --> Estimated padding for shifting: {shiftpad} pixels')
shifts_all = []
for j in ww_all:
# Read FITS file and PSF mask.
fitsfile = self.database.obs[key]['FITSFILE'][j]
data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
nanmaskfile = self.database.obs[key]['NANMASKFILE'][j]
mask = ut.read_msk(maskfile)
nanmask = ut.read_msk(nanmaskfile)
# Shift frames.
head, tail = os.path.split(fitsfile)
log.info(' --> Shift frames: ' + tail)
# TO IMPROVE DQ DATA
# The DQ also needs to be padded so we don't have array mismatches further down the line.
# However, it's not really meaningful to shift the DQ array by sub-pixel amounts.
# For now we will pad the array with zeros and assume that DQ is not important following shift_frames.
pxdq = np.pad(pxdq, pad_width=((0, 0), (shiftpad, shiftpad), (shiftpad, shiftpad)),
mode='constant', constant_values=0)
# SCI and REF data.
data_shift, erro_shift = [], []
if j in ww_sci or j in ww_ref:
maskcenx = self.database.obs[key]['MASKCENX'][j] # 1 indexed
maskceny = self.database.obs[key]['MASKCENY'][j] # 1 indexed
nanmaskcenx = self.database.obs[key]['NANMASKCENX'][j] # 1 indexed
nanmaskceny = self.database.obs[key]['NANMASKCENY'][j] # 1 indexed
# NIRCam coronagraphy.
if self.database.obs[key]['EXP_TYPE'][j] in ['NRC_CORON']:
shifts = [] # Store recenter + align shifts.
for k in range(data.shape[0]):
xshift = (align_shift_star[j][k][0] if not isinstance(align_shift_star[j], types.BuiltinFunctionType) else 0.0) + \
(center_shift_star[j][k][0] if not isinstance(center_shift_star[j], types.BuiltinFunctionType) else 0.0)
yshift = (align_shift_star[j][k][1] if not isinstance(align_shift_star[j], types.BuiltinFunctionType) else 0.0) + \
(center_shift_star[j][k][1] if not isinstance(center_shift_star[j], types.BuiltinFunctionType) else 0.0)
shifts += [np.array([xshift, yshift])]
# Recenter and align the SCI and REF frames.
data_shift += [ut.imshift(data[k], [shifts[k][0], shifts[k][1]],
pad_amount=shiftpad, method=method, kwargs=kwargs)]
erro_shift += [ut.imshift(erro[k], [shifts[k][0], shifts[k][1]],
pad_amount=shiftpad, method=method, kwargs=kwargs)]
data = np.array(data_shift)
erro = np.array(erro_shift)
if mask is not None:
mask_shift = center_shift_mask[j] + align_shift_mask[j]
mask = ut.imshift(mask, [mask_shift[0], mask_shift[1]], method='spline',
pad_amount=shiftpad, kwargs={'mode':'constant'})
# Update mask center.
maskcenx = self.database.obs[key]['MASKCENX'][j] + shifts[0][0] + shiftpad
maskceny = self.database.obs[key]['MASKCENY'][j] + shifts[0][1] + shiftpad
# Update star center.
starcenx = self.database.obs[key]['STARCENX'][j] + shifts[0][0] + shiftpad
starceny = self.database.obs[key]['STARCENY'][j] + shifts[0][1] + shiftpad
if nanmask is not None:
# Update mask center.
nanmaskcenx = self.database.obs[key]['NANMASKCENX'][j] + shifts[0][0] + shiftpad
nanmaskceny = self.database.obs[key]['NANMASKCENY'][j] + shifts[0][1] + shiftpad
# Update CRPIX values.
crpix1 = self.database.obs[key]['CRPIX1'][j] + shifts[0][0] + shiftpad
crpix2 = self.database.obs[key]['CRPIX2'][j] + shifts[0][1] + shiftpad
# MIRI coronagraphy.
elif self.database.obs[key]['EXP_TYPE'][j] in ['MIR_4QPM', 'MIR_LYOT']:
shifts = [] # Store recenter + align shifts.
for k in range(data.shape[0]):
xshift = (align_shift_star[j][k][0] if not isinstance(align_shift_star[j], types.BuiltinFunctionType) else 0.0) + \
(center_shift_star[j][k][0] if not isinstance(center_shift_star[j], types.BuiltinFunctionType) else 0.0)
yshift = (align_shift_star[j][k][1] if not isinstance(align_shift_star[j], types.BuiltinFunctionType) else 0.0) + \
(center_shift_star[j][k][1] if not isinstance(center_shift_star[j], types.BuiltinFunctionType) else 0.0)
shifts += [np.array([xshift, yshift])]
data_shift += [ut.imshift(data[k], [shifts[k][0], shifts[k][1]],
pad_amount=shiftpad, method=method, kwargs=kwargs)]
erro_shift += [ut.imshift(erro[k], [shifts[k][0], shifts[k][1]],
pad_amount=shiftpad, method=method, kwargs=kwargs)]
data = np.array(data_shift)
erro = np.array(erro_shift)
if mask is not None:
mask_shift = center_shift_mask[j] + align_shift_mask[j]
mask = ut.imshift(mask, [mask_shift[0], mask_shift[1]], method='spline',
pad_amount=shiftpad, kwargs={'mode':'constant'})
# Update mask center.
maskcenx = self.database.obs[key]['MASKCENX'][j] + shifts[0][0] + shiftpad
maskceny = self.database.obs[key]['MASKCENY'][j] + shifts[0][1] + shiftpad
if nanmask is not None:
nanmask_shift = center_shift_mask[j] + align_shift_mask[j]
nanmask = ut.imshift(nanmask, [nanmask_shift[0], nanmask_shift[1]], method='spline',
pad_amount=shiftpad, kwargs={'mode':'constant'})
# Update mask center.
nanmaskcenx = self.database.obs[key]['NANMASKCENX'][j] + shifts[0][0] + shiftpad
nanmaskceny = self.database.obs[key]['NANMASKCENY'][j] + shifts[0][1] + shiftpad
# Update star center.
starcenx = self.database.obs[key]['STARCENX'][j] + shifts[0][0] + shiftpad
starceny = self.database.obs[key]['STARCENY'][j] + shifts[0][1] + shiftpad
# Update CRPIX values.
crpix1 = self.database.obs[key]['CRPIX1'][j] + shifts[0][0] + shiftpad
crpix2 = self.database.obs[key]['CRPIX2'][j] + shifts[0][1] + shiftpad
# Other data types.
else:
shifts = []
for k in range(data.shape[0]):
xshift = align_shift_star[j][k][0] + center_shift_star[j][k][0]
yshift = align_shift_star[j][k][1] + center_shift_star[j][k][1]
shifts += [np.array([xshift, yshift])]
this_data = ut.imshift(data[k], [shifts[k][0], shifts[k][1]],
pad_amount=shiftpad, method=method, kwargs=kwargs)
this_erro = ut.imshift(erro[k], [shifts[k][0], shifts[k][1]],
pad_amount=shiftpad, method=method, kwargs=kwargs)
# Recenter SCI and REF frames to integer pixel
# precision by rolling the image.
ww_max = np.unravel_index(np.nanargmax(this_data), this_data.shape)
if ww_max != (this_data.shape[-2] // 2, this_data.shape[-1] // 2):
dx, dy = this_data.shape[-1] // 2 - ww_max[1], this_data.shape[-2] // 2 - ww_max[0]
shifts[-1][0] += dx
shifts[-1][1] += dy
data_shift += [np.roll(np.roll(this_data, dx, axis=1), dy, axis=0)]
erro_shift += [np.roll(np.roll(this_erro, dx, axis=1), dy, axis=0)]
else:
data_shift += [this_data]
erro_shift += [this_erro]
if nanmask is not None:
# nanmask shift preservesing 0/1 and NaN values.
nanmask = ut.imshift(nanmask, [shifts[k][0], shifts[k][1]], method='spline',
pad_amount=shiftpad, kwargs={'mode': 'constant'})
nanmask[np.isnan(nanmask)] = 1
nanmask = (nanmask >= 0.5).astype(np.float32)
nanmaskcenx = self.database.obs[key]['NANMASKCENX'][j] + shifts[0][0] + shiftpad
nanmaskceny = self.database.obs[key]['NANMASKCENY'][j] + shifts[0][1] + shiftpad
data = np.array(data_shift)
erro = np.array(erro_shift)
# Update star center.
starcenx = self.database.obs[key]['STARCENX'][j] + shifts[0][0] + shiftpad
starceny = self.database.obs[key]['STARCENY'][j] + shifts[0][1] + shiftpad
# Update CRPIX values.
crpix1 = self.database.obs[key]['CRPIX1'][j] + shifts[0][0] + shiftpad
crpix2 = self.database.obs[key]['CRPIX2'][j] + shifts[0][1] + shiftpad
# TA data.
if j in ww_sci_ta or j in ww_ref_ta:
shifts = []
for k in range(data.shape[0]):
xshift = align_shift_star[j][k][0] + center_shift_star[j][k][0]
yshift = align_shift_star[j][k][1] + center_shift_star[j][k][1]
shifts += [np.array([xshift, yshift])]
this_data = ut.imshift(data[k], [shifts[k][0], shifts[k][1]],
pad_amount=shiftpad, method=method, kwargs=kwargs)
this_erro = ut.imshift(erro[k], [shifts[k][0], shifts[k][1]],
pad_amount=shiftpad, method=method, kwargs=kwargs)
# Recenter TA frames to integer pixel precision by
# rolling the image.
ww_max = np.unravel_index(np.argmax(data[k]), data[k].shape)
if ww_max != (data.shape[-2] // 2, data.shape[-1] // 2):
dx, dy = data.shape[-1] // 2 - ww_max[1], data.shape[-2] // 2 - ww_max[0]
shifts[-1][0] += dx
shifts[-1][1] += dy
data_shift += [np.roll(np.roll(this_data, dx, axis=1), dy, axis=0)]
erro_shift += [np.roll(np.roll(this_erro, dx, axis=1), dy, axis=0)]
data = np.array(data_shift)
erro = np.array(erro_shift)
# Update star center.
starcenx = self.database.obs[key]['STARCENX'][j] + shifts[0][0] + shiftpad
starceny = self.database.obs[key]['STARCENY'][j] + shifts[0][1] + shiftpad
# Update CRPIX values.
crpix1 = self.database.obs[key]['CRPIX1'][j] + shifts[0][0] + shiftpad
crpix2 = self.database.obs[key]['CRPIX2'][j] + shifts[0][1] + shiftpad
shifts_all += [shifts]
# Write FITS file and PSF mask.
head_sci['STARCENX'] = starcenx
head_sci['STARCENY'] = starceny
head_sci['MASKCENX'] = maskcenx
head_sci['MASKCENY'] = maskceny
head_sci['NANMASKCENX'] = nanmaskcenx
head_sci['NANMASKCENY'] = nanmaskceny
head_sci['CRPIX1'] = crpix1
head_sci['CRPIX2'] = crpix2
# Save fits file.
fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d,
align_shift=align_shift, center_shift=center_shift, align_mask=align_mask,
center_mask=center_mask, maskoffs=maskoffs)
maskfile = ut.write_msk(maskfile, mask, fitsfile)
nanmaskfile = ut.write_msk(nanmaskfile, nanmask, fitsfile, '_nanmask.fits')
# Update spaceKLIP database.
self.database.update_obs(key, j, fitsfile, maskfile,
maskcenx=maskcenx, maskceny=maskceny,
nanmaskcenx=nanmaskcenx, nanmaskceny=nanmaskceny,
starcenx=starcenx, starceny=starceny,
crpix1=crpix1, crpix2=crpix2, nanmaskfile=nanmaskfile)
[docs]
def subtract_nircam_coron_background(self,
subdir='bgsub',
mask_snr_threshold=2,
r_excl_nfwhm=40,
q_clip=5.,
align_wrapped=True,
include_global_offset=True,
include_stellar_psf_component=True,
generate_plot=True,
save_model=False,
use_jbt_background=False,
bgmodel_dir=None,
background_sb={},
restrict_to=None,
plot_style=None):
"""
Fits and subtracts the astrophysical background in NIRCam coronagraphic
data following the procedure described in Lawson et al. (2024).
Note: This step should only be applied to data that has already been
aligned. Otherwise, it will crash.
For SW filters using a LW coronagraph, the field of view excludes the
neutral density squares. In this case, the astrophysical background and
the artificial background offset that we fit are fully degenerate in
the regions we consider (away from the coronagraph). Since SW
backgrounds should be low anyway, the default is to assume an
astrophysical background of zero here. Alternatively, the JWST
Backgrounds Tool can be used to estimate the background for affected
data instead (if use_jbt_background=True and the JWST Backgrounds Tool
is installed).
Parameters
----------
subdir : str, optional
Name of the sub-directory where the data products will be saved.
The default is 'bgsub'.
mask_snr_threshold : float, optional
SNR threshold for features to be masked during fitting of the
background. SNR is estimated using the ERR FITS extension. The
default is 2.
r_excl_nfwhm : float, optional
Radius (in units of the effective PSF FWHM) of the region around
the star to exclude from the fit. The default is 50.
q_clip : float, optional
After computing BG model residuals, exclude q_clip% of pixels from
both ends of the residual distribution before computing chisq. This
is intended to avoid over-/under-estimation of the background due
to unmasked sources or artifacts. Default is 5.
align_wrapped : bool, optional
Whether input data were aligned using a Fourier shift without
padding first (such that values wrapped at the edges). Default is
True.
include_global_offset : bool, optional
Whether to fit a uniform background offset along with the
astrophysical background model. This corrects for offsets induced
by ramp fitting or use of the median subtraction step. Default is
True.
include_stellar_psf_component : bool, optional
Whether to include a stellar PSF model component when optimizing
the background model. Default is True.
generate_plot : bool, optional
Whether to generate a plot showing the data before and after
subtraction along with the model and masked residuals. Default is
True.
save_model : bool, optional
Whether to save the optimized background model. Default is False.
use_jbt_background : bool, optional
Whether to use the JWST Backgrounds Tool to estimate the background
surface brightness for data without coverage of ND squares.
Requires that jwst_backgrounds is installed. Default is False.
bgmodel_dir : str, optional
Path to the directory containing the normalized background model
component FITS files to use (or to which they should be
downloaded). If None, uses spaceKLIP/resources/nircam_bg_models/.
Default is None.
background_sb : dict, optional
A dictionary of fixed background surface brightness (SB) values (in
the same units as the data) to adopt for any included concatenation
keys. For each key in database.obs, if background_sb[key] is None
or if the key is not in background_sb, the background SB will be
fit for all observations of that concatenation if possible.
Otherwise, background_sb[key] should be an array-like of float or
None having the same length as database.obs[key]. If
background_sb[key][j] is None, the jth frame's background SB will
be fit, otherwise it will be fixed to the value
background_sb[key][j]. Default is {}.
restrict_to : str or None, optional
Restrict the background subtraction to a specific key in the
database. Default is None.
"""
def get_jbt_background_est(t, ra, dec, wavelength):
"""
Uses the JWST Backgrounds tool to estimate the background surface
brightness at a given time, position, and wavelength.
"""
from jwst_backgrounds import jbt
from astropy.time import Time
tobs = Time(t, format='mjd')
bkg = jbt.background(ra, dec, wavelength)
calendar = bkg.bkg_data['calendar']
tobs0 = Time(f'{tobs.datetime.year}-01-01T00:00:00')
thisday = int(np.round((tobs.mjd-tobs0.mjd)+1))
Fbg = bkg.bathtub['total_thiswave'][np.where(thisday == calendar)[0][0]]
return Fbg
def background_objective(p, im, bg0, psf0, optmask, q=5):
"""
Objective function for fitting the multi-component background model
using LMFit.
"""
fbg, bg_offset, fpsf = [p[key] for key in p]
res = (im - fbg*bg0 - bg_offset - fpsf*psf0)[optmask]
low, upp = np.nanpercentile(res, [q, 100.-q])
return np.abs(res[(res >= low) & (res <= upp)])
def get_stellar_model_path(key, bgmodel_dir):
"""
Searches for the correct stellar model component on disk and
fetches from an online repository if needed. Returns the path to
the FITS file.
"""
psffile = f'{bgmodel_dir}{key}_psf0.fits'
if not os.path.exists(psffile):
with fits.open(f'https://github.com/kdlawson/nircam_bgsub_go4050/raw/main/nominal_bgmodels/{key}_psf0.fits') as hdul:
hdul.writeto(psffile)
return psffile
def get_background_model_path(key, bgmodel_dir):
"""
Searches for the correct background model component on disk and
fetches from an online repository if needed. Returns the path to
the FITS file.
"""
bgfile = f'{bgmodel_dir}{key}_background0.fits'
if not os.path.exists(bgfile):
with fits.open(f'https://github.com/kdlawson/nircam_bgsub_go4050/raw/main/nominal_bgmodels/{key}_background0.fits') as hdul:
hdul.writeto(bgfile)
return bgfile
output_dir = os.path.join(self.database.output_dir, subdir+'/')
if not os.path.isdir(output_dir):
os.makedirs(output_dir)
if bgmodel_dir is None:
bgmodel_dir = os.path.join(os.path.split(os.path.abspath(__file__))[0], 'resources/nircam_bg_models/')
if not os.path.exists(bgmodel_dir):
os.makedirs(bgmodel_dir)
# Copy input SB dictionary so we can fill out / change values as needed.
bg_sb_dict = background_sb.copy()
for key in self.database.obs:
if ((restrict_to is not None) and (restrict_to not in key)) or not np.any(np.isin(self.database.obs[key]['TYPE'], ['SCI', 'REF'])):
continue
log.info('--> Concatenation ' + key)
# Fill out dictionary of SB values, replacing None with a fixed value where fitting is not possible (SW filters with LW coronagraphs).
if (key not in bg_sb_dict) or (bg_sb_dict[key] is None):
bg_sb_dict[key] = np.repeat(None, len(self.database.obs[key]))
for j, entry in enumerate(self.database.obs[key]):
if (entry['DETECTOR'] == 'NRCA2') and (entry['CORONMSK'] in ['MASK335R', 'MASK430R', 'MASKLWB']) and (entry['SUBARRAY'] != 'FULL') and (bg_sb_dict[key][j] is None):
if use_jbt_background:
try:
bg_sb_dict[key][j] = get_jbt_background_est(entry['EXPSTART'], entry['TARG_RA'], entry['TARG_DEC'], entry['CWAVEL'])
except ModuleNotFoundError:
raise ModuleNotFoundError("""
JBT background estimation requires
the jwst_backgrounds package. Either
install jwst_backgrounds or rerun
with use_jbt_background=False.
""")
else:
bg_sb_dict[key][j] = 0.
db_tab = self.database.obs[key]
fwhm = db_tab['CWAVEL'][0] * 1e-6 / 5.2 * 180. / np.pi * 3600. / db_tab['PIXSCALE'][0]
blur_fwhm = db_tab['BLURFWHM'][0]
if np.isfinite(blur_fwhm):
blur_sigma = blur_fwhm/np.sqrt(8.*np.log(2.))
else:
blur_sigma = None
# Load the normalized stellar PSF model component
if include_stellar_psf_component:
psffile = get_stellar_model_path(key, bgmodel_dir=bgmodel_dir)
with fits.open(psffile) as hdul:
psf0_osamp, hpsf0 = hdul['OVERSAMP'].data, hdul['OVERSAMP'].header
osamp = hpsf0['OSAMP']
c_psf0_osamp = np.array([hpsf0['CRPIX1'], hpsf0['CRPIX2']])-1
if blur_sigma is not None:
psf0_osamp = gaussian_filter(psf0_osamp, blur_sigma*osamp)
# Load the normalized BG model component
if not np.all(bg_sb_dict[key] == 0):
bgfile = get_background_model_path(key, bgmodel_dir=bgmodel_dir)
with fits.open(bgfile) as hdul:
bg0_osamp, hbg0 = hdul['OVERSAMP'].data, hdul['OVERSAMP'].header
c_coron_bg0 = np.array([hbg0['CRPIX1'], hbg0['CRPIX2']])-1
osamp = hbg0['OSAMP']
# Apply any blurring used for the data:
if blur_sigma is not None:
bg0_osamp = gaussian_filter(bg0_osamp, blur_sigma*osamp)
files = db_tab['FITSFILE']
c_star = np.array([db_tab['STARCENX'][0], db_tab['STARCENY'][0]])-1
h1 = fits.getheader(files[0], ext=1)
ny, nx = h1['NAXIS2'], h1['NAXIS1']
rmap = np.sqrt((np.arange(0, nx, dtype=np.float32)-c_star[0])**2 + (np.arange(0, ny, dtype=np.float32)-c_star[1])[:, np.newaxis]**2)
rmap_nfwhm = rmap/fwhm # Map of each pixel's distance from the location of the star in units of the effective PSF FWHM
# With no alignment wrapping, the stellar PSF model should be the same for all frames, so we'll just set it up once outside the loop over files.
if not align_wrapped:
if include_stellar_psf_component:
c_star_osamp = c_star*osamp + 0.5*(osamp-1)
psf0_osamp_crop = webbpsf_ext.image_manip.crop_image(psf0_osamp, [ny*osamp, nx*osamp],
xyloc=c_psf0_osamp, delx=c_star_osamp[0]-(nx*osamp-1)/2.,
dely=c_star_osamp[1]-(ny*osamp-1)/2.)
psf0_crop = webbpsf_ext.image_manip.frebin(psf0_osamp_crop, scale=1/osamp, total=False)
else:
psf0_crop = np.zeros((ny, nx), dtype=np.float32)
for j,f in enumerate(files):
if db_tab[j]['TYPE'] not in ['SCI', 'REF']:
continue
head, tail = os.path.split(f)
log.info(' --> NIRCam Background Subtraction: ' + tail)
# Assume alignment and background differences between integrations are negligible, so we can use the higher SNR coadded exposure
with fits.open(f) as hdul:
ints = hdul['SCI'].data
errs = hdul['ERR'].data
h1 = hdul['SCI'].header
mask_offset = np.nanmedian(hdul['MASKOFFS'].data, axis=0)
imshift = np.nanmedian(hdul['ALIGN_SHIFT'].data + hdul['CENTER_SHIFT'].data, axis=0)
if np.ndim(ints) == 3:
med = np.nanmedian(ints, axis=0)
nsample = np.sum(np.isfinite(ints), axis=0)
err = np.sqrt(np.nansum(errs**2, axis=0))/nsample
else: # In case using coadded data saved with only two dims
med, err = ints, errs
c_coron = c_star - mask_offset # post-alignment mask center position
if align_wrapped:
if bg_sb_dict[key][j] == 0:
bg0_crop = np.zeros_like(med)
else:
c_coron_osamp_preshift = (c_coron-imshift)*osamp + 0.5*(osamp-1)
bg0_osamp_crop = webbpsf_ext.image_manip.crop_image(bg0_osamp, [ny*osamp, nx*osamp],
xyloc=c_coron_bg0, delx=c_coron_osamp_preshift[0]-(nx*osamp-1)/2.,
dely=c_coron_osamp_preshift[1]-(ny*osamp-1)/2.)
bg0_osamp_crop = ut.imshift(bg0_osamp_crop, imshift*osamp, pad=False)
bg0_crop = webbpsf_ext.image_manip.frebin(bg0_osamp_crop, scale=1/osamp, total=False) # Downsample to detector resolution
if include_stellar_psf_component:
c_star_osamp_preshift = (c_star-imshift)*osamp + 0.5*(osamp-1)
psf0_osamp_crop = webbpsf_ext.image_manip.crop_image(psf0_osamp, [ny*osamp, nx*osamp],
xyloc=c_psf0_osamp, delx=c_star_osamp_preshift[0]-(nx*osamp-1)/2.,
dely=c_star_osamp_preshift[1]-(ny*osamp-1)/2.)
psf0_osamp_crop = ut.imshift(psf0_osamp_crop, imshift*osamp, pad=False)
psf0_crop = webbpsf_ext.image_manip.frebin(psf0_osamp_crop, scale=1/osamp, total=False)
else:
psf0_crop = np.zeros_like(med)
else:
if bg_sb_dict[key][j] == 0:
bg0_crop = np.zeros_like(med)
else:
c_coron_osamp = c_coron*osamp + 0.5*(osamp-1)
bg0_osamp_crop = webbpsf_ext.image_manip.crop_image(bg0_osamp, [ny*osamp, nx*osamp],
xyloc=c_coron_bg0, delx=c_coron_osamp[0]-(nx*osamp-1)/2.,
dely=c_coron_osamp[1]-(ny*osamp-1)/2.)
bg0_crop = webbpsf_ext.image_manip.frebin(bg0_osamp_crop, scale=1/osamp, total=False) # Downsample to detector resolution
if bg_sb_dict[key][j] is None:
fbg0 = 1
fbg_vary = True
else:
fbg0 = bg_sb_dict[key][j]
fbg_vary = False
optmask = rmap_nfwhm > r_excl_nfwhm
snr = med/err # SNR estimate using FITS ERR extension
med_snr = np.nanmedian(snr[optmask]) # Median SNR in the nominal background area
low_snr = snr <= (med_snr+mask_snr_threshold) # High SNR features are those more than mask_snr_threshold sigma above the approximate BG SNR
optmask = optmask & low_snr
bg_offset0 = 0
fpsf0 = 1. if not include_stellar_psf_component else np.nansum((med-np.nanmedian(med[optmask])) * psf0_crop) / np.nansum((psf0_crop ** 2))
# Prepare the lmfit Parameters object with default values and sensible bounds
p = lmfit.Parameters()
p.add('fbg', value=fbg0, min=0, max=np.inf, vary=fbg_vary)
p.add('bg_offset', value=bg_offset0, min=-np.inf, max=0, vary=include_global_offset)
p.add('fpsf', value=fpsf0, min=0, max=fpsf0*10, vary=include_stellar_psf_component)
# Optimize the background model
result = lmfit.minimize(background_objective, p, args=(med, bg0_crop, psf0_crop, optmask, q_clip), method='powell')
pfin = result.params.valuesdict()
logstr = ', '.join([f"{key}:{value:.3f}" for key, value in pfin.items()])
log.info(' --> NIRCam Background Subtraction: ' + logstr)
# Compute the final background model and stellar PSF component:
bg = pfin['fbg']*bg0_crop + pfin['bg_offset']
psf = psf0_crop*pfin['fpsf']
f_out = output_dir+os.path.basename(os.path.normpath(f))
with fits.open(f) as hdul:
hdul[1].data -= bg # Subtract the BG model from the original file
hdul.writeto(f_out, overwrite=True) # Save to disk in the output directory
if save_model:
f_model = f_out.replace('.fits', '_background_model.fits')
hdu1 = fits.ImageHDU(bg, name='BG')
hdul_model = fits.HDUList([hdul[0], hdu1])
if include_stellar_psf_component:
hdul_model.append(fits.ImageHDU(psf, name='STELLAR_PSF'))
# Add fit params to header
hdul_model[0].header.update(pfin)
# Add all relevant settings to the header
hdul_model[0].header.update(dict(include_global_offset=include_global_offset,
mask_snr_threshold=mask_snr_threshold, r_excl_nfwhm=r_excl_nfwhm, q=q_clip,
include_stellar_psf_component=include_stellar_psf_component))
hdul_model.writeto(f_model, overwrite=True)
hdul_model.close()
if generate_plot:
# Intialize the matplotlib style.
load_plt_style(plot_style)
res = med - bg
res_psfsub = res - psf
low, upp = np.nanpercentile((res_psfsub)[optmask], [q_clip, 100.-q_clip])
res_inliers = np.where((res_psfsub >= low) & (res_psfsub <= upp) & optmask, res_psfsub, np.nan)
cmap = copy.copy(mpl.colormaps.get_cmap('RdBu_r'))
cmap.set_bad('white')
clim = np.array([-1, 1])*np.nanpercentile(np.abs(med), 90)
plot_mask = np.isfinite(med)
plot_ims = np.where(plot_mask, [med, bg, res_inliers, med-bg], np.nan)
fig,axes = plt.subplots(1, 4,figsize=(15, 3.5), sharex=True, sharey=True)
labels = ['Data', 'BG Model', 'Masked Residuals', 'Data (BG-subtracted)']
norm = mpl.colors.Normalize(*clim)
for ind,ax in enumerate(axes):
ax.imshow(plot_ims[ind], norm=norm, interpolation='None', origin='lower', cmap=cmap)
ax.set_title(labels[ind], pad=10)
fig.tight_layout(w_pad=1.00)
fig.colorbar(mpl.cm.ScalarMappable(norm, cmap), ax=axes, pad=0.015, label='[MJy / Sr]')
plt.savefig(output_dir+os.path.basename(os.path.normpath(f)).replace('.fits', '_background_model.pdf'), bbox_inches='tight')
plt.close(fig)
mask_in = db_tab['MASKFILE'][j]
mask = ut.read_msk(mask_in)
mask_out = ut.write_msk(mask_in, mask, f_out)
self.database.update_obs(key, j, f_out, mask_out)