Source code for spaceKLIP.imagetools

from __future__ import division

import matplotlib

# =============================================================================
# IMPORTS
# =============================================================================

# general imports
import os
import re
import logging
import sys
import json
import types
import copy
import lmfit
import numpy as np
from copy import deepcopy
from tqdm.auto import trange

# astropy imports
import pysiaf
import astropy.stats
import astropy.io.fits as pyfits
from astropy.io import fits

# plotting imports
import matplotlib.pyplot as plt
import matplotlib as mpl
from skimage.registration import phase_cross_correlation

# scipy imports
import scipy.ndimage
from scipy.ndimage import gaussian_filter, median_filter
from scipy.ndimage import shift as spline_shift
from scipy.optimize import leastsq, minimize
from scipy.interpolate import griddata

# webbpsf_ext imports
import webbpsf_ext
from webbpsf_ext import robust
from webbpsf_ext.coords import dist_image
from webbpsf_ext.webbpsf_ext_core import _transmission_map
from stpsf.constants import JWST_CIRCUMSCRIBED_DIAMETER

# spaceKLIP imports
from spaceKLIP import utils as ut
from spaceKLIP.psf import JWST_PSF
from spaceKLIP.xara import core
from spaceKLIP.utils import gaussian_kernel
from spaceKLIP.psf import get_offsetpsf
from spaceKLIP.pyklippipeline import get_pyklip_filepaths
from spaceKLIP.target_acq_tools import ta_analysis
from spaceKLIP.starphot import get_stellar_magnitudes, read_spec_file

# pyklip imports
import pyklip.fakes as fakes
from pyklip import parallelized
from pyklip.instruments.JWST import JWSTData

# jwst imports
import jwst.datamodels
from stdatamodels.jwst import datamodels
from jwst.datamodels import ModelContainer, ModelLibrary
from jwst.resample import resample_step

# Set up log.
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

# =============================================================================
# MAIN
# =============================================================================

# Load NIRCam true mask centers and filter-dependent shifts from Jarron.
path = 'resources/crpix_jarron.json'
path = os.path.join(os.path.split(os.path.abspath(__file__))[0], path)
file = open(path, 'r')
crpix_jarron = json.load(file)
file.close()
path = 'resources/filter_shifts_jarron.json'
path = os.path.join(os.path.split(os.path.abspath(__file__))[0], path)
file = open(path, 'r')
filter_shifts_jarron = json.load(file)
file.close()


[docs] class ImageTools(): """ The spaceKLIP image manipulation tools class. """ def __init__(self, database): """ Initialize the spaceKLIP image manipulation tools class. Parameters ---------- database : spaceKLIP.Database SpaceKLIP database on which the image manipulation steps shall be run. Returns ------- None. """ # Make an internal alias of the spaceKLIP database class. self.database = database pass def _get_output_dir(self, subdir): """Utility function to get full output dir path, and create it if needed""" # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) return output_dir def _iterate_function_over_files(self, types, file_transformation_function, restrict_to=None): """ Iterate some callable function over all files in a database. This is a repetitive pattern used in many of the image processing functions, so we abstract it here to reduce code repetition. The file transformation function should take one filename as an input, perform some transformation or image processing write out the file to some new path, and return the output filename. Any other arguments should be provided prior to passing in the function, for instance via functools.partial if necessary. """ # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): # if we limit to only processing some concatenations, check whether this concatenation matches the pattern if (restrict_to is not None) and (restrict_to not in key): continue log.info('--> Concatenation ' + key) # Loop through FITS files. nfitsfiles = len(self.database.obs[key]) for j in range(nfitsfiles): # Read FITS file. filename = self.database.obs[key]['FITSFILE'][j] # Only process files of the specified types. # (skip any files with types that are not in the list of types.) if self.database.obs[key]['TYPE'][j] in types: output_filename = file_transformation_function(filename) # Update spaceKLIP database. self.database.update_obs(key, j, output_filename)
[docs] def remove_frames(self, index=[0], types=['SCI', 'SCI_BG', 'REF', 'REF_BG'], subdir='removed'): """ Remove individual frames from the data. Parameters ---------- index : int or list of int or dict of list of list of int, optional Indices (0-indexed) of the frames to be removed. If int, then only a single frame will be removed from each observation. If list of int, then multiple frames can be removed from each observation. If dict of list of list of int, then the dictionary keys must match the keys of the observations database, and the number of entries in the lists must match the number of observations in the corresponding concatenation. Then, a different list of int can be used for each individual observation to remove different frames. The default is [0]. types : list of str, optional List of data types from which the frames shall be removed. The default is ['SCI', 'SCI_BG', 'REF', 'REF_BG']. subdir : str, optional Name of the directory where the data products shall be saved. The default is 'removed'. Returns ------- None. """ # Check input. if isinstance(index, int): index = [index] # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) # Loop through FITS files. nfitsfiles = len(self.database.obs[key]) for j in range(nfitsfiles): # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) nints = self.database.obs[key]['NINTS'][j] # Skip file types that are not in the list of types. if self.database.obs[key]['TYPE'][j] in types: # Remove frames. head, tail = os.path.split(fitsfile) log.info(' --> Frame removal: ' + tail) try: index_temp = index[key][j] except: index_temp = index.copy() log.info(' --> Frame removal: removing frame(s) ' + str(index_temp)) data = np.delete(data, index_temp, axis=0) erro = np.delete(erro, index_temp, axis=0) pxdq = np.delete(pxdq, index_temp, axis=0) if align_shift is not None: align_shift = np.delete(align_shift, index_temp, axis=0) if center_shift is not None: center_shift = np.delete(center_shift, index_temp, axis=0) # There is only a single value for align_mask and center_mask, so we won't reshape # This should be improved in the future once the PSF mask is properly handled at the integration level. # if align_mask is not None: # align_mask = np.delete(align_mask, index_temp, axis=0) # if center_mask is not None: # center_mask = np.delete(center_mask, index_temp, axis=0) if maskoffs is not None: maskoffs = np.delete(maskoffs, index_temp, axis=0) nints = data.shape[0] # Write FITS file and PSF mask. head_pri['NINTS'] = nints fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile, nints=nints) pass
[docs] def crop_frames(self, npix=1, types=['SCI', 'SCI_BG', 'REF', 'REF_BG'], subdir='cropped'): """ Crop all frames. Parameters ---------- npix : int or list of four int, optional Number of pixels to be cropped from the frames. If int, the same number of pixels will be cropped on each side. If list of four int, a different number of pixels can be cropped from the [left, right, bottom, top] of the frames. The default is 1. types : list of str, optional List of data types from which the frames shall be cropped. The default is ['SCI', 'SCI_BG', 'REF', 'REF_BG']. subdir : str, optional Name of the directory where the data products shall be saved. The default is 'cropped'. Returns ------- None. """ # Check input. if isinstance(npix, int): npix = [npix, npix, npix, npix] # left, right, bottom, top if len(npix) != 4: raise UserWarning('Parameter npix must either be an int or a list of four int (left, right, bottom, top)') # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) # Loop through FITS files. nfitsfiles = len(self.database.obs[key]) for j in range(nfitsfiles): # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) crpix1 = self.database.obs[key]['CRPIX1'][j] crpix2 = self.database.obs[key]['CRPIX2'][j] starcenx = self.database.obs[key]['STARCENX'][j] starceny = self.database.obs[key]['STARCENY'][j] maskcenx = self.database.obs[key]['MASKCENX'][j] maskceny = self.database.obs[key]['MASKCENY'][j] # Skip file types that are not in the list of types. if self.database.obs[key]['TYPE'][j] in types: # Crop frames. head, tail = os.path.split(fitsfile) log.info(' --> Frame cropping: ' + tail) sh = data.shape data = data[:, npix[2]:-npix[3], npix[0]:-npix[1]] erro = erro[:, npix[2]:-npix[3], npix[0]:-npix[1]] pxdq = pxdq[:, npix[2]:-npix[3], npix[0]:-npix[1]] if mask is not None: mask = mask[npix[2]:-npix[3], npix[0]:-npix[1]] crop_shiftx = npix[0] crop_shifty = npix[2] crpix1 -= crop_shiftx crpix2 -= crop_shifty starcenx -= crop_shiftx starceny -= crop_shifty maskcenx -= crop_shiftx maskceny -= crop_shifty log.info(' --> Frame cropping: old shape = ' + str(sh[1:]) + ', new shape = ' + str(data.shape[1:])) # Write FITS file and PSF mask. head_sci['CRPIX1'] = crpix1 head_sci['CRPIX2'] = crpix2 head_sci['STARCENX'] = starcenx head_sci['STARCENY'] = starceny head_sci['MASKCENX'] = maskcenx head_sci['MASKCENY'] = maskceny head_sci['CROP_SHIFTX'] = crop_shiftx # Store crop shift. head_sci['CROP_SHIFTY'] = crop_shifty fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile, crpix1=crpix1, crpix2=crpix2, starcenx=starcenx, starceny=starceny, maskcenx=maskcenx, maskceny=maskceny, crop_shiftx=crop_shiftx, crop_shifty=crop_shifty) pass
[docs] def pad_frames(self, npix=1, cval=np.nan, types=['SCI', 'SCI_BG', 'REF', 'REF_BG'], subdir='padded'): """ Pad all frames. Parameters ---------- npix : int or list of four int, optional Number of pixels to be padded around the frames. If int, the same number of pixels will be padded on each side. If list of four int, a different number of pixels can be padded on the [left, right, bottom, top] of the frames. The default is 1. cval : float, optional Fill value for the padded pixels. The default is nan. types : list of str, optional List of data types from which the frames shall be padded. The default is ['SCI', 'SCI_BG', 'REF', 'REF_BG']. subdir : str, optional Name of the directory where the data products shall be saved. The default is 'padded'. Returns ------- None. """ # Check input. if isinstance(npix, int): npix = [npix, npix, npix, npix] # left, right, bottom, top if len(npix) != 4: raise UserWarning('Parameter npix must either be an int or a list of four int (left, right, bottom, top)') # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) # Loop through FITS files. nfitsfiles = len(self.database.obs[key]) for j in range(nfitsfiles): # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) crpix1 = self.database.obs[key]['CRPIX1'][j] crpix2 = self.database.obs[key]['CRPIX2'][j] starcenx = self.database.obs[key]['STARCENX'][j] starceny = self.database.obs[key]['STARCENY'][j] maskcenx = self.database.obs[key]['MASKCENX'][j] maskceny = self.database.obs[key]['MASKCENY'][j] # Skip file types that are not in the list of types. if self.database.obs[key]['TYPE'][j] in types: # Crop frames. head, tail = os.path.split(fitsfile) log.info(' --> Frame padding: ' + tail) sh = data.shape data = np.pad(data, ((0, 0), (npix[2], npix[3]), (npix[0], npix[1])), mode='constant', constant_values=cval) erro = np.pad(erro, ((0, 0), (npix[2], npix[3]), (npix[0], npix[1])), mode='constant', constant_values=cval) pxdq = np.pad(pxdq, ((0, 0), (npix[2], npix[3]), (npix[0], npix[1])), mode='constant', constant_values=0) if mask is not None: mask = np.pad(mask, ((npix[2], npix[3]), (npix[0], npix[1])), mode='constant', constant_values=np.nan) crpix1 += npix[0] crpix2 += npix[2] starcenx += npix[0] starceny += npix[2] maskcenx += npix[0] maskceny += npix[2] log.info(' --> Frame padding: old shape = ' + str(sh[1:]) + ', new shape = ' + str(data.shape[1:]) + ', fill value = %.2f' % cval) # Write FITS file and PSF mask. head_sci['CRPIX1'] = crpix1 head_sci['CRPIX2'] = crpix2 head_sci['STARCENX'] = starcenx head_sci['STARCENY'] = starceny head_sci['MASKCENX'] = maskcenx head_sci['MASKCENY'] = maskceny fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile, crpix1=crpix1, crpix2=crpix2, starcenx=starcenx, starceny=starceny, maskcenx=maskcenx, maskceny=maskceny) pass
[docs] def coadd_frames(self, nframes=None, types=['SCI', 'SCI_BG', 'REF', 'REF_BG'], subdir='coadded'): """ Coadd frames. Parameters ---------- nframes : int, optional Number of frames to be coadded. Modulo frames will be removed. If None, will coadd all frames in an observation. The default is None. types : list of str, optional List of data types from which the frames shall be coadded. The default is ['SCI', 'SCI_BG', 'REF', 'REF_BG']. subdir : str, optional Name of the directory where the data products shall be saved. The default is 'coadded'. Returns ------- None. """ # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) # The starting value. nframes0 = nframes # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) # Loop through FITS files. nfitsfiles = len(self.database.obs[key]) for j in range(nfitsfiles): # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) nints = self.database.obs[key]['NINTS'][j] effinttm = self.database.obs[key]['EFFINTTM'][j] # If nframes is not provided, collapse everything. if nframes0 is None: nframes = nints # Skip file types that are not in the list of types. if self.database.obs[key]['TYPE'][j] in types: # Coadd frames. head, tail = os.path.split(fitsfile) log.info(' --> Frame coadding: ' + tail) ncoadds = data.shape[0] // nframes data = np.nanmedian(data[:nframes * ncoadds].reshape((nframes, ncoadds, data.shape[-2], data.shape[-1])), axis=0) erro_reshape = erro[:nframes * ncoadds].reshape((nframes, ncoadds, erro.shape[-2], erro.shape[-1])) nsample = np.sum(np.logical_not(np.isnan(erro_reshape)), axis=0) erro = np.true_divide(np.sqrt(np.nansum(erro_reshape**2, axis=0)), nsample) pxdq_temp = pxdq[:nframes * ncoadds].reshape((nframes, ncoadds, pxdq.shape[-2], pxdq.shape[-1])) pxdq = pxdq_temp[0] for k in range(1, nframes): pxdq = np.bitwise_or(pxdq, pxdq_temp[k]) if align_shift is not None: align_shift = np.mean(align_shift[:nframes * ncoadds].reshape((nframes, ncoadds, align_shift.shape[-1])), axis=0) if center_shift is not None: center_shift = np.mean(center_shift[:nframes * ncoadds].reshape((nframes, ncoadds, center_shift.shape[-1])), axis=0) # There is only a single value for align_mask and center_mask, so we won't reshape # This should be improved in the future once the PSF mask is properly handled at the integration level. # if align_mask is not None: # align_mask = np.mean(align_mask[:nframes * ncoadds].reshape((nframes, ncoadds, align_mask.shape[-1])), axis=0) # if center_mask is not None: # center_mask = np.mean(center_mask[:nframes * ncoadds].reshape((nframes, ncoadds, center_mask.shape[-1])), axis=0) if maskoffs is not None: maskoffs = np.mean(maskoffs[:nframes * ncoadds].reshape((nframes, ncoadds, maskoffs.shape[-1])), axis=0) nints = data.shape[0] effinttm *= nframes log.info(' --> Frame coadding: %.0f coadd(s) of %.0f frame(s)' % (ncoadds, nframes)) # Write FITS file and PSF mask. head_pri['NINTS'] = nints head_pri['EFFINTTM'] = effinttm fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile, nints=nints, effinttm=effinttm) pass
[docs] def subtract_median(self, types=['SCI', 'SCI_TA', 'SCI_BG', 'REF', 'REF_TA', 'REF_BG'], method='border', sigma=3.0, borderwidth=32, subdir='medsub'): """ Subtract the median from each frame. Clip everything brighter than 5- sigma from the background before computing the median. Parameters ---------- types : list of str, optional List of data types for which the median shall be subtracted. The default is ['SCI', 'SCI_TA', 'SCI_BG', 'REF', 'REF_TA', 'REF_BG']. subdir : str, optional Name of the directory where the data products shall be saved. The default is 'medsub'. method : str, optional - 'robust' for a robust median after masking out bright stars, - 'sigma_clipped' for another version of robust median using astropy sigma_clipped_stats on the whole image, - 'border' for robust median on the outer border region only, to ignore the bright stellar PSF in the center, - 'simple' for a simple np.nanmedian. sigma : float, optional number of standard deviations to use for the clipping limit in sigma_clipped_stats, if the robust option is selected. borderwidth : int, optional number of pixels to use when defining the outer border region, if the border option is selected. Default is to use the outermost 32 pixels around all sides of the image. Returns ------- None. """ # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) log.info(f'Median subtraction using method={method}') # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) # Loop through FITS files. nfitsfiles = len(self.database.obs[key]) for j in range(nfitsfiles): # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) # Skip file types that are not in the list of types. if self.database.obs[key]['TYPE'][j] in types: # Subtract median. head, tail = os.path.split(fitsfile) log.info(' --> Median subtraction: ' + tail) data_temp = data.copy() # if self.database.obs[key]['TELESCOP'][j] == 'JWST' and self.database.obs[key]['INSTRUME'][j] == 'NIRCAM': # data_temp[pxdq != 0] = np.nan data_temp[pxdq & 1 == 1] = np.nan # else: # data_temp[pxdq & 1 == 1] = np.nan if method == 'robust': # Robust median, using a method by Jens bg_med = np.nanmedian(data_temp, axis=(1, 2), keepdims=True) bg_std = robust.medabsdev(data_temp, axis=(1, 2), keepdims=True) bg_ind = data_temp > (bg_med + 5. * bg_std) # clip bright PSFs for final calculation data_temp[bg_ind] = np.nan bg_median = np.nanmedian(data_temp, axis=(1, 2), keepdims=True) elif method == 'sigma_clipped': # Robust median using astropy.stats.sigma_clipped_stats if len(data.shape) == 2: mean, median, stddev = astropy.stats.sigma_clipped_stats(data_temp, sigma=sigma) elif len(data.shape) == 3: bg_median = np.zeros([data.shape[0], 1, 1]) for iint in range(data.shape[0]): mean_i, median_i, stddev_i = astropy.stats.sigma_clipped_stats(data[iint]) bg_median[iint] = median_i else: raise NotImplementedError("data must be 2d or 3d for this method") elif method == 'border': # Use only the outer border region of the image, near the edges of the FOV shape = data.shape if len(shape) == 2: # only one int y, x = np.indices(shape) bordermask = (x < borderwidth) | (x > shape[1] - borderwidth) | (y < borderwidth) | (y > shape[0] - borderwidth) mean, bg_median, stddev = astropy.stats.sigma_clipped_stats(data[bordermask]) elif len(shape) == 3: # perform robust stats on border region of each int y, x = np.indices(data.shape[1:]) bordermask = (x < borderwidth) | (x > shape[1] - borderwidth) | (y < borderwidth) | (y > shape[0] - borderwidth) bg_median = np.zeros([shape[0], 1, 1]) for iint in range(shape[0]): mean_i, median_i, stddev_i = astropy.stats.sigma_clipped_stats(data[iint][bordermask]) bg_median[iint] = median_i else: raise NotImplementedError("data must be 2d or 3d for this method") else: # Plain vanilla median of the image bg_median = np.nanmedian(data_temp, axis=(1, 2), keepdims=True) data -= bg_median log.info(' --> Median subtraction: mean of frame median = %.2f' % np.mean(bg_median)) # Write FITS file and PSF mask. fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile)
[docs] def subtract_background_godoy(self, types=['SCI', 'REF'], subdir='bgsub'): """ Subtract the corresponding background observations from the SCI and REF data in the spaceKLIP database using a method developed by Nico Godoy. Parameters ---------- types : list of str File types to run the subtraction over. subdir : str, optional Name of the directory where the data products shall be saved. The default is 'bgsub'. Returns ------- None. """ # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): # Load in bunch of stuff # Find science, reference, and background files. ww_sci = np.where(self.database.obs[key]['TYPE'] == 'SCI')[0] ww_ref = np.where(self.database.obs[key]['TYPE'] == 'REF')[0] ww_sci_bg = np.where(self.database.obs[key]['TYPE'] == 'SCI_BG')[0] ww_ref_bg = np.where(self.database.obs[key]['TYPE'] == 'REF_BG')[0] # Loop over science and reference files for typ in types: if typ == 'SCI': ww, ww_bg = ww_sci, ww_sci_bg elif typ == 'REF': ww, ww_bg = ww_ref, ww_ref_bg # Gather background files. if len(ww_bg) == 0: raise UserWarning('Could not find any background files.') else: bg_data, bg_erro, bg_pxdq = [], [], [] for j in ww_bg: # Read background file. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile) # Compute median science background. bg_data += [data] bg_erro += [erro] bg_pxdq += [pxdq] bg_data, bg_erro, bg_pxdq = np.array(bg_data), np.array(bg_erro), np.array(bg_pxdq) # If multiple files, take the median. Otherwise, carry on. if bg_data.ndim == 4: bg_data = np.nanmedian(bg_data, axis=0) # Loop over individual files for j in ww: # Read FITS file. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile) # Subtract the background per frame head, tail = os.path.split(fitsfile) log.info(' --> Background subtraction: ' + tail) data -= bg_data # Loop over integrations data_bg_sub = np.empty_like(data) for k in range(data.shape[0]): # Subtract median of corresponding background frame from the frame bg_submed = bg_data[k, :, :] - np.nanmedian(bg_data[k, :, :]) # Do the same for the data (that's already background subtracted) data_submed = data[k, :, :] - np.nanmedian(data[k, :, :]) # Specify sections for initial guess # sect1 = data_submed[108:118,12:62]/bg_submed[108:118, 12:62] # sect2 = data_submed[93:106,152:207]/bg_submed[93:106, 152:207] sect1 = data_submed[112:118, 4:10]/bg_submed[112:118, 4:10] sect2 = data_submed[95:101, 207:212]/bg_submed[95:101, 207:212] # Reshape into 1d arrays and concatenate s1 = sect1.reshape(1, sect1.shape[0]*sect1.shape[1]) s2 = sect2.reshape(1, sect2.shape[0]*sect2.shape[1]) s12 = np.concatenate((s1[0, :], s2[0, :])) # Take median of concatenated array cte = np.nanmedian(s12) # Use filter to determine mask for estimating BG scaling # at the moment only have it working for F1140C. filt = self.database.obs[key]['FILTER'][j] if filt not in ['F1065C', 'F1140C', 'F1550C']: raise NotImplementedError('Godoy subtraction is only supported for MIRI FQPMs at this time!') else: bgmaskbase = os.path.split(os.path.abspath(__file__))[0] bgmaskfile = os.path.join(bgmaskbase, 'resources/miri_bg_masks/godoy_mask_{}.fits'.format(filt.lower())) # Run minimisation function, 'res' will tell us if there is any residual # background that wasn't removed in the initial attempt. I.e. do we # need to subtract a little bit more or less? res = minimize(ut.bg_minimize, x0=cte*100, args=(data_submed, bg_submed, bgmaskfile), method='L-BFGS-B', tol=1e-7) # Extract scale factor for the background from res scale = res.x/100 # Scale the background, and now subtract this correction from the original # background subtracted data data_improved_bgsub = data_submed - bg_submed*scale # Subtract median of residual frame to remove any residual median offset data_bg_sub[k] = data_improved_bgsub - np.nanmedian(data_improved_bgsub) # Write FITS file and PSF mask. fitsfile = ut.write_obs(fitsfile, output_dir, data_bg_sub, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs) # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile) pass
[docs] def subtract_background(self, nints_per_med=None, subdir='bgsub'): """ Median subtract the corresponding background observations from the SCI and REF data in the spaceKLIP database. Parameters ---------- nints_per_med : int Number of integrations per median. For example, if you have a target + background dataset with 20 integrations each and nints_per_med is set to 5, a median of every 5 background images will be subtracted from the corresponding 5 target images. The default is None (i.e. a median across all images). subdir : str, optional Name of the directory where the data products shall be saved. The default is 'bgsub'. Returns ------- None. """ # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) # Store the nints_per_med parameter orig_nints_per_med = deepcopy(nints_per_med) # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) # Find science, reference, and background files. ww = np.where((self.database.obs[key]['TYPE'] == 'SCI') | (self.database.obs[key]['TYPE'] == 'REF'))[0] ww_sci_bg = np.where(self.database.obs[key]['TYPE'] == 'SCI_BG')[0] ww_ref_bg = np.where(self.database.obs[key]['TYPE'] == 'REF_BG')[0] # Loop through science background files. if len(ww_sci_bg) != 0: sci_bg_data = [] sci_bg_erro = [] sci_bg_pxdq = [] for j in ww_sci_bg: # Read science background file. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile) # Determine split indices nints = data.shape[0] if orig_nints_per_med == None: nints_per_med = nints indxs = np.arange(nints) split_inds = [x+1 for x in indxs if (x+1) % nints_per_med == 0 and x < (nints-nints_per_med)] # Compute median science background. sci_bg_data += [data] sci_bg_erro += [erro] sci_bg_pxdq += [pxdq] sci_bg_data = np.concatenate(sci_bg_data) sci_bg_erro = np.concatenate(sci_bg_erro) sci_bg_pxdq = np.concatenate(sci_bg_pxdq) sci_bg_data_split = np.array_split(sci_bg_data, split_inds, axis=0) sci_bg_erro_split = np.array_split(sci_bg_erro, split_inds, axis=0) sci_bg_pxdq_split = np.array_split(sci_bg_pxdq, split_inds, axis=0) for k in range(len(split_inds)+1): sci_bg_data_split[k] = np.nanmedian(sci_bg_data_split[k], axis=0) nsample = np.sum(np.logical_not(np.isnan(sci_bg_erro_split[k])), axis=0) sci_bg_erro_split[k] = np.true_divide(np.sqrt(np.nansum(sci_bg_erro_split[k]**2, axis=0)), nsample) sci_bg_pxdq_split[k] = np.sum(sci_bg_pxdq_split[k] & 1 == 1, axis=0) != 0 else: sci_bg_data = None # Loop through reference background files. if len(ww_ref_bg) != 0: ref_bg_data = [] ref_bg_erro = [] ref_bg_pxdq = [] for j in ww_ref_bg: # Read reference background file. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile) # Determine split indices nints = data.shape[0] if orig_nints_per_med == None: nints_per_med = nints indxs = np.arange(nints) split_inds = [x+1 for x in indxs if (x+1) % nints_per_med == 0 and x < (nints-nints_per_med)] # Compute median reference background. ref_bg_data += [data] ref_bg_erro += [erro] ref_bg_pxdq += [pxdq] ref_bg_data = np.concatenate(ref_bg_data) ref_bg_erro = np.concatenate(ref_bg_erro) ref_bg_pxdq = np.concatenate(ref_bg_pxdq) ref_bg_data_split = np.array_split(ref_bg_data, split_inds, axis=0) ref_bg_erro_split = np.array_split(ref_bg_erro, split_inds, axis=0) ref_bg_pxdq_split = np.array_split(ref_bg_pxdq, split_inds, axis=0) for k in range(len(split_inds)+1): ref_bg_data_split[k] = np.nanmedian(ref_bg_data_split[k], axis=0) nsample = np.sum(np.logical_not(np.isnan(ref_bg_erro_split[k])), axis=0) ref_bg_erro_split[k] = np.true_divide(np.sqrt(np.nansum(ref_bg_erro_split[k]**2, axis=0)), nsample) ref_bg_pxdq_split[k] = np.sum(ref_bg_pxdq_split[k] & 1 == 1, axis=0) != 0 else: ref_bg_data = None # Check input. if sci_bg_data is None and ref_bg_data is None: raise UserWarning('Could not find any background files') # Loop through science and reference files. for j in ww: # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) wwtype = self.database.obs[key]['TYPE'][j] if wwtype == 'SCI': sci = True else: sci = False # Determine split indices nints = data.shape[0] if orig_nints_per_med == None: nints_per_med = nints indxs = np.arange(nints) split_inds = [x+1 for x in indxs if (x+1) % nints_per_med == 0 and x < (nints-nints_per_med)] # Subtract background. head, tail = os.path.split(fitsfile) log.info(' --> Background subtraction: ' + tail) data_split = np.array_split(data, split_inds, axis=0) erro_split = np.array_split(erro, split_inds, axis=0) pxdq_split = np.array_split(pxdq, split_inds, axis=0) # For each dataset, need to decide what to use as the background and subtract for k in range(len(split_inds)+1): if (sci and sci_bg_data is not None) or (not sci and ref_bg_data is None): if not sci and ref_bg_data is None: log.warning(' --> Could not find reference background, attempting to use science background') data_split[k] = data_split[k] - sci_bg_data_split[k] erro_split[k] = np.sqrt(erro_split[k]**2 + sci_bg_erro_split[k]**2) pxdq_split[k][np.logical_not(pxdq_split[k] & 1 == 1) & (sci_bg_pxdq_split[k] != 0)] += 1 elif (not sci and ref_bg_data is not None) or (sci and sci_bg_data is None): if sci and sci_bg_data is None: log.warning(' --> Could not find science background, attempting to use reference background') data_split[k] = data_split[k] - ref_bg_data_split[k] erro_split[k] = np.sqrt(erro_split[k]**2 + ref_bg_erro_split[k]**2) pxdq_split[k][np.logical_not(pxdq_split[k] & 1 == 1) & (ref_bg_pxdq_split[k] != 0)] += 1 data = np.concatenate(data_split, axis=0) erro = np.concatenate(erro_split, axis=0) pxdq = np.concatenate(pxdq_split, axis=0) # Write FITS file and PSF mask. fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile) pass
[docs] def find_bad_pixels(self, method='dqarr', set_dq_zero=True, dqarr_kwargs={}, sigclip_kwargs={}, custom_kwargs={}, timeints_kwargs={}, gradient_kwargs={}, types=['SCI', 'SCI_TA', 'SCI_BG', 'REF', 'REF_TA', 'REF_BG'], subdir='bpfound', restrict_to=None): """ Identify bad pixels for cleaning Parameters ---------- method : str, optional Sequence of bad pixel cleaning methods to be run on the data. Different methods must be joined by a '+' sign without whitespace. Available methods are: - dqarr: uses DQ array to identify bad pixels - sigclip: use sigma clipping to identify additional bad pixels. - custom: use a custom bad pixel map The default is 'dqarr'. set_dq_zero : bool, optional Toggle to start a new empty DQ array, or built upon the existing array. The default is True dqarr_kwargs : dict, optional Keyword arguments for the 'dqarr' identification method. Available keywords are: The default is {}. sigclip_kwargs : dict, optional Keyword arguments for the 'sigclip' identification methods. Available keywords are: - sigma: float, optional Sigma clipping threshold. The default is 5. - shift_x : list of int, optional Pixels in x-direction to which each pixel shall be compared to. The default is [-1, 0, 1]. - shift_y : list of int, optional Pixels in y-direction to which each pixel shall be compared to. The default is [-1, 0, 1]. The default is {}. custom_kwargs : dict, optional Keyword arguments for the 'custom' method. The dictionary keys must match the keys of the observations database and the dictionary content must be binary bad pixel maps (1 = bad, 0 = good) with the same shape as the corresponding data. The default is {}. The default is {}. types : list of str, optional List of data types for which bad pixels shall be identified. The default is ['SCI', 'SCI_TA', 'SCI_BG', 'REF', 'REF_TA', 'REF_BG']. subdir : str, optional Name of the directory where the data products shall be saved. The default is 'bpfound'. Returns ------- None """ # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): # if we limit to only processing some concatenations, # check whether this concatenation matches the pattern if (restrict_to is not None) and (restrict_to not in key): continue log.info('--> Concatenation ' + key) # Loop through FITS files. nfitsfiles = len(self.database.obs[key]) for j in range(nfitsfiles): # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) if set_dq_zero: # set_dq_zero # Make copy of DQ array filled with zeros, i.e. all good pixels pxdq_temp = np.zeros_like(pxdq) else: # Make copy of DQ array pxdq_temp = pxdq.copy() # Skip file types that are not in the list of types. if self.database.obs[key]['TYPE'][j] in types: # Call bad pixel identification routines. method_split = method.split('+') if method_split[0] != 'dqarr' and not set_dq_zero: # If the first methond is not dqarr and you are not using a boolean mask for pxdq_temp, # convert pxdq_temp to a boolean mask or some of the next steps won't work. # This is just a place holder. We need to think about how this mask will look like pxdq_temp = (pxdq_temp > 0) for k in range(len(method_split)): head, tail = os.path.split(fitsfile) if method_split[k] == 'dqarr': log.info(' --> Method ' + method_split[k] + ': ' + tail) # Flag any pixels marked as DO_NOT_USE that aren't NONSCIENCE pxdq_temp = (np.isnan(data) | (pxdq_temp & 1 == 1)) \ & np.logical_not(pxdq_temp & 512 == 512) elif method_split[k] == 'sigclip': log.info(' --> Method ' + method_split[k] + ': ' + tail) self.find_bad_pixels_sigclip(data, erro, pxdq_temp, pxdq & 512 == 512, sigclip_kwargs) elif method_split[k] == 'custom': log.info(' --> Method ' + method_split[k] + ': ' + tail) if self.database.obs[key]['TYPE'][j] not in ['SCI_TA', 'REF_TA']: self.find_bad_pixels_custom(data, erro, pxdq_temp, key, custom_kwargs) else: log.info(' --> Method ' + method_split[k] + ': skipped because TA file') elif method_split[k] == 'timeints': self.find_bad_pixels_timeints(data, erro, pxdq_temp, key, timeints_kwargs) elif method_split[k] == 'gradient': self.find_bad_pixels_gradient(data, erro, pxdq_temp, key, gradient_kwargs) else: log.info(' --> Unknown method ' + method_split[k] + ': skipped') if set_dq_zero: # The new DQ will just be the pxdq_temp we've been modifying new_dq = pxdq_temp.astype(np.uint32) else: # The new DQ will be the original pxdq with added the flagged pixels from the pxdq_temp we've been modifying as do_not_use new_dq = np.bitwise_or(pxdq.copy(), pxdq_temp).astype(np.uint32) # Write FITS file and PSF mask. fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, new_dq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile) pass
[docs] def fix_bad_pixels(self, method='timemed+localmed+medfilt', sigclip_kwargs={}, custom_kwargs={}, timemed_kwargs={}, localmed_kwargs={}, medfilt_kwargs={}, types=['SCI', 'SCI_TA', 'SCI_BG', 'REF', 'REF_TA', 'REF_BG'], subdir='bpcleaned', restrict_to=None): """ TO BE DEPRECATED BY FIND_BAD_PIXELS() AND CLEAN_BAD_PIXELS() Identify and fix bad pixels. Parameters ---------- method : str, optional Sequence of bad pixel identification and cleaning methods to be run on the data. Different methods must be joined by a '+' sign without whitespace. Available methods are: - sigclip: use sigma clipping to identify additional bad pixels. - custom: use a custom bad pixel map. - timemed: replace pixels which are only bad in some frames with their median value from the good frames. - localmed: replace bad pixels with the median value of their surrounding good pixels. - medfilt: replace bad pixels with an image plane median filter. The default is 'timemed+localmed+medfilt'. sigclip_kwargs : dict, optional Keyword arguments for the 'sigclip' method. Available keywords are: - sigclip : float, optional Sigma clipping threshold. The default is 5. - shift_x : list of int, optional Pixels in x-direction to which each pixel shall be compared to. The default is [-1, 0, 1]. - shift_y : list of int, optional Pixels in y-direction to which each pixel shall be compared to. The default is [-1, 0, 1]. The default is {}. custom_kwargs : dict, optional Keyword arguments for the 'custom' method. The dictionary keys must match the keys of the observations database and the dictionary content must be binary bad pixel maps (1 = bad, 0 = good) with the same shape as the corresponding data. The default is {}. timemed_kwargs : dict, optional Keyword arguments for the 'timemed' method. Available keywords are: - n/a The default is {}. localmed_kwargs : dict, optional Keyword arguments for the 'localmed' method. Available keywords are: - shift_x : list of int, optional Pixels in x-direction from which the median shall be computed. The default is [-1, 0, 1]. - shift_y : list of int, optional Pixels in y-direction from which the median shall be computed. The default is [-1, 0, 1]. The default is {}. medfilt_kwargs : dict, optional Keyword arguments for the 'medfilt' method. Available keywords are: - size : int, optional Kernel size of the median filter to be used. The default is 4. The default is {}. types : list of str, optional List of data types for which bad pixels shall be identified and fixed. The default is ['SCI', 'SCI_TA', 'SCI_BG', 'REF', 'REF_TA', 'REF_BG']. subdir : str, optional Name of the directory where the data products shall be saved. The default is 'bpcleaned'. Returns ------- None. """ # log.info('--> WARNING! The fix_bad_pixels() routine is deprecated, the ..........') # log.info('--> WARNING! find_bad_pixels() and clean_bad_pixels() are preferred!!!!') # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): # if we limit to only processing some concatenations, check whether this concatenation matches the pattern if (restrict_to is not None) and (restrict_to not in key): continue log.info('--> Concatenation ' + key) # Loop through FITS files. nfitsfiles = len(self.database.obs[key]) for j in range(nfitsfiles): # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) # Skip file types that are not in the list of types. if self.database.obs[key]['TYPE'][j] in types: # Call bad pixel cleaning routines. pxdq_temp = pxdq.copy() # if self.database.obs[key]['TELESCOP'][j] == 'JWST' and self.database.obs[key]['INSTRUME'][j] == 'NIRCAM': # pxdq_temp = (pxdq_temp != 0) & np.logical_not(pxdq_temp & 512 == 512) # else: pxdq_temp = (np.isnan(data) | (pxdq_temp & 1 == 1)) & np.logical_not(pxdq_temp & 512 == 512) method_split = method.split('+') for k in range(len(method_split)): head, tail = os.path.split(fitsfile) if method_split[k] == 'sigclip': log.info(' --> Method ' + method_split[k] + ': ' + tail) self.find_bad_pixels_sigclip(data, erro, pxdq_temp, pxdq & 512 == 512, sigclip_kwargs) elif method_split[k] == 'custom': log.info(' --> Method ' + method_split[k] + ': ' + tail) if self.database.obs[key]['TYPE'][j] not in ['SCI_TA', 'REF_TA']: self.find_bad_pixels_custom(data, erro, pxdq_temp, key, custom_kwargs) else: log.info(' --> Method ' + method_split[k] + ': skipped because TA file') elif method_split[k] == 'timemed': log.info(' --> Method ' + method_split[k] + ': ' + tail) self.fix_bad_pixels_timemed(data, erro, pxdq_temp, timemed_kwargs) elif method_split[k] == 'localmed': log.info(' --> Method ' + method_split[k] + ': ' + tail) self.fix_bad_pixels_localmed(data, erro, pxdq_temp, localmed_kwargs) elif method_split[k] == 'medfilt': log.info(' --> Method ' + method_split[k] + ': ' + tail) self.fix_bad_pixels_medfilt(data, erro, pxdq_temp, medfilt_kwargs) else: log.info(' --> Unknown method ' + method_split[k] + ': skipped') # if self.database.obs[key]['TELESCOP'][j] == 'JWST' and self.database.obs[key]['INSTRUME'][j] == 'NIRCAM': # pxdq[(pxdq != 0) & np.logical_not(pxdq & 512 == 512) & (pxdq_temp == 0)] = 0 # else: # pxdq[(pxdq & 1 == 1) & np.logical_not(pxdq & 512 == 512) & (pxdq_temp == 0)] = 0 # update the pixel DQ bit flags for the output files. # The pxdq variable here is effectively just the DO_NOT_USE flag, discarding other bits. # We want to make a new dq which retains the other bits as much as possible. # first, retain all the other bits (bits greater than 1), then add in the new/cleaned DO_NOT_USE bit do_not_use = jwst.datamodels.dqflags.pixel['DO_NOT_USE'] new_dq = np.bitwise_and(pxdq.copy(), np.invert(do_not_use)) # retain all other bits except the do_not_use bit new_dq = np.bitwise_or(new_dq, pxdq_temp) # add in the do_not_use bit from the cleaned version new_dq = new_dq.astype(np.uint32) # ensure correct output type for saving # (the bitwise steps otherwise return np.int64 which isn't FITS compatible) # Write FITS file and PSF mask. fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, new_dq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile) pass
[docs] def clean_bad_pixels(self, method='timemed+localmed+medfilt', timemed_kwargs={}, localmed_kwargs={}, medfilt_kwargs={}, interp2d_kwargs={}, types=['SCI', 'SCI_TA', 'SCI_BG', 'REF', 'REF_TA', 'REF_BG'], subdir='bpcleaned', restrict_to=None): """ Clean bad pixels. Parameters ---------- method : str, optional Sequence of bad pixel cleaning methods to be run on the data. Different methods must be joined by a '+' sign without whitespace. Available methods are: - timemed: replace pixels which are only bad in some frames with their median value from the good frames. - localmed: replace bad pixels with the median value of their surrounding good pixels. - medfilt: replace bad pixels with an image plane median filter. - interp2d: replace bad pixels with an interpolation of neighbouring pixels. The default is 'timemed+localmed+medfilt'. timemed_kwargs : dict, optional Keyword arguments for the 'timemed' method. Available keywords are: - n/a The default is {}. localmed_kwargs: dict, optional Keyword arguments for the 'localmed' method. Available keywords are: - shift_x : list of int, optional Pixels in x-direction from which the median shall be computed. The default is [-1, 0, 1]. - shift_y : list of int, optional Pixels in y-direction from which the median shall be computed. The default is [-1, 0, 1]. The default is {}. medfilt_kwargs : dict, optional Keyword arguments for the 'medfilt' method. Available keywords are: - size : int, optional Kernel size of the median filter to be used. The default is 4. The default is {}. interp2d_kwargs: dict, optional Keyword arguments for the 'interp2d' method. Available keywords are: - size : int, optional Kernel size of the median filter to be used. The default is 4. The default is {}. types : list of str, optional List of data types for which bad pixels shall be identified and fixed. The default is ['SCI', 'SCI_TA', 'SCI_BG', 'REF', 'REF_TA', 'REF_BG']. subdir : str, optional Name of the directory where the data products shall be saved. The default is 'bpcleaned'. Returns ------- None. """ # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): # if we limit to only processing some concatenations, check whether this concatenation matches the pattern if (restrict_to is not None) and (restrict_to not in key): continue log.info('--> Concatenation ' + key) # Loop through FITS files. nfitsfiles = len(self.database.obs[key]) for j in range(nfitsfiles): # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) fig = plt.figure() ax = plt.gca() ax.hist(data.flatten(), bins=int(np.sqrt(len(data.flatten()))), histtype='step', label='Pre Cleaning') # Make copy of DQ array pxdq_temp = pxdq.copy() # Don't want to clean anything that isn't bad or is a non-science pixel pxdq_temp = (np.isnan(data) | (pxdq_temp & 1 == 1)) & np.logical_not(pxdq_temp & 512 == 512) # Skip file types that are not in the list of types. if self.database.obs[key]['TYPE'][j] in types: method_split = method.split('+') spatial = ['localmed', 'medfilt', 'interp2d'] # If localmed and medfilt in cleaning, can't run both if len(set(method_split) & set(spatial)) > 1: log.info(' --> WARNING: Multiple spatial cleaning routines detected!') log.info(' --> The localmed/medfilt/interp2d methods clean data in a similar manner!') log.info(' --> medfilt and interp2d are redundant') log.info(' --> only the first method listed will affect the data') log.info(' --> localmed is partially redundant with other methods') log.info(' --> if run first, large clusters of bad pixels may not be fully cleaned.') # Loop over methods for k in range(len(method_split)): head, tail = os.path.split(fitsfile) log.info(' --> Method ' + method_split[k] + ': ' + tail) if method_split[k] == 'timemed': self.fix_bad_pixels_timemed(data, erro, pxdq_temp, timemed_kwargs) elif method_split[k] == 'localmed': self.fix_bad_pixels_localmed(data, erro, pxdq_temp, localmed_kwargs) elif method_split[k] == 'medfilt': self.fix_bad_pixels_medfilt(data, erro, pxdq_temp, medfilt_kwargs) elif method_split[k] == 'interp2d': self.fix_bad_pixels_interp2d(data, erro, pxdq_temp, interp2d_kwargs) else: log.info(' --> Unknown method ' + method_split[k] + ': skipped') # update the pixel DQ bit flags for the output files. # The pxdq variable here is effectively just the DO_NOT_USE flag, discarding other bits. # We want to make a new dq which retains the other bits as much as possible. # first, retain all the other bits (bits greater than 1), then add in the new/cleaned DO_NOT_USE bit do_not_use = jwst.datamodels.dqflags.pixel['DO_NOT_USE'] new_dq = np.bitwise_and(pxdq.copy(), np.invert(do_not_use)) # retain all other bits except the do_not_use bit new_dq = np.bitwise_or(new_dq, pxdq_temp) # add in the do_not_use bit from the cleaned version new_dq = new_dq.astype(np.uint32) # ensure correct output type for saving # (the bitwise steps otherwise return np.int64 which isn't FITS compatible) # Finish figure for this file ax.hist(data.flatten(), bins=int(np.sqrt(len(data.flatten()))), histtype='step', label='Post Cleaning') ax.legend() # ax.set_xscale('log') ax.set_yscale('log') ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12) ax.set_xlabel("Pixel Value", fontsize=14) ax.set_ylabel("Frequency", fontsize=12) ax.set_title(f"{os.path.basename(fitsfile)} \n Original vs. Cleaned Data", fontsize=16) output_file = os.path.join(output_dir, tail.replace('.fits', '_hist.png')) plt.savefig(output_file) plt.close(fig) # Write FITS file and PSF mask. fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, new_dq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile) pass
[docs] def find_bad_pixels_sigclip(self, data, erro, pxdq, NON_SCIENCE, sigclip_kwargs={}): """ Use an iterative sigma clipping algorithm to identify additional bad pixels in the data. Parameters ---------- data : 3D-array Input images. erro : 3D-array Input image uncertainties. pxdq : 3D-array Input binary bad pixel maps (1 = bad, 0 = good). Will be updated by the routine to include the newly identified bad pixels. NON_SCIENCE : 3D-array Input binary non-science pixel maps (1 = bad, 0 = good). Will not be modified by the routine. sigclip_kwargs : dict, optional Keyword arguments for the 'sigclip' method. Available keywords are: - sigma : float, optional Sigma clipping threshold. The default is 5. - neg_sigma : float, optional Sigma clipping threshold for negative outliers. The default is 1. - shift_x : list of int, optional Pixels in x-direction to which each pixel shall be compared to. The default is [-1, 0, 1]. - shift_y : list of int, optional Pixels in y-direction to which each pixel shall be compared to. The default is [-1, 0, 1]. The default is {}. Returns ------- None. """ # Check input. if 'sigma' not in sigclip_kwargs.keys(): sigclip_kwargs['sigma'] = 5. if 'neg_sigma' not in sigclip_kwargs.keys(): sigclip_kwargs['neg_sigma'] = 1. if 'shift_x' not in sigclip_kwargs.keys(): sigclip_kwargs['shift_x'] = [-1, 0, 1] if 'shift_y' not in sigclip_kwargs.keys(): sigclip_kwargs['shift_y'] = [-1, 0, 1] if 0 not in sigclip_kwargs['shift_x']: sigclip_kwargs['shift_x'] += [0] if 0 not in sigclip_kwargs['shift_y']: sigclip_kwargs['shift_y'] += [0] # Pad data. pad_left = np.abs(np.min(sigclip_kwargs['shift_x'])) pad_right = np.abs(np.max(sigclip_kwargs['shift_x'])) if pad_right == 0: right = None else: right = -pad_right pad_bottom = np.abs(np.min(sigclip_kwargs['shift_y'])) pad_top = np.abs(np.max(sigclip_kwargs['shift_y'])) if pad_top == 0: top = None else: top = -pad_top pad_vals = ((pad_bottom, pad_top), (pad_left, pad_right)) # Find bad pixels using median of neighbors. pxdq_orig = pxdq.copy() ww = pxdq != 0 data_temp = data.copy() data_temp[ww] = np.nan erro_temp = erro.copy() erro_temp[ww] = np.nan for i in range(ww.shape[0]): # Get median background and standard deviation. bg_med = np.nanmedian(data_temp[i]) bg_std = robust.medabsdev(data_temp[i]) bg_ind = data[i] < (bg_med + 10. * bg_std) # clip bright PSFs for final calculation bg_med = np.nanmedian(data_temp[i][bg_ind]) bg_std = robust.medabsdev(data_temp[i][bg_ind]) # Create initial mask of large negative values. ww[i] = ww[i] | (data[i] < bg_med - sigclip_kwargs['neg_sigma'] * bg_std) ww[i][NON_SCIENCE[i]] = 0 # Loop through max 10 iterations. for it in range(10): data_temp[i][ww[i]] = np.nan erro_temp[i][ww[i]] = np.nan # Shift data and calculate median and standard deviation of neighbours pad_data = np.pad(data_temp[i], pad_vals, mode='edge') pad_erro = np.pad(erro_temp[i], pad_vals, mode='edge') data_arr = [] erro_arr = [] for ix in sigclip_kwargs['shift_x']: for iy in sigclip_kwargs['shift_y']: if ix != 0 or iy != 0: data_arr += [np.roll(pad_data, (iy, ix), axis=(0, 1))] erro_arr += [np.roll(pad_erro, (iy, ix), axis=(0, 1))] data_arr = np.array(data_arr) data_arr_trim = data_arr[:, pad_bottom:top, pad_left:right] data_med = np.nanmedian(data_arr_trim, axis=0) diff = data[i] - data_med data_std = np.nanstd(data_arr_trim, axis=0) # # Do the same for the diff array we just made # pad_diff = np.pad(diff, pad_vals, mode='edge') # diff_arr = [] # for ix in sigclip_kwargs['shift_x']: # for iy in sigclip_kwargs['shift_y']: # if ix != 0 or iy != 0: # diff_arr += [np.roll(pad_diff, (iy, ix), axis=(0, 1))] # diff_arr = np.array(diff_arr) # diff_arr = diff_arr[:, pad_bottom:top, pad_left:right] # diff_med = np.nanmedian(diff_arr, axis=0) # doublediff = data[i] - data_med - diff_med # diff_std = np.nanstd(diff_arr, axis=0) # Find values N standard deviations above the mean of neighbors threshold = sigclip_kwargs['sigma'] * data_std mask_new = diff > threshold data_temp[i][mask_new] = np.nan # fig, ax = plt.subplots(1, 2) # ax[0].imshow(data_temp[i]) # ax[1].imshow(data_std) # plt.show() nmask_new = np.sum(mask_new & np.logical_not(ww[i])) # print('Iteration %.0f: %.0f bad pixels identified, %.0f are new' % (it + 1, np.sum(mask_new), nmask_new)) sys.stdout.write('\rFrame %.0f/%.0f, iteration %.0f' % (i + 1, ww.shape[0], it + 1)) sys.stdout.flush() if it > 0 and nmask_new == 0: break ww[i] = ww[i] | mask_new ww[i][NON_SCIENCE[i]] = 0 pxdq[i][ww[i]] = 1 print('') log.info(' --> Method sigclip: identified %.0f additional bad pixel(s) -- %.2f%%' % (np.sum(pxdq) - np.sum(pxdq_orig), 100. * (np.sum(pxdq) - np.sum(pxdq_orig)) / np.prod(pxdq.shape))) pass
[docs] def find_bad_pixels_timeints(self, data, erro, pxdq, NON_SCIENCE, timeints_kwargs={}): """ Identify bad pixels from temporal variations across integrations. Parameters ---------- data : 3D-array Input images. erro : 3D-array Input image uncertainties. pxdq : 3D-array Input binary bad pixel maps (1 = bad, 0 = good). Will be updated by the routine to include the newly identified bad pixels. NON_SCIENCE : 3D-array Input binary non-science pixel maps (1 = bad, 0 = good). Will not be modified by the routine. timeints_kwargs : dict, optional Keyword arguments for the 'timeints' method. Available keywords are: - sigma : float, optional Sigma clipping threshold. The default is 5. The default is {}. Returns ------- None. """ # Check input. if 'sigma' not in timeints_kwargs.keys(): timeints_kwargs['sigma'] = 10. pxdq_orig = pxdq.copy() ww = pxdq != 0 data_temp = data.copy() data_temp[ww] = np.nan # Find bad pixels across the cube med_ints = np.nanmedian(data_temp, axis=0) std_ints = np.nanstd(data_temp, axis=0) std2_ints = robust.medabsdev(data_temp, axis=0) diff = np.abs((data_temp - med_ints)) / std2_ints mask_new = diff > timeints_kwargs['sigma'] # data_temp[mask_new] = 9999 # plt.imshow(data_temp[1]) # plt.show() # plt.hist(diff.flatten(), # bins=int(np.sqrt(len(diff.flatten()))), # histtype='step', # label='Pre Cleaning') # plt.yscale('log') # plt.show() ww = ww | mask_new pxdq[ww] = 1 print('') log.info(' --> Method timeints: identified %.0f additional bad pixel(s) -- %.2f%%' % (np.sum(pxdq) - np.sum(pxdq_orig), 100. * (np.sum(pxdq) - np.sum(pxdq_orig)) / np.prod(pxdq.shape))) pass
[docs] def find_bad_pixels_gradient(self, data, erro, pxdq, key, gradient_kwargs={}): print('') log.info(' --> Warning!: This routine has not been thoroughly tested and requires further development') # Check input. if 'sigma' not in gradient_kwargs.keys(): gradient_kwargs['sigma'] = 0.5 if 'threshold' not in gradient_kwargs.keys(): gradient_kwargs['threshold'] = 0.05 if 'negative' not in gradient_kwargs.keys(): gradient_kwargs['negative'] = True sig = gradient_kwargs['sigma'] threshold = gradient_kwargs['threshold'] negative = gradient_kwargs['negative'] pxdq_orig = pxdq.copy() ww = pxdq != 0 data_temp = data.copy() data_temp[ww] = np.nan # Loop over the images for i in range(ww.shape[0]): image = data_temp[i] # remove nans x = np.arange(0, image.shape[1]) y = np.arange(0, image.shape[0]) xx, yy = np.meshgrid(x, y) # mask nans image = np.ma.masked_invalid(image) xvalid = xx[~image.mask] yvalid = yy[~image.mask] newimage = image[~image.mask] image_no_nans = griddata((xvalid, yvalid), newimage.ravel(), (xx, yy), method='linear') # get smooth image smimage=gaussian_filter(image_no_nans, sigma=sig) # get sharp image shimage = image_no_nans-smimage # get gradients image_to_gradient = shimage/smimage gr = np.gradient((image_to_gradient)) gr_dx = gr[1] gr_dy = gr[0] # pad gradient adding 1 extra pixel at beginning and end gr_dxp = np.pad(gr_dx, (1, 1)) gr_dyp = np.pad(gr_dy, (1, 1)) # identify bad pixels # positive bad_pixels = (gr_dxp[1:-1, 2:] < -threshold) & (gr_dxp[1:-1, :-2] > threshold) & (gr_dyp[2:, 1:-1] < -threshold) & (gr_dyp[:-2, 1:-1] > threshold) # negative if negative: bad_pixels_n = (gr_dxp[1:-1, 2:] > threshold) & (gr_dxp[1:-1, :-2] < -threshold) & (gr_dyp[2:, 1:-1] > threshold) & (gr_dyp[:-2, 1:-1] < threshold) bad_pixels = bad_pixels | bad_pixels_n image[bad_pixels] = np.nan # Flag DQ array ww[i] = ww[i] | bad_pixels pxdq[i][ww[i]] = 1 print('') log.info(' --> Method gradient: identified %.0f additional bad pixel(s) -- %.2f%%' % (np.sum(pxdq) - np.sum(pxdq_orig), 100. * (np.sum(pxdq) - np.sum(pxdq_orig)) / np.prod(pxdq.shape))) pass
[docs] def find_bad_pixels_custom(self, data, erro, pxdq, key, custom_kwargs={}): """ Use a custom bad pixel map to flag additional bad pixels in the data. Parameters ---------- data : 3D-array Input images. erro : 3D-array Input image uncertainties. pxdq : 3D-array Input binary bad pixel maps (1 = bad, 0 = good). Will be updated by the routine to include the newly flagged bad pixels. key : str Database key of the observation to be updated. custom_kwargs : dict, optional Keyword arguments for the 'custom' method. The dictionary keys must match the keys of the observations database and the dictionary content must be binary bad pixel maps (1 = bad, 0 = good) with the same shape as the corresponding data. The default is {}. Returns ------- None. """ # Find bad pixels using median of neighbors. pxdq_orig = pxdq.copy() pxdq_custom = custom_kwargs[key] != 0 if pxdq_custom.ndim == pxdq.ndim - 1: # Enable 3D bad pixel map to flag individual frames pxdq_custom = np.array([pxdq_custom] * pxdq.shape[0]) pxdq[pxdq_custom] = 1 log.info(' --> Method custom: flagged %.0f additional bad pixel(s) -- %.2f%%' % (np.sum(pxdq) - np.sum(pxdq_orig), 100. * (np.sum(pxdq) - np.sum(pxdq_orig)) / np.prod(pxdq.shape))) pass
[docs] def fix_bad_pixels_timemed(self, data, erro, pxdq, timemed_kwargs={}): """ Replace pixels which are only bad in some frames with their median value from the good frames. Parameters ---------- data : 3D-array Input images. erro : 3D-array Input image uncertainties. pxdq : 3D-array Input binary bad pixel maps (1 = bad, 0 = good). Will be updated by the routine to exclude the fixed bad pixels. timemed_kwargs : dict, optional Keyword arguments for the 'timemed' method. Available keywords are: - n/a The default is {}. Returns ------- None. """ # Fix bad pixels using time median. ww = pxdq != 0 ww_all_bad = np.array([np.sum(ww, axis=0) == ww.shape[0]] * ww.shape[0]) ww_not_all_bad = ww & np.logical_not(ww_all_bad) log.info(' --> Method timemed: fixing %.0f bad pixel(s) -- %.2f%%' % (np.sum(ww_not_all_bad), 100. * np.sum(ww_not_all_bad) / np.prod(ww_not_all_bad.shape))) data[ww_not_all_bad] = np.nan data[ww_not_all_bad] = np.array([np.nanmedian(data, axis=0)] * data.shape[0])[ww_not_all_bad] erro[ww_not_all_bad] = np.nan erro[ww_not_all_bad] = np.array([np.nanmedian(erro, axis=0)] * erro.shape[0])[ww_not_all_bad] pxdq[ww_not_all_bad] = 0 pass
[docs] def fix_bad_pixels_localmed(self, data, erro, pxdq, localmed_kwargs={}): """ Replace bad pixels with the median value of their surrounding good pixels. Parameters ---------- data : 3D-array Input images. erro : 3D-array Input image uncertainties. pxdq : 3D-array Input binary bad pixel maps (1 = bad, 0 = good). Will be updated by the routine to exclude the fixed bad pixels. localmed_kwargs : dict, optional Keyword arguments for the 'localmed' method. Available keywords are: - shift_x : list of int, optional Pixels in x-direction from which the median shall be computed. The default is [-1, 0, 1]. - shift_y : list of int, optional Pixels in y-direction from which the median shall be computed. The default is [-1, 0, 1]. The default is {}. Returns ------- None. """ # Check input. if 'shift_x' not in localmed_kwargs.keys(): localmed_kwargs['shift_x'] = [-1, 0, 1] if 'shift_y' not in localmed_kwargs.keys(): localmed_kwargs['shift_y'] = [-1, 0, 1] if 0 not in localmed_kwargs['shift_x']: localmed_kwargs['shift_x'] += [0] if 0 not in localmed_kwargs['shift_y']: localmed_kwargs['shift_y'] += [0] # Pad data. pad_left = np.abs(np.min(localmed_kwargs['shift_x'])) pad_right = np.abs(np.max(localmed_kwargs['shift_x'])) if pad_right == 0: right = None else: right = -pad_right pad_bottom = np.abs(np.min(localmed_kwargs['shift_y'])) pad_top = np.abs(np.max(localmed_kwargs['shift_y'])) if pad_top == 0: top = None else: top = -pad_top pad_vals = ((0, 0), (pad_bottom, pad_top), (pad_left, pad_right)) # Fix bad pixels using median of neighbors. ww = pxdq != 0 data_temp = data.copy() data_temp[ww] = np.nan pad_data = np.pad(data_temp, pad_vals, mode='edge') erro_temp = erro.copy() erro_temp[ww] = np.nan pad_erro = np.pad(erro_temp, pad_vals, mode='edge') for i in range(ww.shape[0]): data_arr = [] erro_arr = [] for ix in localmed_kwargs['shift_x']: for iy in localmed_kwargs['shift_y']: if ix != 0 or iy != 0: data_arr += [np.roll(pad_data[i], (iy, ix), axis=(0, 1))] erro_arr += [np.roll(pad_erro[i], (iy, ix), axis=(0, 1))] data_arr = np.array(data_arr) data_arr = data_arr[:, pad_bottom:top, pad_left:right] data_med = np.nanmedian(data_arr, axis=0) ww[i][np.isnan(data_med)] = 0 data[i][ww[i]] = data_med[ww[i]] erro_arr = np.array(erro_arr) erro_arr = erro_arr[:, pad_bottom:top, pad_left:right] erro_med = np.nanmedian(erro_arr, axis=0) erro[i][ww[i]] = erro_med[ww[i]] pxdq[i][ww[i]] = 0 log.info(' --> Method localmed: fixing %.0f bad pixel(s) -- %.2f%%' % (np.sum(ww), 100. * np.sum(ww) / np.prod(ww.shape))) pass
[docs] def fix_bad_pixels_medfilt(self, data, erro, pxdq, medfilt_kwargs={}): """ Replace bad pixels with an image plane median filter. Parameters ---------- data : 3D-array Input images. erro : 3D-array Input image uncertainties. pxdq : 3D-array Input binary bad pixel maps (1 = bad, 0 = good). Will be updated by the routine to exclude the fixed bad pixels. medfilt_kwargs : dict, optional Keyword arguments for the 'medfilt' method. Available keywords are: - size : int, optional Kernel size of the median filter to be used. The default is 4. The default is {}. Returns ------- None. """ # Check input. if 'size' not in medfilt_kwargs.keys(): medfilt_kwargs['size'] = 4 # Fix bad pixels using median filter. ww = pxdq != 0 log.info(' --> Method medfilt: fixing %.0f bad pixel(s) -- %.2f%%' % (np.sum(ww), 100. * np.sum(ww) / np.prod(ww.shape))) data_temp = data.copy() data_temp[np.isnan(data_temp)] = 0. erro_temp = erro.copy() erro_temp[np.isnan(erro_temp)] = 0. for i in range(ww.shape[0]): data[i][ww[i]] = median_filter(data_temp[i], **medfilt_kwargs)[ww[i]] erro[i][ww[i]] = median_filter(erro_temp[i], **medfilt_kwargs)[ww[i]] pxdq[i][ww[i]] = 0 pass
[docs] def fix_bad_pixels_interp2d(self, data, erro, pxdq, interp2d_kwargs={}): """ Replace bad pixels with an interpolation of neighbouring pixels. Parameters ---------- data : 3D-array Input images. erro : 3D-array Input image uncertainties. pxdq : 3D-array Input binary bad pixel maps (1 = bad, 0 = good). Will be updated by the routine to exclude the fixed bad pixels. interp2d_kwargs : dict, optional Keyword arguments for the 'interp2d' method. Available keywords are: - size : int, optional Kernel size of the median filter to be used. The default is 4. The default is {}. Returns ------- None. """ # Check input. if 'size' not in interp2d_kwargs.keys(): interp2d_kwargs['size'] = 5 # Fix bad pixels using interpolation of neighbors. ww = (pxdq != 0) & np.logical_not(pxdq & 512 == 512) log.info(' --> Method interp2d: fixing %.0f bad pixel(s) -- %.2f%%' % (np.sum(ww), 100. * np.sum(ww) / np.prod(ww.shape))) # NaN pixels to be replaced with interpolation data_temp = data.copy() data_temp[np.where(np.isnan(data_temp))] = 0 data_temp[ww] = np.nan erro_temp = erro.copy() erro_temp[np.where(np.isnan(erro_temp))] = 0 erro_temp[ww] = np.nan rows, cols = data_temp[0].shape half_box = interp2d_kwargs['size'] // 2 for i in range(ww.shape[0]): for ri in range(rows): for ci in range(cols): if np.isnan(data_temp[i][ri, ci]): # Calculate the indices of the NxN box centered around the NaN pixel x_min = max(0, ci - half_box) x_max = min(cols, ci + half_box + 1) y_min = max(0, ri - half_box) y_max = min(rows, ri + half_box + 1) # Extract a NxN box within the valid range box = data_temp[i][y_min:y_max, x_min:x_max] ebox = erro_temp[i][y_min:y_max, x_min:x_max] # Extract coordinates and values from the box box_coords = np.array(np.where(~np.isnan(box))).T \ + np.array([[x_min, y_min]]) box_values = box[~np.isnan(box)] ebox_coords = np.array(np.where(~np.isnan(ebox))).T \ + np.array([[x_min, y_min]]) ebox_values = ebox[~np.isnan(ebox)] # Perform interpolation if there are valid values in the box if len(box_values) > interp2d_kwargs['size'] \ and len(ebox_values) > interp2d_kwargs['size']: # Extract x and y coordinates of valid values, same coords for # data and err x_coords = box_coords[:, 0] y_coords = box_coords[:, 1] ex_coords = ebox_coords[:, 0] ey_coords = ebox_coords[:, 1] # Perform interpolation of data data_interp = griddata((x_coords, y_coords), box_values, (ci, ri), method='linear', fill_value=np.nan) # Replace data pixel with interpolated value data[i][ri, ci] = data_interp # Perform interpolation of error err_interp = griddata((ex_coords, ey_coords), ebox_values, (ci, ri), method='linear', fill_value=np.nan) # Replace error pixel erro[i][ri, ci] = err_interp pxdq[i][ww[i]] = 0 pass
[docs] def replace_nans(self, cval=0., types=['SCI', 'SCI_BG', 'REF', 'REF_BG'], subdir='nanreplaced'): """ Replace all nans in the data with a constant value. Parameters ---------- cval : float, optional Fill value for the nan pixels. The default is 0. types : list of str, optional List of data types for which nans shall be replaced. The default is ['SCI', 'SCI_BG', 'REF', 'REF_BG']. subdir : str, optional Name of the directory where the data products shall be saved. The default is 'nanreplaced'. Returns ------- None. """ # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) # Loop through FITS files. nfitsfiles = len(self.database.obs[key]) for j in range(nfitsfiles): # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) # Skip file types that are not in the list of types. if self.database.obs[key]['TYPE'][j] in types: # Replace nans. head, tail = os.path.split(fitsfile) log.info(' --> Nan replacement: ' + tail) ww = np.isnan(data) data[ww] = cval log.info(' --> Nan replacement: replaced %.0f nan pixel(s) with value ' % (np.sum(ww)) + str(cval) + ' -- %.2f%%' % (100. * np.sum(ww)/np.prod(ww.shape))) # Write FITS file and PSF mask. fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile) pass
[docs] def blur_frames(self, fact='auto', types=['SCI', 'SCI_BG', 'REF', 'REF_BG'], subdir='blurred'): """ Blur frames with a Gaussian filter. Parameters ---------- fact : 'auto' or 'fix23' or float or dict of list of float or None, optional FWHM (pix) of the Gaussian filter. If 'auto', will compute the FWHM automatically based on the Nyquist sampling criterion for discrete data, which is FWHM = lambda / 2.3D, where D = 5.2 m for NIRCam coronagraphy and D = 6.5 m otherwise. If 'fix23', will always blur the data with a Gaussian kernel of FWHM = 2.3 pix, so that even bad pixels cause no more Fourier ripples. If dict of list of float, then the dictionary keys must match the keys of the observations database, and the number of entries in the lists must match the number of observations in the corresponding concatenation. Then, a different FWHM can be used for each observation. If None, the corresponding observation will be skipped. The default is 'auto'. types : list of str, optional List of data types for which the frames shall be blurred. The default is ['SCI', 'SCI_BG', 'REF', 'REF_BG']. subdir : str, optional Name of the directory where the data products shall be saved. The default is 'blurred'. Returns ------- None. """ # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) # Loop through FITS files. Nfitsfiles = len(self.database.obs[key]) for j in range(Nfitsfiles): # Read FITS file. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) # Skip file types that are not in the list of types. fact_temp = None if self.database.obs[key]['TYPE'][j] in types: # Blur frames. head, tail = os.path.split(fitsfile) log.info(' --> Frame blurring: ' + tail) try: fact_temp = fact[key][j] except: fact_temp = fact if self.database.obs[key]['TELESCOP'][j] == 'JWST': if self.database.obs[key]['EXP_TYPE'][j] in ['NRC_CORON']: diam = 5.2 else: diam = JWST_CIRCUMSCRIBED_DIAMETER else: raise UserWarning('Data originates from unknown telescope') if fact_temp is not None: if str(fact_temp) == 'auto': wave_min = self.database.obs[key]['CWAVEL'][j] - self.database.obs[key]['DWAVEL'][j] # micron fwhm_current = wave_min * 1e-6 / diam * 180. / np.pi * 3600. / self.database.obs[key]['PIXSCALE'][j] # pix fwhm_desired = 2.3 # pix; see, e.g., Pawley 2006 fwhm_desired *= 1.5 # go to 1.5 times the theoretically required bluring to safely avoid numerical ringing effects fact_temp = np.sqrt(fwhm_desired**2 - fwhm_current**2) fact_temp /= np.sqrt(8. * np.log(2.)) # fix from Marshall if str(fact_temp) == 'fix23': fwhm_current = 1. # pix fwhm_desired = 2.3 # pix; see, e.g., Pawley 2006 fact_temp = np.sqrt(fwhm_desired**2 - fwhm_current**2) fact_temp /= np.sqrt(8. * np.log(2.)) # fix from Marshall if np.isnan(fact_temp): fact_temp = None log.info(' --> Frame blurring: skipped') continue log.info(' --> Frame blurring: factor = %.3f' % fact_temp) for k in range(data.shape[0]): data[k] = gaussian_filter(data[k], fact_temp) erro[k] = gaussian_filter(erro[k], fact_temp) if mask is not None: mask = gaussian_filter(mask, fact_temp) else: log.info(' --> Frame blurring: skipped') # Write FITS file. if fact_temp is None: pass else: head_pri['BLURFWHM'] = fact_temp * np.sqrt(8. * np.log(2.)) # Factor to convert from sigma to FWHM fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) # Update spaceKLIP database. if fact_temp is None: self.database.update_obs(key, j, fitsfile, maskfile, blurfwhm=np.nan) else: self.database.update_obs(key, j, fitsfile, maskfile, blurfwhm=fact_temp * np.sqrt(8. * np.log(2.))) pass
[docs] def hpf(self, size='auto', types=['SCI', 'SCI_BG', 'REF', 'REF_BG'], subdir='filtered'): """ Blur frames with a Gaussian filter. Parameters ---------- size : 'auto' or float or dict of list of float or None, optional FWHM (pix) of the Gaussian filter. If 'auto', will compute the FWHM automatically based on the Nyquist sampling criterion for discrete data, which is FWHM = lambda / 2.3D, where D = 5.2 m for NIRCam coronagraphy and D = 6.5 m otherwise. If dict of list of float, then the dictionary keys must match the keys of the observations database, and the number of entries in the lists must match the number of observations in the corresponding concatenation. Then, a different FWHM can be used for each observation. If None, the corresponding observation will be skipped. The default is 'auto'. types : list of str, optional List of data types for which the frames shall be blurred. The default is ['SCI', 'SCI_BG', 'REF', 'REF_BG']. subdir : str, optional Name of the directory where the data products shall be saved. The default is 'blurred'. Returns ------- None. """ # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) # Loop through FITS files. Nfitsfiles = len(self.database.obs[key]) for j in range(Nfitsfiles): # Read FITS file. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) # Skip file types that are not in the list of types. size_temp = None if self.database.obs[key]['TYPE'][j] in types: # High-pass filter frames. head, tail = os.path.split(fitsfile) log.info(' --> Frame filtering: ' + tail) try: size_temp = size[key] except: size_temp = float(size) if size_temp is not None: log.info(' --> Frame filtering: HPF FWHM = %.2f pix' % size_temp) fourier_sigma_size = (data.shape[1] / size_temp) / (2. * np.sqrt(2. * np.log(2.))) data = parallelized.high_pass_filter_imgs(data, numthreads=None, filtersize=fourier_sigma_size) erro = parallelized.high_pass_filter_imgs(erro, numthreads=None, filtersize=fourier_sigma_size) else: log.info(' --> Frame filtering: skipped') # Write FITS file. if size_temp is None: pass else: head_pri['HPFSIZE'] = size_temp fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile) pass
[docs] def inject_companions(self, companions, starfile, spectral_type, output_dir, highpass=False, subdir='test', date='auto', kwargs={}): """ Function to inject synthetic PSFs into a set of frames loaded from a dataset, and save the new frames with the injected companion. Parameters ---------- companions : list of list of three float, optional List of companions to be injected. For each companion, there should be a three element list containing [RA offset (arcsec), Dec offset (arcsec), contrast]. raw_dataset : pyKLIP dataset A pyKLIP dataset which companions will be injected into and KLIP will be performed on. injection_psf : 2D-array The PSF of the companion to be injected. injection_seps : 1D-array List of separations to inject companions at (pixels). injection_pas : 1D-array List of position angles to inject companions at (degrees). injection_spacing : int, None Spacing between companions injected in a single image. If companions are too close then it can pollute the recovered flux. Set to 'None' to inject only one companion at a time (pixels). injection_fluxes : 1D-array Same size as injection_seps, units should correspond to the image units. This is the *peak* flux of the injection. true_companions : list of list of three float, optional List of real companions to be masked before computing the raw contrast. For each companion, there should be a three element list containing [RA offset (pixels), Dec offset (pixels), mask radius (pixels)]. The default is None. Returns ------- None """ # Check input. if not isinstance(companions[0], list): if len(companions) == 3: companions = [companions] for i in range(len(companions)): if len(companions[i]) != 3: raise UserWarning('There should be three elements for each companion in the companions list') Ncompanions = len(companions) for _, key in enumerate(self.database.obs.keys()): ww_type = list(self.database.obs[key]['TYPE']) list_of_injected = [] all_injected = False log.info('--> Concatenation ' + key) ####################################################################################################################### filepaths, psflib_filepaths = get_pyklip_filepaths(self.database, key) raw_dataset = JWSTData(filepaths, psflib_filepaths, center_keywords=['STARCENX', 'STARCENY']) for ww in range(len(ww_type)): # Read input files and store values that we just want to save in the output_dir fitsfile = self.database.obs[key]['FITSFILE'][ww] data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][ww] mask = ut.read_msk(maskfile) crpix1 = self.database.obs[key]['CRPIX1'][ww] crpix2 = self.database.obs[key]['CRPIX2'][ww] maskcenx = self.database.obs[key]['MASKCENX'][ww] maskceny = self.database.obs[key]['MASKCENY'][ww] starcenx = self.database.obs[key]['STARCENX'][ww] starceny = self.database.obs[key]['STARCENY'][ww] head, tail = os.path.split(fitsfile) # Write FITS file and PSF mask. head_sci['CRPIX1'] = crpix1 head_sci['CRPIX2'] = crpix2 head_sci['MASKCENX'] = maskcenx head_sci['MASKCENY'] = maskceny head_sci['STARCENX'] = starcenx head_sci['STARCENY'] = starceny # Inject only into SCI type data if ww_type[ww] == 'SCI': # Convert the host star brightness from vegamag to MJy. Use an # unocculted model PSF whose integrated flux is normalized to # one in order to obtain the theoretical peak count of the # star. filt = self.database.obs[key]['FILTER'][ww] # Get stellar magnitudes and filter zero points. mstar, fzero, fzero_si = get_stellar_magnitudes(starfile, spectral_type, self.database.obs[key]['INSTRUME'][ww], return_si=True, output_dir=output_dir, **kwargs) # vegamag, Jy, erg/cm^2/s/A # Compute the pixel area in steradian. pxsc_arcsec = self.database.obs[key]['PIXSCALE'][ww] # arcsec pxsc_rad = pxsc_arcsec / 3600. / 180. * np.pi # rad pxar = pxsc_rad ** 2 # sr # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) # Make a copy of the dataset dataset = copy.deepcopy(raw_dataset) apername = self.database.obs[key]['APERNAME'][ww] if 'planetfile' not in kwargs.keys() or kwargs['planetfile'] is None: if starfile is not None and starfile.endswith('.txt'): sed = read_spec_file(starfile) else: sed = None else: sed = read_spec_file(kwargs['planetfile']) ww_sci = np.where(self.database.obs[key]['TYPE'] == 'SCI')[0] if date is not None: if date == 'auto': date = fits.getheader(self.database.obs[key]['FITSFILE'][ww_sci[0]], 0)['DATE-BEG'] offsetpsf_func = JWST_PSF(apername, filt, date=date, fov_pix=65, oversample=2, sp=sed, use_coeff=False) # Offset PSF that is not affected by the coronagraphic # mask, but only the Lyot stop. psf_no_coronmsk = offsetpsf_func.psf_off log.info('--> Injecting companions, writing FITS files and updating spaceKLIP database: ') # Loop over companions for i in trange(Ncompanions): # Initial guesses for the fit parameters. guess_dx = companions[i][0] / pxsc_arcsec # pix guess_dy = companions[i][1] / pxsc_arcsec # pix guess_sep = np.sqrt(guess_dx ** 2 + guess_dy ** 2) # pix guess_pa = np.rad2deg(np.arctan2(guess_dx, guess_dy)) # deg guess_flux = companions[i][2] # contrast roll_ref = self.database.obs[key]['ROLL_REF'][ww] # deg # Get shift between star and coronagraphic mask # position. If positive, the coronagraphic mask center # is to the left/bottom of the star position. _, _, _, _, _, _, _, maskoffs = ut.read_obs(self.database.obs[key]['FITSFILE'][ww]) # NIRCam. if maskoffs is not None: mask_xoff = -maskoffs[:, 0] # pix mask_yoff = -maskoffs[:, 1] # pix # Need to rotate by the roll angle (CCW) and flip # the x-axis so that positive RA is to the left. mask_raoff = -(mask_xoff * np.cos(np.deg2rad(roll_ref)) - mask_yoff * np.sin( np.deg2rad(roll_ref))) # pix mask_deoff = mask_xoff * np.sin(np.deg2rad(roll_ref)) + mask_yoff * np.cos( np.deg2rad(roll_ref)) # pix # Compute the true offset between the companion and # the coronagraphic mask center. sim_dx = guess_dx - mask_raoff # pix sim_dy = guess_dy - mask_deoff # pix sim_sep = np.sqrt(sim_dx ** 2 + sim_dy ** 2) * pxsc_arcsec # arcsec sim_pa = np.rad2deg(np.arctan2(sim_dx, sim_dy)) # deg # Take median of observation. Typically, each # dither position is a separate observation. sim_sep = np.median(sim_sep) sim_pa = np.median(sim_pa) # Otherwise. else: sim_sep = np.sqrt(guess_dx ** 2 + guess_dy ** 2) * pxsc_arcsec # arcsec sim_pa = np.rad2deg(np.arctan2(guess_dx, guess_dy)) # deg # Generate offset PSF for this roll angle. Do not add # the V3Yidl angle as it has already been added to the # roll angle by spaceKLIP. This is only for estimating # the coronagraphic mask throughput! offsetpsf_coronmsk = offsetpsf_func.gen_psf([sim_sep, sim_pa], mode='rth', PA_V3=roll_ref, do_shift=False, quick=True, addV3Yidl=False) # Coronagraphic mask throughput is not incorporated # into the flux calibration of the JWST pipeline so # that the companion flux from the detector pixels will # be underestimated. Therefore, we need to scale the # model offset PSF to account for the coronagraphic # mask throughput (it becomes fainter). Compute scale # factor by comparing a model PSF with and without # coronagraphic mask. scale_factor = np.sum(offsetpsf_coronmsk) / np.sum(psf_no_coronmsk) # Normalize model offset PSF to a total integrated flux # of 1 at infinity. Generates a new webbpsf model with # PSF normalization set to 'exit_pupil'. offsetpsf = offsetpsf_func.gen_psf([sim_sep, sim_pa], mode='rth', PA_V3=roll_ref, do_shift=False, quick=False, addV3Yidl=False, normalize='exit_pupil') # Normalize model offset PSF by the flux of the star. mcomp = mstar[filt] - 2.5*np.log10(guess_flux) offsetpsf *= fzero[filt] / 10 ** (mcomp / 2.5) / 1e6 / pxar # MJy/sr # Apply scale factor to incorporate the coronagraphic # mask througput. offsetpsf *= scale_factor # For Test only, we apply a gaussian kernel to the psf we want to inject to test if we are able # to recover it later when using Analysis.extract_companions if 'sigma_xy' in kwargs.keys(): if 'theta_degrees' not in kwargs.keys(): kwargs['theta_degrees'] = 0 sigma_xy = kwargs['sigma_xy'] theta_degrees = kwargs['theta_degrees'] kernel = gaussian_kernel(sigma_x=sigma_xy[0], sigma_y=sigma_xy[1], theta_degrees=theta_degrees, n=6) offsetpsf = scipy.ndimage.convolve(offsetpsf, kernel) # Injected PSF needs to be a 3D array that matches dataset inj_psf_3d = np.array([offsetpsf for k in range(dataset.input.shape[0])]) # Inject the PSF fakes.inject_planet(frames=dataset.input, centers=dataset.centers, inputflux=inj_psf_3d, astr_hdrs=dataset.wcs, radius=guess_sep, pa=guess_pa, stampsize=65) data = dataset.input # Write FITS file and PSF mask. fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) # Update spaceKLIP database. self.database.update_obs(key, ww, fitsfile, maskfile, crpix1=crpix1, crpix2=crpix2, starcenx=starcenx, starceny=starceny,maskcenx=maskcenx, maskceny=maskceny)
[docs] def update_nircam_centers(self, force_siaf_center=False, force_db_center=False): """ Checks SIAF PRD version against FITS header PRD version and updates CRPIX if SIAF version is newer. Also accounts for filter-dependent distortion. Might not be required for simulated data. This step uses lookup tables of information derived from NIRCam commissioning activities CAR-30 and CAR-31, by J. Leisenring and J. Girard, and subsequent reanalyses using additional data from PSF SGD observations. Parameters ---------- force_siaf_center : bool, optional Force the use of the SIAF reference pixel position irrespective of versions. The default is False force_db_center : bool, optional Force the use of the database reference pixel position irrespective of versions. The default is False Returns ------- None. """ if force_siaf_center and force_db_center: raise UserWarning('Both force_siaf_center and force_db_center are set to True. Only one can be True.') if not force_db_center: siaf = pysiaf.Siaf('NIRCAM') # Use same RegEx as pysiaf to get sorted PRDS for later comparison. prds = [prd for i, prd in enumerate(pysiaf.prd_list) if bool(re.match(r"^[A-Z]-\d+", pysiaf.prd_list[i].split("PRDOPSSOC-")[1])) is False] # PRDs matching format: PRODOSSOC-### # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) # Loop through FITS files. for j in range(len(self.database.obs[key])): # Skip file types that are not NIRCam coronagraphy. if self.database.obs[key]['EXP_TYPE'][j] == 'NRC_CORON': # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) # Update current reference pixel position. head, tail = os.path.split(fitsfile) log.info(' --> Update NIRCam coronagraphy centers: ' + tail) # Get PRD version used for the current file. file_prd_ver = head_pri['PRD_VER'] # use SIAF for reference pixel positions if its PRD is # newer or if force_siaf_center unless force_db_center. if (not force_db_center) and ((np.searchsorted(prds, file_prd_ver) < np.searchsorted(prds, pysiaf.JWST_PRD_VERSION)) or force_siaf_center): log.info(' --> Update NIRCam coronagraphy centers: using MASKCEN from pysiaf') apsiaf = siaf[self.database.obs[key]['APERNAME'][j]] maskcenx = apsiaf.XSciRef maskceny = apsiaf.YSciRef else: log.info(' --> Update NIRCam coronagraphy centers: using MASKCEN from database') maskcenx = self.database.obs[key]['MASKCENX'][j] maskceny = self.database.obs[key]['MASKCENY'][j] # Get filter shift from Jarron. try: xshift_jarron, yshift_jarron = filter_shifts_jarron[self.database.obs[key]['FILTER'][j]] except KeyError: log.warning(' --> Update NIRCam coronagraphy centers: no filter shift found for ' + self.database.obs[key]['FILTER'][j]) xshift_jarron, yshift_jarron = 0., 0. xoff, yoff = xshift_jarron, yshift_jarron log.info(' --> Update NIRCam coronagraphy centers: old = (%.2f, %.2f), new = (%.2f, %.2f)' % (maskcenx, maskceny, maskcenx + xoff, maskceny + yoff)) maskcenx += xoff maskceny += yoff # Update spaceKLIP database. # Change also CRPIX to be the same as maskcen. self.database.update_obs(key, j, fitsfile, maskfile, crpix1=maskcenx, crpix2=maskceny, maskcenx=maskcenx, maskceny=maskceny) pass
[docs] def update_miri_offsets(self): """ Updates SCI frame X/Y offsets to zero for older MIRI coronagraphy datasets, and updates REF frame offsets accordingly in FITS headers and the database. """ # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) # Find science and reference files. ww_sci = np.where(self.database.obs[key]['TYPE'] == 'SCI')[0] ww_ref = np.where(self.database.obs[key]['TYPE'] == 'REF')[0] ww_all = np.append(ww_sci, ww_ref) # Set the SCI offsets to 0. for j in ww_sci: xoffset_fix = self.database.obs[key]['XOFFSET'][j] # arcsec yoffset_fix = self.database.obs[key]['YOFFSET'][j] # arcsec continue # Loop through FITS files. for j in ww_all: # Skip file types that are not MIRI coronagraphy. if self.database.obs[key]['EXP_TYPE'][j] in ['MIR_4QPM', 'MIR_LYOT']: # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] # Update MIRI mask offsets. head, tail = os.path.split(fitsfile) log.info(' --> Update MIRI coronagraphy offsets: ' + tail) # Apply offset fixes to SCI and REF data. xoffset_old = self.database.obs[key]['XOFFSET'][j] # arcsec yoffset_old = self.database.obs[key]['YOFFSET'][j] # arcsec xoffset_new = xoffset_old - xoffset_fix # arcsec yoffset_new = yoffset_old - yoffset_fix # arcsec log.info(f' --> Update MIRI coronagraphy offsets: old = ({xoffset_old:.3g}, {yoffset_old:.3g}), new = ({xoffset_new:.3g}, {yoffset_new:.3g})') head_pri['XOFFSET'] = xoffset_new head_pri['YOFFSET'] = yoffset_new output_dir = os.path.dirname(self.database.obs[key]['FITSFILE'][j]) fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs) # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile, xoffset=xoffset_new, yoffset=yoffset_new) pass
[docs] def recenter_frames(self, method='fourier', subpix_first_sci_only=False, first_sci_only=True, spectral_type='G2V', shft_exp=1, kwargs={}, highpass=False, subdir='recentered'): """ Recenter frames so that the host star position is data.shape // 2. For NIRCam coronagraphy, use a WebbPSF model to determine the star position behind the coronagraphic mask for the first SCI frame. Then, shift all other SCI and REF frames by the same amount. For MIRI coronagraphy, do nothing. For all other data types, simply recenter the host star PSF. Parameters ---------- method : 'fourier' or 'spline' (not recommended), optional Method for shifting the frames. The default is 'fourier'. subpix_first_sci_only : bool, optional By default, all frames will be recentered to subpixel precision. If 'subpix_first_sci_only' is True, then only the first SCI frame will be recentered to subpixel precision and all other SCI and REF frames will only be recentered to integer pixel precision by rolling the image. Can be helpful when working with poorly sampled data to avoid another interpolation step if the 'align_frames' routine is run subsequently. Only applicable to non-coronagraphic data. The default is False. first_sci_only : bool, optional Recenter all files and not just the first SCI file in each concate- nation. Only applicable to NIRCam coronagraphy. The default is True. spectral_type : str, optional Host star spectral type for the WebbPSF model used to determine the star position behind the coronagraphic mask. The default is 'G2V'. shft_exp : float, optional Take image to the given power before cross correlating for shifts, default is 1. For instance, 1/2 helps align nircam bar/narrow data (or other data with weird speckles). kwargs : dict, optional Keyword arguments for the scipy.ndimage.shift routine. The default is {}. subdir : str, optional Name of the directory where the data products shall be saved. The default is 'recentered'. Returns ------- None. """ # DEPRECATION WARNING log.warning('This function is deprecated. Use `calculate_centers` and `shift_frames` instead.') # Update NIRCam coronagraphy centers, i.e., change SIAF CRPIX position # to true mask center determined by Jarron # self.update_nircam_centers() # Shall be run purposely by the user. # Update MIRI coronagraphy offsets, of older data. # self.update_miri_offsets() # Shall be run purposely by the user. # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) # Find science and reference files. ww_sci = np.where(self.database.obs[key]['TYPE'] == 'SCI')[0] ww_sci_ta = np.where(self.database.obs[key]['TYPE'] == 'SCI_TA')[0] ww_ref = np.where(self.database.obs[key]['TYPE'] == 'REF')[0] ww_ref_ta = np.where(self.database.obs[key]['TYPE'] == 'REF_TA')[0] # Loop through FITS files. ww_all = np.append(ww_sci, ww_ref) ww_all = np.append(ww_all, ww_sci_ta) ww_all = np.append(ww_all, ww_ref_ta) shifts_all = [] for j in ww_all: # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] (data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs) = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) # Recenter frames. Use different algorithms based on data type. head, tail = os.path.split(fitsfile) log.info(' --> Recenter frames: ' + tail) if np.sum(np.isnan(data)) != 0: raise UserWarning('Please replace nan pixels before attempting to recenter frames') shifts = [] # Shift between star position and image center (data.shape // 2) maskoffs_temp = [] # Shift between star and coronagraphic mask position # SCI and REF data. if j in ww_sci or j in ww_ref: # NIRCam coronagraphy. if self.database.obs[key]['EXP_TYPE'][j] in ['NRC_CORON']: for k in range(data.shape[0]): # For the first SCI frame, get the star position # and the shift between the star and coronagraphic # mask position. if (not first_sci_only or j == ww_sci[0]) and k == 0: xc, yc, xshift, yshift = self.find_nircam_centers(data0=data.copy(), key=key, j=j, shft_exp=shft_exp, spectral_type=spectral_type, date=head_pri['DATE-BEG'], output_dir=output_dir, highpass=highpass) # Apply the same shift to all SCI and REF frames. shifts += [np.array([-(xc - (data.shape[-1] - 1.) / 2.), -(yc - (data.shape[-2] - 1.) / 2.)])] maskoffs_temp += [np.array([xshift, yshift])] data[k] = ut.imshift(data[k], [shifts[k][0], shifts[k][1]], method=method, kwargs=kwargs) erro[k] = ut.imshift(erro[k], [shifts[k][0], shifts[k][1]], method=method, kwargs=kwargs) if mask is not None: # mask = ut.imshift(mask, [shifts[k][0], shifts[k][1]], method=method, kwargs=kwargs) mask = spline_shift(mask, [shifts[k][1], shifts[k][0]], order=0, mode='constant', cval=np.nanmedian(mask)) xoffset = self.database.obs[key]['XOFFSET'][j] - self.database.obs[key]['XOFFSET'][ww_sci[0]] # arcsec yoffset = self.database.obs[key]['YOFFSET'][j] - self.database.obs[key]['YOFFSET'][ww_sci[0]] # arcsec # Update star center starcenx = (data.shape[-1] - 1.) / 2. + 1. # 1-indexed starceny = (data.shape[-2] - 1.) / 2. + 1. # 1-indexed # Update mask center (using the shift of the first frame) maskcenx = starcenx - xshift # 1-indexed maskceny = starceny - yshift # 1-indexed # MIRI coronagraphy. elif self.database.obs[key]['EXP_TYPE'][j] in ['MIR_4QPM', 'MIR_LYOT']: log.warning(' --> Recenter frames: not implemented for MIRI coronagraphy, skipped') for k in range(data.shape[0]): # Do nothing. shifts += [np.array([0., 0.])] maskoffs_temp += [np.array([0., 0.])] xoffset = self.database.obs[key]['XOFFSET'][j] # arcsec yoffset = self.database.obs[key]['YOFFSET'][j] # arcsec # Star center and mask center stay the same starcenx = self.database.obs[key]['STARCENX'][j] # 1-indexed starceny = self.database.obs[key]['STARCENY'][j] # 1-indexed maskcenx = self.database.obs[key]['MASKCENX'][j] # 1-indexed maskceny = self.database.obs[key]['MASKCENY'][j] # 1-indexed # Other data types. else: for k in range(data.shape[0]): # Recenter SCI and REF frames to subpixel precision # using the 'BCEN' routine from XARA. # https://github.com/fmartinache/xara if subpix_first_sci_only == False or (j == ww_sci[0] and k == 0): pp = core.determine_origin(data[k], algo='BCEN') shifts += [np.array([-(pp[0] - data.shape[-1]//2), -(pp[1] - data.shape[-2]//2)])] maskoffs_temp += [np.array([0., 0.])] data[k] = ut.imshift(data[k], [shifts[k][0], shifts[k][1]], method=method, kwargs=kwargs) erro[k] = ut.imshift(erro[k], [shifts[k][0], shifts[k][1]], method=method, kwargs=kwargs) else: shifts += [np.array([0., 0.])] maskoffs_temp += [np.array([0., 0.])] # Recenter SCI and REF frames to integer pixel # precision by rolling the image. ww_max = np.unravel_index(np.argmax(data[k]), data[k].shape) if ww_max != (data.shape[-2]//2, data.shape[-1]//2): dx, dy = data.shape[-1]//2 - ww_max[1], data.shape[-2]//2 - ww_max[0] shifts[-1][0] += dx shifts[-1][1] += dy data[k] = np.roll(np.roll(data[k], dx, axis=1), dy, axis=0) erro[k] = np.roll(np.roll(erro[k], dx, axis=1), dy, axis=0) xoffset = 0. # arcsec yoffset = 0. # arcsec # Update star center starcenx = data.shape[-1]//2 + 1 # 1-indexed starceny = data.shape[-2]//2 + 1 # 1-indexed maskcenx = None maskceny = None # TA data. if j in ww_sci_ta or j in ww_ref_ta: for k in range(data.shape[0]): # Center TA frames on the nearest pixel center. This # pixel center is not necessarily the image center, # which is why a subsequent integer pixel recentering # is required. p0 = np.array([0., 0.]) pp = minimize(ut.recenterlsq, p0, args=(data[k], method, kwargs))['x'] shifts += [np.array([pp[0], pp[1]])] maskoffs_temp += [np.array([0., 0.])] data[k] = ut.imshift(data[k], [shifts[k][0], shifts[k][1]], method=method, kwargs=kwargs) erro[k] = ut.imshift(erro[k], [shifts[k][0], shifts[k][1]], method=method, kwargs=kwargs) # Recenter TA frames to integer pixel precision by # rolling the image. ww_max = np.unravel_index(np.argmax(data[k]), data[k].shape) if ww_max != (data.shape[-2]//2, data.shape[-1]//2): dx, dy = data.shape[-1]//2 - ww_max[1], data.shape[-2]//2 - ww_max[0] shifts[-1][0] += dx shifts[-1][1] += dy data[k] = np.roll(np.roll(data[k], dx, axis=1), dy, axis=0) erro[k] = np.roll(np.roll(erro[k], dx, axis=1), dy, axis=0) xoffset = 0. # arcsec yoffset = 0. # arcsec starcenx = data.shape[-1]//2 + 1 # 1-indexed starceny = data.shape[-2]//2 + 1 # 1-indexed maskcenx = None maskceny = None shifts = np.array(shifts) shifts_all += [shifts] maskoffs_temp = np.array(maskoffs_temp) if imshifts is not None: imshifts += shifts else: imshifts = shifts if maskoffs is not None: maskoffs += maskoffs_temp else: maskoffs = maskoffs_temp # Compute shift distances. dist = np.sqrt(np.sum(shifts[:, :2]**2, axis=1)) # pix dist *= self.database.obs[key]['PIXSCALE'][j] * 1000 # mas head, tail = os.path.split(self.database.obs[key]['FITSFILE'][j]) log.info(' --> Recenter frames: ' + tail) log.info(' --> Recenter frames: median required shift = %.2f mas' % np.median(dist)) # Write FITS file and PSF mask. head_pri['XOFFSET'] = xoffset # arcsec head_pri['YOFFSET'] = yoffset # arcsec head_sci['STARCENX'] = starcenx head_sci['STARCENY'] = starceny if maskcenx is not None: head_sci['MASKCENX'] = maskcenx head_sci['MASKCENY'] = maskceny fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile, xoffset=xoffset, yoffset=yoffset, starcenx=starcenx, starceny=starceny, maskcenx=maskcenx, maskceny=maskceny) pass
[docs] def calculate_centers(self, method='fourier', use_ta=False, subpix_first_sci_only=False, first_sci_only=True, spectral_type='G2V', shft_exp=1, kwargs={}, highpass=False, subdir='recentered'): """ Calculate shifts necessary to recenter frames so that the host star position is data.shape // 2. For NIRCam coronagraphy, use either a WebbPSF model or target acquisition data (if available) to determine the star's position behind the coronagraphic mask in the first SCI frame. For MIRI coronagraphy, use target acquisition data (if available) to determine the star's position behind the coronagraphic mask in the first SCI frame; otherwise, no shift is applied. Then, shift all other SCI and REF frames by the same amount. For all other data types, simply recenter the host star PSF. Parameters ---------- method : 'fourier' or 'spline' (not recommended), optional Method for shifting the frames. The default is 'fourier'. use_ta: bool, optional Use target acquisition data to determine the star's position behind the coronagraphic mask in the first SCI frame? The default is False. subpix_first_sci_only : bool, optional By default, all frames will be recentered to subpixel precision. If 'subpix_first_sci_only' is True, then only the first SCI frame will be recentered to subpixel precision and all other SCI and REF frames will only be recentered to integer pixel precision by rolling the image. Can be helpful when working with poorly sampled data to avoid another interpolation step if the 'align_frames' routine is run subsequently. Only applicable to non-coronagraphic data. The default is False. first_sci_only : bool, optional Recenter all files and not just the first SCI file in each concate- nation. Only applicable to NIRCam/MIRI coronagraphy. The default is True. spectral_type : str, optional Host star spectral type for the WebbPSF model used to determine the star position behind the coronagraphic mask. The default is 'G2V'. shft_exp : float, optional Take image to the given power before cross correlating for shifts, default is 1. For instance, 1/2 helps align NIRCam bar/narrow data (or other data with weird speckles). kwargs : dict, optional Keyword arguments for the scipy.ndimage.shift routine. The default is {}. subdir : str, optional Name of the directory where the data products shall be saved. The default is 'recentered'. Returns ------- None. """ # Update NIRCam coronagraphy centers, i.e., change SIAF CRPIX/MASKCEN position # to true mask center determined by Jarron. # self.update_nircam_centers() # Shall be run purposely by the user. # Update MIRI coronagraphy offsets, of older data. # self.update_miri_offsets() # Shall be run purposely by the user. # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) # Find science and reference files. ww_sci = np.where(self.database.obs[key]['TYPE'] == 'SCI')[0] ww_sci_ta = np.where(self.database.obs[key]['TYPE'] == 'SCI_TA')[0] ww_ref = np.where(self.database.obs[key]['TYPE'] == 'REF')[0] ww_ref_ta = np.where(self.database.obs[key]['TYPE'] == 'REF_TA')[0] # Loop through FITS files. ww_all = np.append(ww_sci, ww_ref) ww_all = np.append(ww_all, ww_sci_ta) ww_all = np.append(ww_all, ww_ref_ta) shifts_all = [] for j in ww_all: # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] (data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs) = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) # Find center of frames. Use different algorithms based on data type. head, tail = os.path.split(fitsfile) log.info(' --> Calculating centers: ' + tail) if np.sum(np.isnan(data)) != 0: raise UserWarning('Please replace nan pixels before attempting to calculate centers') shifts = [] # Shift between star position and image center (data.shape // 2). maskoffs_temp = [] # Offset between star and coronagraphic mask center. mask_shifts = [] # Shift applied to the PSF mask to match the recentered image. # SCI and REF data. if j in ww_sci or j in ww_ref: # NIRCam coronagraphy. if self.database.obs[key]['EXP_TYPE'][j] in ['NRC_CORON']: for k in range(data.shape[0]): # Mask center remains the same. maskcenx = self.database.obs[key]['MASKCENX'][j] # 1-indexed maskceny = self.database.obs[key]['MASKCENY'][j] # 1-indexed xoffset = self.database.obs[key]['XOFFSET'][j] # arcsec yoffset = self.database.obs[key]['YOFFSET'][j] # arcsec # For the first SCI/REF frame, get the star position # and the shift between the star and coronagraphic # mask position. if (not first_sci_only or j == ww_sci[0]) and k == 0: # Use TA data to determine the star position behind the coronagraphic mask. if use_ta: xc, yc = ta_analysis(self.database.obs[key][j]['FITSFILE'], plot=True, verbose=True, output_dir=output_dir) # 0 indexed. # Calculate the star - mask offset. xshift, yshift = (xc - (maskcenx - 1), yc - (maskceny - 1)) # Use WebbPSF model to determine the star position behind the coronagraphic mask. else: xc, yc, xshift, yshift = self.find_nircam_centers(data0=data.copy(), key=key, j=j, shft_exp=shft_exp, spectral_type=spectral_type, date=head_pri['DATE-BEG'], output_dir=output_dir, highpass=highpass) # Apply the same shift to all SCI and REF frames and masks. shifts += [np.array([-(xc - (data.shape[-1] - 1.) / 2.), -(yc - (data.shape[-2] - 1.) / 2.)])] mask_shifts += [np.array([-(xc - (data.shape[-1] - 1.) / 2.), -(yc - (data.shape[-2] - 1.) / 2.)])] maskoffs_temp += [np.array([xshift, yshift])] # pixels if first_sci_only: xoffset -= self.database.obs[key]['XOFFSET'][ww_sci[0]] # arcsec yoffset -= self.database.obs[key]['YOFFSET'][ww_sci[0]] # arcsec log.info(' --> Calculate centers: adjusted XOFFSET/YOFFSET relative to first SCI frame.') else: # XOFFSET/YOFFSET remain the same. log.info(' --> Calculate centers: no adjustment made to XOFFSET/YOFFSET, all SCI/REF files recenetered individually.') # Set star center (image center - shift). starcenx = (data.shape[-1] - 1) / 2. - shifts[0][0] + 1 # 1-indexed starceny = (data.shape[-2] - 1) / 2. - shifts[0][1] + 1 # 1-indexed # MIRI coronagraphy. elif self.database.obs[key]['EXP_TYPE'][j] in ['MIR_4QPM', 'MIR_LYOT']: for k in range(data.shape[0]): # Mask center remains the same. maskcenx = self.database.obs[key]['MASKCENX'][j] # 1-indexed maskceny = self.database.obs[key]['MASKCENY'][j] # 1-indexed xoffset = self.database.obs[key]['XOFFSET'][j] # arcsec yoffset = self.database.obs[key]['YOFFSET'][j] # arcsec # For the first SCI/REF frame, get the star position # and the shift between the star and coronagraphic # mask position (if TA data avaliable else None). if (not first_sci_only or j == ww_sci[0]) and k == 0: # Use TA data to determine the star position behind the coronagraphic mask. if use_ta: xc, yc = ta_analysis(self.database.obs[key][j]['FITSFILE'], plot=True, verbose=True, output_dir=output_dir) # 0 indexed. # Adjust for any cropping done prior. xc -= self.database.obs[key]['CROP_SHIFTX'][j] # Adjust X for left cropping; 0 indexed. yc -= self.database.obs[key]['CROP_SHIFTY'][j] # Adjust Y for bottom cropping; 0 indexed. # Calculate the star - mask offset. xshift, yshift = (xc - (maskcenx - 1), yc - (maskceny - 1)) # Do nothing. else: log.warning(' --> Calculate centers: not implemented for MIRI coronagraphy unless use_ta=True, skipped') xc, yc = (data.shape[-1] - 1.) / 2., (data.shape[-2] - 1.) / 2. # Shift will be 0. xshift, yshift = 0., 0. # Apply the same shift to all SCI and REF frames and masks. shifts += [np.array([-(xc - (data.shape[-1] - 1.) / 2.), -(yc - (data.shape[-2] - 1.) / 2.)])] mask_shifts += [np.array([-(xc - (data.shape[-1] - 1.) / 2.), -(yc - (data.shape[-2] - 1.) / 2.)])] maskoffs_temp += [np.array([xshift, yshift])] if first_sci_only: xoffset -= self.database.obs[key]['XOFFSET'][ww_sci[0]] # arcsec yoffset -= self.database.obs[key]['YOFFSET'][ww_sci[0]] # arcsec log.info(' --> Calculate centers: adjusted XOFFSET/YOFFSET relative to first SCI frame.') else: # XOFFSET/YOFFSET remain the same. log.info(' --> Calculate centers: no adjustment made to XOFFSET/YOFFSET, all SCI/REF files measured individually.') # Star center. # If TA was used, shift from image center; otherwise, use provided header values. starcenx = ((data.shape[-1] - 1) / 2. - shifts[0][0] + 1) if use_ta else self.database.obs[key]['STARCENX'][j] # 1-indexed starceny = ((data.shape[-2] - 1) / 2. - shifts[0][1] + 1) if use_ta else self.database.obs[key]['STARCENY'][j] # 1-indexed # Other data types. else: for k in range(data.shape[0]): # Recenter SCI and REF frames to subpixel precision # using the 'BCEN' routine from XARA. # https://github.com/fmartinache/xara if subpix_first_sci_only == False or (j == ww_sci[0] and k == 0): pp = core.determine_origin(data[k], algo='BCEN') shifts += [np.array([-(pp[0] - data.shape[-1]//2), -(pp[1] - data.shape[-2]//2)])] mask_shifts += [np.array([0., 0.])] maskoffs_temp += [np.array([0., 0.])] else: shifts += [np.array([0., 0.])] mask_shifts += [np.array([0., 0.])] maskoffs_temp += [np.array([0., 0.])] xoffset = 0. # arcsec yoffset = 0. # arcsec # Update star center (image center - shift). starcenx = (data.shape[-1] - 1) / 2. - shifts[0][0] + 1 # 1-indexed starceny = (data.shape[-2] - 1) / 2. - shifts[0][1] + 1 # 1-indexed maskcenx = None maskceny = None # TA data. if j in ww_sci_ta or j in ww_ref_ta: for k in range(data.shape[0]): # Center TA frames on the nearest pixel center. This # pixel center is not necessarily the image center, # which is why a subsequent integer pixel recentering # is required. p0 = np.array([0., 0.]) pp = minimize(ut.recenterlsq, p0, args=(data[k], method, kwargs))['x'] shifts += [np.array([pp[0], pp[1]])] mask_shifts += [np.array([0., 0.])] maskoffs_temp += [np.array([0., 0.])] xoffset = 0. # arcsec yoffset = 0. # arcsec # Update star center (image center - shift). starcenx = (data.shape[-1] - 1) / 2. - shifts[0][0] + 1 # 1-indexed starceny = (data.shape[-2] - 1) / 2. - shifts[0][1] + 1 # 1-indexed maskcenx = None maskceny = None shifts = np.array(shifts) shifts_all += [shifts] maskoffs_temp = np.array(maskoffs_temp) maskshifts_temp = np.median(mask_shifts, axis=0) if center_shift is not None: center_shift += shifts else: center_shift = shifts if center_mask is not None: center_mask += maskshifts_temp else: center_mask = maskshifts_temp if maskoffs is not None: maskoffs += maskoffs_temp else: maskoffs = maskoffs_temp # Compute shift distances. dist = np.sqrt(np.sum(shifts[:, :2]**2, axis=1)) # pix dist *= self.database.obs[key]['PIXSCALE'][j] * 1000 # mas head, tail = os.path.split(self.database.obs[key]['FITSFILE'][j]) log.info(' --> Calculate centers: median measured shift = %.2f mas' % np.median(dist)) # Write FITS file and PSF mask. head_pri['XOFFSET'] = xoffset # arcsec head_pri['YOFFSET'] = yoffset # arcsec head_sci['STARCENX'] = starcenx head_sci['STARCENY'] = starceny if maskcenx is not None: head_sci['MASKCENX'] = maskcenx head_sci['MASKCENY'] = maskceny # Reading in CRPIX1/2 from database for updates from update_nircam_centers. head_sci['CRPIX1'] = self.database.obs[key]['CRPIX1'][j] head_sci['CRPIX2'] = self.database.obs[key]['CRPIX2'][j] fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile, xoffset=xoffset, yoffset=yoffset, starcenx=starcenx, starceny=starceny, maskcenx=maskcenx, maskceny=maskceny, center_shift=center_shift, center_mask=center_mask) pass
[docs] def resample_frames(self, subdir='resampled'): ''' Resample frames applying distortion correction to the data Parameters ---------- subdir : str, optional Name of the directory where the data products shall be saved. The default is 'recentered'. Returns ------- None. ''' def to_container(model, target=True): """Convert to a ModelContainer of ImageModels for each plane""" container = ModelContainer() if target: attr_list = [ 'data', 'dq', 'err', 'zeroframe', 'area', 'var_poisson', 'var_rnoise', 'var_flat' ] else: # model = model attr_list = [ 'data' ] for plane in range(model.data.shape[0]): image = datamodels.ImageModel() for attribute in attr_list: try: setattr(image, attribute, model.getarray_noinit(attribute)[plane]) except AttributeError: pass image.update(model) try: image.meta.wcs = model.meta.wcs except AttributeError: pass container.append(image) return container # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): # Read FITS file and PSF mask. log.info('--> Concatenation ' + key) # Find science and reference files. ww_sci = np.where(self.database.obs[key]['TYPE'] == 'SCI')[0] if len(ww_sci) == 0: raise UserWarning('Could not find any science files') ww_ref = np.where(self.database.obs[key]['TYPE'] == 'REF')[0] ww_all = np.append(ww_sci, ww_ref) # Make ASN file. log.info(' --> Resampling: ' + key) resample_step.ResampleStep.blendheaders = False for j in ww_all: target_file = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs = ut.read_obs(target_file) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) with datamodels.open(target_file) as target: # #storing wcs information for mask resampling # wcs = target.meta.wcs # wcsinfo = target.meta.wcsinfo data_list = [] dq_list = [] err_list = [] for model in to_container(target): resample_input = ModelContainer() resample_input.append(model) # Call the resample step to combine all psf-subtracted target images # for compatibility with image3 pipeline use of ModelLibrary, # convert ModelContainer to ModelLibrary resample_library = ModelLibrary(resample_input, on_disk=False) # Output is a single datamodel result = resample_step.ResampleStep.call(resample_library) data_list.append(result.data) dq_list.append(result.dq) err_list.append(result.err) target.data = np.array([mask]) for model in to_container(target, target=False): resample_input = ModelContainer() resample_input.append(model) # Call the resample step to combine all psf-subtracted target images # for compatibility with image3 pipeline use of ModelLibrary, # convert ModelContainer to ModelLibrary resample_library = ModelLibrary(resample_input, on_disk=False) # Output is a single datamodel result = resample_step.ResampleStep.call(resample_library) mask = result.data # Write FITS file and PSF mask. fitsfile = ut.write_obs(target_file, output_dir, data_list, err_list, dq_list, head_pri, head_sci, is2d, imshifts, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile) pass
[docs] @plt.style.context('spaceKLIP.sk_style') def find_nircam_centers(self, data0, key, j, spectral_type='G2V', shft_exp=1, date=None, output_dir=None, fov_pix=65, oversample=2, use_coeff=False, highpass=False, save_figures=True): """ Find the star position behind the coronagraphic mask using a WebbPSF model. Parameters ---------- data0 : list List of frame for which the star position shall be determined. key : str Database key of the observation containing the first frame in data0. j : int Database index of the observation containing the first frame in data0. spectral_type : str, optional Host star spectral type for the WebbPSF model used to determine the star position behind the coronagraphic mask. The default is 'G2V'. shft_exp : float, optional Take image to the given power before cross correlating for shifts, default is 1. date : str, optional Observation date in the format 'YYYY-MM-DDTHH:MM:SS.MMM'. The default is None. output_dir : path, optional Path of the directory where the data products shall be saved. The default is None. oversample : int, optional Factor by which the WebbPSF model shall be oversampled. The default is 2. use_coeff : bool, optional Use pre-computed coefficients to generate the WebbPSF model. The default is False. save_figures : bool, optional Save the plots in a PDF? Returns ------- xc : float Star x-position (pix, 0-indexed). yc : float Star y-position (pix, 0-indexed). xshift : float X-shift between star and coronagraphic mask position (pix). yshift : float Y-shift between star and coronagraphic mask position (pix). """ # Generate host star spectrum. spectrum = webbpsf_ext.stellar_spectrum(spectral_type) # Get true mask center. maskcenx = self.database.obs[key]['MASKCENX'][j] - 1 # 0-indexed maskceny = self.database.obs[key]['MASKCENY'][j] - 1 # 0-indexed # Initialize JWST_PSF object. Use odd image size so that PSF is # centered in pixel center. log.info(' --> Recenter frames: generating WebbPSF image for absolute centering (this might take a while)') FILTER = self.database.obs[key]['FILTER'][j] APERNAME = self.database.obs[key]['APERNAME'][j] kwargs = { 'fov_pix': fov_pix, 'oversample': oversample, 'date': date, 'use_coeff': use_coeff, 'sp': spectrum } psf = JWST_PSF(APERNAME, FILTER, **kwargs) # Get SIAF reference pixel position. apsiaf = psf.inst_on.siaf_ap xsciref, ysciref = (apsiaf.XSciRef, apsiaf.YSciRef) # Generate model PSF. Apply offset between SIAF reference pixel # position and true mask center. xoff = (maskcenx + 1) - xsciref yoff = (maskceny + 1) - ysciref model_psf = psf.gen_psf_idl((0, 0), coord_frame='idl', return_oversample=False, quick=True) if not isinstance(highpass, bool): highpass = float(highpass) fourier_sigma_size = (model_psf.shape[0] / highpass) / (2. * np.sqrt(2. * np.log(2.))) model_psf = parallelized.high_pass_filter_imgs(np.array([model_psf]), numthreads=None, filtersize=fourier_sigma_size)[0] else: if highpass: raise NotImplementedError() if not np.isnan(self.database.obs[key]['BLURFWHM'][j]): gauss_sigma = self.database.obs[key]['BLURFWHM'][j] / np.sqrt(8. * np.log(2.)) model_psf = gaussian_filter(model_psf, gauss_sigma) shift_list = [] count = 0 for data in data0: # Get transmission mask. yi, xi = np.indices(data.shape) xidl, yidl = apsiaf.sci_to_idl(xi + 1 - xoff, yi + 1 - yoff) mask = psf.inst_on.gen_mask_transmission_map((xidl, yidl), 'idl') # Determine relative shift between data and model PSF. Iterate 3 times # to improve precision. xc, yc = (maskcenx, maskceny) # Start assuming that star is exactly at the center of the coronagraph. for i in range(3): # Crop data and transmission mask. datasub, xsub_indarr, ysub_indarr = ut.crop_image(image=data, xycen=(xc, yc), npix=fov_pix, return_indices=True) masksub = ut.crop_image(image=mask, xycen=(xc, yc), npix=fov_pix) if shft_exp == 1: img1 = datasub * masksub img2 = model_psf * masksub else: img1 = np.power(np.abs(datasub), shft_exp) * masksub img2 = np.power(np.abs(model_psf), shft_exp) * masksub # Determine relative shift between data and model PSF. shift, error, phasediff = phase_cross_correlation(img1, img2, upsample_factor=1000, normalization=None) yshift, xshift = shift # Update star position. xc = np.mean(xsub_indarr) + xshift yc = np.mean(ysub_indarr) + yshift xshift, yshift = (xc - maskcenx, yc - maskceny) shift_list.append([xshift, yshift]) log.info(' --> Recenter frames: star offset between frame %i and coronagraph center (dx, dy) = (%.3f, %.3f) pix' % (count, xshift, yshift)) count += 1 median_xshift, median_yshift = np.median(np.array(shift_list), axis=0) std_xshift, std_yshift = np.std(np.array(shift_list), axis=0) log.info(' --> Recenter frames: median star offset from coronagraph center (dx, dy) = (%.3f, %.3f) pix' % (median_xshift, median_yshift)) log.info(' --> Recenter frames: std for the star offset from coronagraph center (dx, dy) = (%.3f, %.3f) pix' % (std_xshift, std_yshift)) # Plot data, model PSF, and scene overview. if output_dir is not None: fig, ax = plt.subplots(1, 3, figsize=(3 * 6.4, 1 * 4.8)) ax[0].imshow(datasub, origin='lower', cmap='Reds') ax[0].contourf(masksub, levels=[0.00, 0.25, 0.50, 0.75], cmap='Greys_r', vmin=0., vmax=2., alpha=0.5) ax[0].set_title('1. SCI frame & transmission mask') ax[1].imshow(model_psf, origin='lower', cmap='Reds') ax[1].contourf(masksub, levels=[0.00, 0.25, 0.50, 0.75], cmap='Greys_r', vmin=0., vmax=2., alpha=0.5) ax[1].set_title('Model PSF & transmission mask') ax[2].scatter((xsciref), (ysciref), marker='+', color='black', label='SIAF reference point') ax[2].scatter((maskcenx + 1), (maskceny + 1), marker='x', color='skyblue', label='True mask center') ax[2].scatter((xc + 1), (yc + 1), marker='*', color='red', label='Computed star position') ax[2].set_aspect('equal') xlim = ax[2].get_xlim() ylim = ax[2].get_ylim() xrng = xlim[1]-xlim[0] yrng = ylim[1]-ylim[0] if xrng > yrng: ax[2].set_xlim(np.mean(xlim) - xrng, np.mean(xlim) + xrng) ax[2].set_ylim(np.mean(ylim) - xrng, np.mean(ylim) + xrng) else: ax[2].set_xlim(np.mean(xlim) - yrng, np.mean(xlim) + yrng) ax[2].set_ylim(np.mean(ylim) - yrng, np.mean(ylim) + yrng) ax[2].set_xlabel('x-position [pix]') ax[2].set_ylabel('y-position [pix]') ax[2].legend(loc='upper right', fontsize=12) ax[2].set_title('Scene overview (1-indexed)') plt.tight_layout() if save_figures: output_file = os.path.split(self.database.obs[key]['FITSFILE'][j])[1] output_file = output_file.replace('.fits', '.pdf') output_file = os.path.join(output_dir, output_file) plt.savefig(output_file) log.info(f" Plot saved in {output_file}") plt.show() plt.close(fig) # Return star position. return xc, yc, median_xshift, median_yshift
[docs] @plt.style.context('spaceKLIP.sk_style') def align_frames(self, method='fourier', align_algo='leastsq', mask_override=None, msk_shp=8, shft_exp=1, align_to_file=None, scale_prior=False, kwargs={}, subdir='aligned', save_figures=True): """ Align all SCI and REF frames to the first SCI frame. Parameters ---------- method : 'fourier' or 'spline' (not recommended), optional Method for shifting the frames. The default is 'fourier'. align_algo : 'leastsq' or 'header' Algorithm to determine the alignment offsets. Default is 'leastsq', 'header' assumes perfect header offsets. mask_override : str, optional Mask some pixels when cross correlating for shifts msk_shp : int, optional Shape (height or radius, or [inner radius, outer radius]) for custom mask invoked by "mask_override" shft_exp : float, optional Take image to the given power before cross correlating for shifts, default is 1. For instance, 1/2 helps align nircam bar/narrow data (or other data with weird speckles) align_to_file : str, optional Path to FITS file to which all images shall be aligned. Needs to be a file with the same observational setup as all concatenations in the spaceKLIP database. Hence, this can only be applied to one observational setup at a time. The default is None. scale_prior : bool, optional If True, tries to find a better prior for the scale factor instead of simply using 1. The default is False. kwargs : dict, optional Keyword arguments for the scipy.ndimage.shift routine. The default is {}. subdir : str, optional Name of the directory where the data products shall be saved. The default is 'aligned'. save_figures : bool, optional Save the plots in a PDF? Returns ------- None. """ #### DEPRECATION WARNING #### log.warning('This function is deprecated. Use `calculate_alignment` and `shift_frames` instead.') # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) # Useful masks for computing shifts: def create_annulus_mask(h, w, center=None, radius=None): if center is None: # use the middle of the image center = (int(w/2), int(h/2)) if radius is None: # use the smallest distance between the center and image walls radius = min(center[0], center[1], w-center[0], h-center[1]) Y, X = np.ogrid[:h, :w] dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2) mask = (dist_from_center <= radius[0]) | (dist_from_center >= radius[1]) return mask def create_circular_mask(h, w, center=None, radius=None): if center is None: # use the middle of the image center = (int(w/2), int(h/2)) if radius is None: # use the smallest distance between the center and image walls radius = min(center[0], center[1], w-center[0], h-center[1]) Y, X = np.ogrid[:h, :w] dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2) mask = dist_from_center <= radius return mask def create_rec_mask(h, w, center=None, z=None): if center is None: # use the middle of the image center = (int(w/2), int(h/2)) if z is None: z = h//4 mask = np.zeros((h, w), dtype=bool) mask[center[1]-z:center[1]+z, :] = True return mask # Loop through concatenations. database_temp = deepcopy(self.database.obs) for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) # Find science and reference files. ww_sci = np.where(self.database.obs[key]['TYPE'] == 'SCI')[0] if len(ww_sci) == 0: raise UserWarning('Could not find any science files') ww_ref = np.where(self.database.obs[key]['TYPE'] == 'REF')[0] ww_all = np.append(ww_sci, ww_ref) # Loop through FITS files. if align_to_file is not None: try: ref_image = pyfits.getdata(align_to_file, 'SCI') except: ref_image = pyfits.getdata(align_to_file, 0) if ref_image.ndim == 3: ref_image = np.nanmedian(ref_image, axis=0) shifts_all = [] for j in ww_all: # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) if mask_override is not None: if mask_override == 'ann': mask_circ = create_annulus_mask(data[0].shape[0], data[0].shape[1], radius=msk_shp) elif mask_override == 'circ': mask_circ = create_circular_mask(data[0].shape[0], data[0].shape[1], radius=msk_shp) elif mask_override == 'rec': mask_circ = create_rec_mask(data[0].shape[0], data[0].shape[1], z=msk_shp) else: raise ValueError('There are `circ` and `rec` custom masks available') mask_temp = data[0].copy() mask_temp[~mask_circ] = 1 mask_temp[mask_circ] = 0 elif mask is None: mask_temp = np.ones_like(data[0]) else: mask_temp = mask.copy() # Align frames. head, tail = os.path.split(fitsfile) log.info(' --> Align frames: ' + tail) if np.sum(np.isnan(data)) != 0: raise UserWarning('Please replace nan pixels before attempting to align frames') shifts = [] maskcen = [] crpix = [] for k in range(data.shape[0]): # Take the first science frame as reference frame. if j == ww_sci[0] and k == 0: if align_to_file is None: ref_image = data[k].copy() pp = np.array([0., 0., 1.]) xoffset = self.database.obs[key]['XOFFSET'][j] # arcsec yoffset = self.database.obs[key]['YOFFSET'][j] # arcsec starcenx = self.database.obs[key]['STARCENX'][j] # pixels starceny = self.database.obs[key]['STARCENY'][j] # pixels pxsc = self.database.obs[key]['PIXSCALE'][j] # arcsec # Align all other SCI and REF frames to the first science # frame. if align_to_file is not None or j != ww_sci[0] or k != 0: # Calculate shifts relative to first frame, work in pixels. xfirst = crpix1 + (xoffset/pxsc) xoff_curr_pix = self.database.obs[key]['XOFFSET'][j]/self.database.obs[key]['PIXSCALE'][j] xcurrent = self.database.obs[key]['CRPIX1'][j] + xoff_curr_pix xshift = xfirst - xcurrent yfirst = crpix2 + (yoffset/pxsc) yoff_curr_pix = self.database.obs[key]['YOFFSET'][j]/self.database.obs[key]['PIXSCALE'][j] ycurrent = self.database.obs[key]['CRPIX2'][j] + yoff_curr_pix yshift = yfirst - ycurrent # Get mask center and crpix to also register the shift in their locations. maskcenx = self.database.obs[key]['MASKCENX'][j] # pixels maskceny = self.database.obs[key]['MASKCENY'][j] # pixels crpix1 = self.database.obs[key]['CRPIX1'][j] # pixels crpix2 = self.database.obs[key]['CRPIX2'][j] # pixels maskcen += [np.array([maskcenx, maskceny])] crpix += [np.array([crpix1, crpix2])] if scale_prior: ww = mask < 0.5 sh = mask.shape bw = 100 ww[:bw, :] = 0. ww[:, :bw] = 0. ww[sh[0] - bw:, :] = 0. ww[:, sh[1] - bw:] = 0. # plt.imshow(ww, origin='lower') # plt.show() scale = np.nanmedian(np.true_divide(ref_image, data[k])[ww]) if shft_exp != 1: scale = np.power(np.abs(scale), shft_exp) p0 = np.array([xshift, yshift, scale]) else: p0 = np.array([xshift, yshift, 1.]) # Fix for weird numerical behaviour if shifts are small # but not exactly zero. if (np.abs(xshift) < 1e-3) and (np.abs(yshift) < 1e-3): p0 = np.array([0., 0., p0[-1]]) if align_algo == 'leastsq': if shft_exp != 1: args = (np.power(np.abs(data[k]), shft_exp), np.power(np.abs(ref_image), shft_exp), mask_temp, method, kwargs) else: args = (data[k], ref_image, mask_temp, method, kwargs) # Use header values to initiate least squares fit pp = leastsq(ut.alignlsq, p0, args=args)[0] elif align_algo == 'header': # Just assume the header values are correct pp = p0 # Append shifts to array and apply shift to image # using defined method. shifts += [np.array([pp[0], pp[1], pp[2]])] if align_to_file is not None or j != ww_sci[0] or k != 0: data[k] = ut.imshift(data[k], [shifts[k][0], shifts[k][1]], method=method, kwargs=kwargs) erro[k] = ut.imshift(erro[k], [shifts[k][0], shifts[k][1]], method=method, kwargs=kwargs) shifts = np.array(shifts) maskcen = np.array(maskcen) crpix = np.array(crpix) if mask is not None: if align_to_file is not None or j != ww_sci[0]: temp = np.median(shifts, axis=0) mask = spline_shift(mask, [temp[1], temp[0]], order=0, mode='constant', cval=np.nanmedian(mask)) shifts_all += [shifts] if imshifts is not None: imshifts += shifts[:, :-1] else: imshifts = shifts[:, :-1] if maskoffs is not None: maskoffs -= shifts[:, :-1] else: maskoffs = -shifts[:, :-1] # Compute shift distances. dist = np.sqrt(np.sum(shifts[:, :2]**2, axis=1)) # pix dist *= self.database.obs[key]['PIXSCALE'][j]*1000 # mas if j == ww_sci[0]: dist = dist[1:] log.info(' --> Align frames: median required shift = %.2f mas' % np.median(dist)) if self.database.obs[key]['TELESCOP'][j] == 'JWST': ww = (dist < 1e-5) | (dist > 100.) else: ww = (dist < 1e-5) if np.sum(ww) != 0: if j == ww_sci[0]: ww = np.append(np.array([False]), ww) ww = np.where(ww == True)[0] if align_algo != 'header': log.warning(' --> The following frames might not be properly aligned: '+str(ww)) # Write FITS file and PSF mask. head_pri['XOFFSET'] = xoffset # arcseconds head_pri['YOFFSET'] = yoffset # arcseconds head_sci['STARCENX'] = starcenx head_sci['STARCENY'] = starceny # Change mask location too: take first values of the shifts of the current fits file for now. maskcenx = maskcen[0, 0] + shifts[0, 0] maskceny = maskcen[0, 1] + shifts[0, 1] crpix1 = crpix[0, 0] + shifts[0, 0] crpix2 = crpix[0, 1] + shifts[0, 1] head_sci['MASKCENX'] = maskcenx head_sci['MASKCENY'] = maskceny head_sci['CRPIX1'] = crpix1 head_sci['CRPIX2'] = crpix2 fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile, xoffset=xoffset, yoffset=yoffset, starcenx=starcenx, starceny=starceny, maskcenx=maskcenx, maskceny=maskceny, crpix1=crpix1, crpix2=crpix2) # Plot science frame alignment. colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] fig = plt.figure(figsize=(6.4, 4.8)) ax = plt.gca() for index, j in enumerate(ww_sci): ax.scatter(shifts_all[index][:, 0] * self.database.obs[key]['PIXSCALE'][j] * 1000, shifts_all[index][:, 1] * self.database.obs[key]['PIXSCALE'][j] * 1000, s=5, color=colors[index % len(colors)], marker='o', label='PA = %.0f deg' % self.database.obs[key]['ROLL_REF'][j]) ax.axhline(0., color='gray', lw=1, zorder=-1) # set zorder to ensure lines are drawn behind all the scatter points ax.axvline(0., color='gray', lw=1, zorder=-1) ax.set_aspect('equal') xlim = ax.get_xlim() ylim = ax.get_ylim() xrng = xlim[1]-xlim[0] yrng = ylim[1]-ylim[0] if xrng > yrng: ax.set_xlim(np.mean(xlim) - xrng, np.mean(xlim) + xrng) ax.set_ylim(np.mean(ylim) - xrng, np.mean(ylim) + xrng) else: ax.set_xlim(np.mean(xlim) - yrng, np.mean(xlim) + yrng) ax.set_ylim(np.mean(ylim) - yrng, np.mean(ylim) + yrng) ax.set_xlabel('x-shift [mas]') ax.set_ylabel('y-shift [mas]') ax.legend(loc='upper right') ax.set_title(f'Science frame alignment\nfor {self.database.obs[key]["TARGPROP"][ww_sci[0]]}, {self.database.obs[key]["FILTER"][ww_sci[0]]}') if save_figures: output_file = os.path.join(output_dir, key + '_align_sci.pdf') plt.savefig(output_file) log.info(f" Plot saved in {output_file}") plt.show() plt.close(fig) # Plot reference frame alignment. if len(ww_ref) > 0: colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] fig = plt.figure(figsize=(6.4, 4.8)) ax = plt.gca() seen = [] reps = [] syms = ['o', 'v', '^', '<', '>'] * (1 + len(ww_ref) // 5) add = len(ww_sci) for index, j in enumerate(ww_ref): this = '%.3f_%.3f' % (database_temp[key]['XOFFSET'][j], database_temp[key]['YOFFSET'][j]) if this not in seen: ax.scatter(shifts_all[index + add][:, 0] * self.database.obs[key]['PIXSCALE'][j] * 1000, shifts_all[index + add][:, 1] * self.database.obs[key]['PIXSCALE'][j] * 1000, s=5, color=colors[len(seen) % len(colors)], marker=syms[0], label='dither %.0f' % (len(seen) + 1)) ax.hlines((-database_temp[key]['YOFFSET'][j] + yoffset) * 1000, (-database_temp[key]['XOFFSET'][j] + xoffset) * 1000 - 4., (-database_temp[key]['XOFFSET'][j] + xoffset) * 1000 + 4., color=colors[len(seen) % len(colors)], lw=1) ax.vlines((-database_temp[key]['XOFFSET'][j] + xoffset) * 1000, (-database_temp[key]['YOFFSET'][j] + yoffset) * 1000 - 4., (-database_temp[key]['YOFFSET'][j] + yoffset) * 1000 + 4., color=colors[len(seen) % len(colors)], lw=1) seen += [this] reps += [1] else: ww = np.where(np.array(seen) == this)[0][0] ax.scatter(shifts_all[index + add][:, 0] * self.database.obs[key]['PIXSCALE'][j] * 1000, shifts_all[index + add][:, 1] * self.database.obs[key]['PIXSCALE'][j] * 1000, s=5, color=colors[ww % len(colors)], marker=syms[reps[ww]]) reps[ww] += 1 ax.set_aspect('equal') xlim = ax.get_xlim() ylim = ax.get_ylim() xrng = xlim[1]-xlim[0] yrng = ylim[1]-ylim[0] if xrng > yrng: ax.set_xlim(np.mean(xlim) - xrng, np.mean(xlim) + xrng) ax.set_ylim(np.mean(ylim) - xrng, np.mean(ylim) + xrng) else: ax.set_xlim(np.mean(xlim) - yrng, np.mean(xlim) + yrng) ax.set_ylim(np.mean(ylim) - yrng, np.mean(ylim) + yrng) ax.set_xlabel('x-shift [mas]') ax.set_ylabel('y-shift [mas]') ax.legend(loc='upper right', fontsize='small') ax.set_title(f'Reference frame alignment\n showing {len(ww_ref)} PSF refs for {self.database.obs[key]["FILTER"][ww_ref[0]]}') if save_figures: output_file = os.path.join(output_dir, key + '_align_ref.pdf') plt.savefig(output_file) log.info(f" Plot saved in {output_file}") plt.show() plt.close(fig)
[docs] @plt.style.context('spaceKLIP.sk_style') def calculate_alignment(self, method='fourier', align_algo='leastsq', mask_override=None, msk_shp=8, shft_exp=1, align_to_file=None, scale_prior=False, kwargs={}, subdir='aligned', save_figures=True): """ Calculate shifts necessary to align all SCI and REF frames to the first SCI frame. Parameters ---------- method : 'fourier' or 'spline' (not recommended), optional Method for shifting the frames. The default is 'fourier'. align_algo : 'leastsq' or 'header' Algorithm to determine the alignment offsets. Default is 'leastsq', 'header' assumes perfect header offsets. mask_override : str, optional Mask some pixels when cross correlating for shifts msk_shp : int, optional Shape (height or radius, or [inner radius, outer radius]) for custom mask invoked by "mask_override" shft_exp : float, optional Take image to the given power before cross correlating for shifts, default is 1. For instance, 1/2 helps align nircam bar/narrow data (or other data with weird speckles) align_to_file : str, optional Path to FITS file to which all images shall be aligned. Needs to be a file with the same observational setup as all concatenations in the spaceKLIP database. Hence, this can only be applied to one observational setup at a time. The default is None. scale_prior : bool, optional If True, tries to find a better prior for the scale factor instead of simply using 1. The default is False. kwargs : dict, optional Keyword arguments for the scipy.ndimage.shift routine. The default is {}. subdir : str, optional Name of the directory where the data products shall be saved. The default is 'aligned'. save_figures : bool, optional Save the plots in a PDF? Returns ------- None. """ # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) # Useful masks for computing shifts: def create_annulus_mask(h, w, center=None, radius=None): if center is None: # Use the middle of the image. center = (int(w/2), int(h/2)) if radius is None: # Use the smallest distance between the center and image walls. radius = min(center[0], center[1], w-center[0], h-center[1]) Y, X = np.ogrid[:h, :w] dist_from_center = np.sqrt((X - center[0])**2 + (Y - center[1])**2) mask = (dist_from_center <= radius[0]) | (dist_from_center >= radius[1]) return mask def create_circular_mask(h, w, center=None, radius=None): if center is None: # Use the middle of the image. center = (int(w/2), int(h/2)) if radius is None: # Use the smallest distance between the center and image walls. radius = min(center[0], center[1], w-center[0], h-center[1]) Y, X = np.ogrid[:h, :w] dist_from_center = np.sqrt((X - center[0])**2 + (Y - center[1])**2) mask = dist_from_center <= radius return mask def create_rec_mask(h, w, center=None, z=None): if center is None: # Use the middle of the image. center = (int(w/2), int(h/2)) if z is None: z = h//4 mask = np.zeros((h, w), dtype=bool) mask[center[1]-z:center[1]+z, :] = True return mask # Loop through concatenations. database_temp = deepcopy(self.database.obs) for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) # Find science and reference files. ww_sci = np.where(self.database.obs[key]['TYPE'] == 'SCI')[0] if len(ww_sci) == 0: raise UserWarning('Could not find any science files') ww_ref = np.where(self.database.obs[key]['TYPE'] == 'REF')[0] ww_all = np.append(ww_sci, ww_ref) # Loop through FITS files. if align_to_file is not None: try: ref_image = pyfits.getdata(align_to_file, 'SCI') except: ref_image = pyfits.getdata(align_to_file, 0) if ref_image.ndim == 3: ref_image = np.nanmedian(ref_image, axis=0) shifts_all = [] for j in ww_all: # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) if mask_override is not None: if mask_override == 'ann': mask_circ = create_annulus_mask(data[0].shape[0], data[0].shape[1], radius=msk_shp) elif mask_override == 'circ': mask_circ = create_circular_mask(data[0].shape[0], data[0].shape[1], radius=msk_shp) elif mask_override == 'rec': mask_circ = create_rec_mask(data[0].shape[0], data[0].shape[1], z=msk_shp) else: raise ValueError('There are `circ` and `rec` custom masks available') mask_temp = data[0].copy() mask_temp[~mask_circ] = 1 mask_temp[mask_circ] = 0 elif mask is None: mask_temp = np.ones_like(data[0]) else: mask_temp = mask.copy() # Align frames. head, tail = os.path.split(fitsfile) log.info(' --> Align frames: ' + tail) if np.sum(np.isnan(data)) != 0: raise UserWarning('Please replace nan pixels before attempting to align frames') shifts = [] mask_shifts = [] for k in range(data.shape[0]): # Take the first science frame as reference frame. if j == ww_sci[0] and k == 0: if align_to_file is None: ref_image = data[k].copy() pp = np.array([0., 0., 1.]) xoffset = self.database.obs[key]['XOFFSET'][j] # arcsec yoffset = self.database.obs[key]['YOFFSET'][j] # arcsec starcenx = self.database.obs[key]['STARCENX'][j] # pixels, 1 indexed starceny = self.database.obs[key]['STARCENY'][j] # pixels, 1 indexed pxsc = self.database.obs[key]['PIXSCALE'][j] # arcsec # Align all other SCI and REF frames to the first science # frame. if align_to_file is not None or j != ww_sci[0] or k != 0: # Calculate shifts relative to first frame, work in pixels. xfirst = starcenx + (xoffset/pxsc) xoff_curr_pix = self.database.obs[key]['XOFFSET'][j]/self.database.obs[key]['PIXSCALE'][j] xcurrent = self.database.obs[key]['STARCENX'][j] + xoff_curr_pix xshift = xfirst - xcurrent yfirst = starceny + (yoffset/pxsc) yoff_curr_pix = self.database.obs[key]['YOFFSET'][j]/self.database.obs[key]['PIXSCALE'][j] ycurrent = self.database.obs[key]['STARCENY'][j] + yoff_curr_pix yshift = yfirst - ycurrent if scale_prior: ww = mask < 0.5 sh = mask.shape bw = 100 ww[:bw, :] = 0. ww[:, :bw] = 0. ww[sh[0] - bw:, :] = 0. ww[:, sh[1] - bw:] = 0. # plt.imshow(ww, origin='lower') # plt.show() scale = np.nanmedian(np.true_divide(ref_image, data[k])[ww]) if shft_exp != 1: scale = np.power(np.abs(scale), shft_exp) p0 = np.array([xshift, yshift, scale]) else: p0 = np.array([xshift, yshift, 1.]) # Fix for weird numerical behaviour if shifts are small # but not exactly zero. if (np.abs(xshift) < 1e-3) and (np.abs(yshift) < 1e-3): p0 = np.array([0., 0., p0[-1]]) if align_algo == 'leastsq': if shft_exp != 1: args = (np.power(np.abs(data[k]), shft_exp), np.power(np.abs(ref_image), shft_exp), mask_temp, method, kwargs) else: args = (data[k], ref_image, mask_temp, method, kwargs) # Use header values to initiate least squares fit. pp = leastsq(ut.alignlsq, p0, args=args)[0] elif align_algo == 'header': # Just assume the header values are correct. pp = p0 # Append shifts to array. if align_to_file is not None or j != ww_sci[0] or k != 0: shifts += [np.array([pp[0], pp[1], pp[2]])] else: shifts += [np.array([0, 0, 0])] # Do the same for the mask if mask is not None: if align_to_file is not None or j != ww_sci[0]: temp = np.median(shifts, axis=0) mask_shifts = np.array([temp[1], temp[0]]) else: mask_shifts = np.array([0, 0]) shifts = np.array(shifts) mask_shifts = np.array(mask_shifts) shifts_all += [shifts] if align_shift is not None: align_shift += shifts[:, :-1] else: align_shift = shifts[:, :-1] if align_mask is not None: align_mask += mask_shifts else: align_mask = mask_shifts if maskoffs is not None: maskoffs -= shifts[:, :-1] else: maskoffs = -shifts[:, :-1] # Compute shift distances. dist = np.sqrt(np.sum(shifts[:, :2]**2, axis=1)) # pix dist *= self.database.obs[key]['PIXSCALE'][j] * 1000 # mas if j == ww_sci[0]: dist = dist[1:] log.info(' --> Align frames: median required shift = %.2f mas' % np.median(dist)) if self.database.obs[key]['TELESCOP'][j] == 'JWST': ww = (dist < 1e-5) | (dist > 100.) else: ww = (dist < 1e-5) if np.sum(ww) != 0: if j == ww_sci[0]: ww = np.append(np.array([False]), ww) ww = np.where(ww == True)[0] if align_algo != 'header': log.warning(' --> The following frames might not be properly aligned: '+str(ww)) # Write FITS file and PSF mask. head_pri['XOFFSET'] = xoffset # arcseconds head_pri['YOFFSET'] = yoffset # arcseconds # Use alignment shifts to update STARCENX/Y with accurate position. if not (j == ww_sci[0]): # Skip the very first science frame. median_x_shift = np.nanmedian(shifts[:, 0]) # Use median shift of all frames. median_y_shift = np.nanmedian(shifts[:, 1]) starcenx = self.database.obs[key]['STARCENX'][j] - median_x_shift starceny = self.database.obs[key]['STARCENY'][j] - median_y_shift head_sci['STARCENX'] = starcenx head_sci['STARCENY'] = starceny fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) # Update spaceKLIP database. if not (j == ww_sci[0]): # Skip updating the STARCENX/Y for very first science frame. self.database.update_obs(key, j, fitsfile, maskfile, xoffset=xoffset, yoffset=yoffset, starcenx=starcenx, starceny=starceny, align_shift=align_shift, align_mask=align_mask) else: self.database.update_obs(key, j, fitsfile, maskfile, xoffset=xoffset, yoffset=yoffset, align_shift=align_shift, align_mask=align_mask) # Plot science frame alignment. colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] fig = plt.figure(figsize=(6.4, 4.8)) ax = plt.gca() for index, j in enumerate(ww_sci): ax.scatter(shifts_all[index][:, 0] * self.database.obs[key]['PIXSCALE'][j] * 1000, shifts_all[index][:, 1] * self.database.obs[key]['PIXSCALE'][j] * 1000, s=5, color=colors[index % len(colors)], marker='o', label='PA = %.0f deg' % self.database.obs[key]['ROLL_REF'][j]) ax.axhline(0., color='gray', lw=1, zorder=-1) # set zorder to ensure lines are drawn behind all the scatter points ax.axvline(0., color='gray', lw=1, zorder=-1) ax.set_aspect('equal') xlim = ax.get_xlim() ylim = ax.get_ylim() xrng = xlim[1]-xlim[0] yrng = ylim[1]-ylim[0] if xrng > yrng: ax.set_xlim(np.mean(xlim) - xrng, np.mean(xlim) + xrng) ax.set_ylim(np.mean(ylim) - xrng, np.mean(ylim) + xrng) else: ax.set_xlim(np.mean(xlim) - yrng, np.mean(xlim) + yrng) ax.set_ylim(np.mean(ylim) - yrng, np.mean(ylim) + yrng) ax.set_xlabel('x-shift [mas]') ax.set_ylabel('y-shift [mas]') ax.legend(loc='upper right') ax.set_title(f'Science frame alignment\nfor {self.database.obs[key]["TARGPROP"][ww_sci[0]]}, {self.database.obs[key]["FILTER"][ww_sci[0]]}') if save_figures: output_file = os.path.join(output_dir, key + '_align_sci.pdf') plt.savefig(output_file) log.info(f" Plot saved in {output_file}") plt.show() plt.close(fig) # Plot reference frame alignment. if len(ww_ref) > 0: colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] fig = plt.figure(figsize=(6.4, 4.8)) ax = plt.gca() seen = [] reps = [] syms = ['o', 'v', '^', '<', '>'] * (1 + len(ww_ref) // 5) add = len(ww_sci) for index, j in enumerate(ww_ref): this = '%.3f_%.3f' % (database_temp[key]['XOFFSET'][j], database_temp[key]['YOFFSET'][j]) if this not in seen: ax.scatter(shifts_all[index + add][:, 0] * self.database.obs[key]['PIXSCALE'][j] * 1000, shifts_all[index + add][:, 1] * self.database.obs[key]['PIXSCALE'][j] * 1000, s=5, color=colors[len(seen) % len(colors)], marker=syms[0], label='dither %.0f' % (len(seen) + 1)) ax.hlines((-database_temp[key]['YOFFSET'][j] + yoffset) * 1000, (-database_temp[key]['XOFFSET'][j] + xoffset) * 1000 - 4., (-database_temp[key]['XOFFSET'][j] + xoffset) * 1000 + 4., color=colors[len(seen) % len(colors)], lw=1) ax.vlines((-database_temp[key]['XOFFSET'][j] + xoffset) * 1000, (-database_temp[key]['YOFFSET'][j] + yoffset) * 1000 - 4., (-database_temp[key]['YOFFSET'][j] + yoffset) * 1000 + 4., color=colors[len(seen) % len(colors)], lw=1) seen += [this] reps += [1] else: ww = np.where(np.array(seen) == this)[0][0] ax.scatter(shifts_all[index + add][:, 0] * self.database.obs[key]['PIXSCALE'][j] * 1000, shifts_all[index + add][:, 1] * self.database.obs[key]['PIXSCALE'][j] * 1000, s=5, color=colors[ww % len(colors)], marker=syms[reps[ww]]) reps[ww] += 1 ax.set_aspect('equal') xlim = ax.get_xlim() ylim = ax.get_ylim() xrng = xlim[1]-xlim[0] yrng = ylim[1]-ylim[0] if xrng > yrng: ax.set_xlim(np.mean(xlim) - xrng, np.mean(xlim) + xrng) ax.set_ylim(np.mean(ylim) - xrng, np.mean(ylim) + xrng) else: ax.set_xlim(np.mean(xlim) - yrng, np.mean(xlim) + yrng) ax.set_ylim(np.mean(ylim) - yrng, np.mean(ylim) + yrng) ax.set_xlabel('x-shift [mas]') ax.set_ylabel('y-shift [mas]') ax.legend(loc='upper right', fontsize='small') ax.set_title(f'Reference frame alignment\n showing {len(ww_ref)} PSF refs for {self.database.obs[key]["FILTER"][ww_ref[0]]}') if save_figures: output_file = os.path.join(output_dir, key + '_align_ref.pdf') plt.savefig(output_file) log.info(f" Plot saved in {output_file}") plt.show() plt.close(fig)
[docs] def shift_frames(self, method='fourier', kwargs={}, subdir='shifted'): """ Calculate shifts necessary to align all SCI and REF frames to the first SCI frame. Parameters ---------- method : 'fourier' or 'spline' (not recommended), optional Method for shifting the frames. The default is 'fourier'. kwargs : dict, optional Keyword arguments for the scipy.ndimage.shift routine. The default is {}. subdir : str, optional Name of the directory where the data products shall be saved. The default is 'shifted'. Returns ------- None. """ # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) # Loop through concatenations. database_temp = deepcopy(self.database.obs) for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) # Find science and reference files. ww_sci = np.where(self.database.obs[key]['TYPE'] == 'SCI')[0] ww_sci_ta = np.where(self.database.obs[key]['TYPE'] == 'SCI_TA')[0] ww_ref = np.where(self.database.obs[key]['TYPE'] == 'REF')[0] ww_ref_ta = np.where(self.database.obs[key]['TYPE'] == 'REF_TA')[0] # Loop through FITS files. ww_all = np.append(ww_sci, ww_ref) ww_all = np.append(ww_all, ww_sci_ta) ww_all = np.append(ww_all, ww_ref_ta) # Load in previously calculated image shifts. align_shift_star = self.database.obs[key]['ALIGN_SHIFT'] align_shift_mask = self.database.obs[key]['ALIGN_MASK'] center_shift_star = self.database.obs[key]['CENTER_SHIFT'] center_shift_mask = self.database.obs[key]['CENTER_MASK'] # Need to determine largest potential shift for padding purposes shiftpad = ut.estimate_padding_for_shift(align_shift_star, center_shift_star) log.info(f' --> Estimated padding for shifting: {shiftpad} pixels') shifts_all = [] for j in ww_all: # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) # Shift frames. head, tail = os.path.split(fitsfile) log.info(' --> Shift frames: ' + tail) # TO IMPROVE DQ DATA # The DQ also needs to be padded so we don't have array mismatches further down the line. # However, it's not really meaningful to shift the DQ array by sub-pixel amounts. # For now we will pad the array with zeros and assume that DQ is not important following shift_frames. pxdq = np.pad(pxdq, pad_width=((0, 0), (shiftpad, shiftpad), (shiftpad, shiftpad)), mode='constant', constant_values=0) # SCI and REF data. data_shift, erro_shift = [], [] if j in ww_sci or j in ww_ref: maskcenx = self.database.obs[key]['MASKCENX'][j] # 1 indexed maskceny = self.database.obs[key]['MASKCENY'][j] # 1 indexed # NIRCam coronagraphy. if self.database.obs[key]['EXP_TYPE'][j] in ['NRC_CORON']: shifts = [] # Store recenter + align shifts. for k in range(data.shape[0]): xshift = (align_shift_star[j][k][0] if not isinstance(align_shift_star[j], types.BuiltinFunctionType) else 0.0) + \ (center_shift_star[j][k][0] if not isinstance(center_shift_star[j], types.BuiltinFunctionType) else 0.0) yshift = (align_shift_star[j][k][1] if not isinstance(align_shift_star[j], types.BuiltinFunctionType) else 0.0) + \ (center_shift_star[j][k][1] if not isinstance(center_shift_star[j], types.BuiltinFunctionType) else 0.0) shifts += [np.array([xshift, yshift])] # Recenter and align the SCI and REF frames. data_shift += [ut.imshift(data[k], [shifts[k][0], shifts[k][1]], pad_amount=shiftpad, method=method, kwargs=kwargs)] erro_shift += [ut.imshift(erro[k], [shifts[k][0], shifts[k][1]], pad_amount=shiftpad, method=method, kwargs=kwargs)] data = np.array(data_shift) erro = np.array(erro_shift) if mask is not None: mask_shift = center_shift_mask[j] + align_shift_mask[j] mask = ut.imshift(mask, [mask_shift[0], mask_shift[1]], method='spline', pad_amount=shiftpad, kwargs={'mode':'constant'}) # Update mask center. maskcenx = self.database.obs[key]['MASKCENX'][j] + shifts[0][0] + shiftpad maskceny = self.database.obs[key]['MASKCENY'][j] + shifts[0][1] + shiftpad # Update star center. starcenx = self.database.obs[key]['STARCENX'][j] + shifts[0][0] + shiftpad starceny = self.database.obs[key]['STARCENY'][j] + shifts[0][1] + shiftpad # Update CRPIX values. crpix1 = self.database.obs[key]['CRPIX1'][j] + shifts[0][0] + shiftpad crpix2 = self.database.obs[key]['CRPIX2'][j] + shifts[0][1] + shiftpad # MIRI coronagraphy. elif self.database.obs[key]['EXP_TYPE'][j] in ['MIR_4QPM', 'MIR_LYOT']: shifts = [] # Store recenter + align shifts. for k in range(data.shape[0]): xshift = (align_shift_star[j][k][0] if not isinstance(align_shift_star[j], types.BuiltinFunctionType) else 0.0) + \ (center_shift_star[j][k][0] if not isinstance(center_shift_star[j], types.BuiltinFunctionType) else 0.0) yshift = (align_shift_star[j][k][1] if not isinstance(align_shift_star[j], types.BuiltinFunctionType) else 0.0) + \ (center_shift_star[j][k][1] if not isinstance(center_shift_star[j], types.BuiltinFunctionType) else 0.0) shifts += [np.array([xshift, yshift])] data_shift += [ut.imshift(data[k], [shifts[k][0], shifts[k][1]], pad_amount=shiftpad, method=method, kwargs=kwargs)] erro_shift += [ut.imshift(erro[k], [shifts[k][0], shifts[k][1]], pad_amount=shiftpad, method=method, kwargs=kwargs)] data = np.array(data_shift) erro = np.array(erro_shift) if mask is not None: mask_shift = center_shift_mask[j] + align_shift_mask[j] mask = ut.imshift(mask, [mask_shift[0], mask_shift[1]], method='spline', pad_amount=shiftpad, kwargs={'mode':'constant'}) # Update mask center. maskcenx = self.database.obs[key]['MASKCENX'][j] + shifts[0][0] + shiftpad maskceny = self.database.obs[key]['MASKCENY'][j] + shifts[0][1] + shiftpad # Update star center. starcenx = self.database.obs[key]['STARCENX'][j] + shifts[0][0] + shiftpad starceny = self.database.obs[key]['STARCENY'][j] + shifts[0][1] + shiftpad # Update CRPIX values. crpix1 = self.database.obs[key]['CRPIX1'][j] + shifts[0][0] + shiftpad crpix2 = self.database.obs[key]['CRPIX2'][j] + shifts[0][1] + shiftpad # Other data types. else: shifts = [] for k in range(data.shape[0]): xshift = align_shift_star[j][k][0] + center_shift_star[j][k][0] yshift = align_shift_star[j][k][1] + center_shift_star[j][k][1] shifts += [np.array([xshift, yshift])] this_data = ut.imshift(data[k], [shifts[k][0], shifts[k][1]], pad_amount=shiftpad, method=method, kwargs=kwargs) this_erro = ut.imshift(erro[k], [shifts[k][0], shifts[k][1]], pad_amount=shiftpad, method=method, kwargs=kwargs) # Recenter SCI and REF frames to integer pixel # precision by rolling the image. ww_max = np.unravel_index(np.argmax(data[k]), data[k].shape) if ww_max != (data.shape[-2] // 2, data.shape[-1] // 2): dx, dy = data.shape[-1] // 2 - ww_max[1], data.shape[-2] // 2 - ww_max[0] shifts[-1][0] += dx shifts[-1][1] += dy data_shift += [np.roll(np.roll(this_data, dx, axis=1), dy, axis=0)] erro_shift += [np.roll(np.roll(this_erro, dx, axis=1), dy, axis=0)] data = np.array(data_shift) erro = np.array(erro_shift) # Update star center. starcenx = self.database.obs[key]['STARCENX'][j] + shifts[0][0] + shiftpad starceny = self.database.obs[key]['STARCENY'][j] + shifts[0][1] + shiftpad # Update CRPIX values. crpix1 = self.database.obs[key]['CRPIX1'][j] + shifts[0][0] + shiftpad crpix2 = self.database.obs[key]['CRPIX2'][j] + shifts[0][1] + shiftpad # TA data. if j in ww_sci_ta or j in ww_ref_ta: shifts = [] for k in range(data.shape[0]): xshift = align_shift_star[j][k][0] + center_shift_star[j][k][0] yshift = align_shift_star[j][k][1] + center_shift_star[j][k][1] shifts += [np.array([xshift, yshift])] this_data = ut.imshift(data[k], [shifts[k][0], shifts[k][1]], pad_amount=shiftpad, method=method, kwargs=kwargs) this_erro = ut.imshift(erro[k], [shifts[k][0], shifts[k][1]], pad_amount=shiftpad, method=method, kwargs=kwargs) # Recenter TA frames to integer pixel precision by # rolling the image. ww_max = np.unravel_index(np.argmax(data[k]), data[k].shape) if ww_max != (data.shape[-2] // 2, data.shape[-1] // 2): dx, dy = data.shape[-1] // 2 - ww_max[1], data.shape[-2] // 2 - ww_max[0] shifts[-1][0] += dx shifts[-1][1] += dy data_shift += [np.roll(np.roll(this_data, dx, axis=1), dy, axis=0)] erro_shift += [np.roll(np.roll(this_erro, dx, axis=1), dy, axis=0)] data = np.array(data_shift) erro = np.array(erro_shift) # Update star center. starcenx = self.database.obs[key]['STARCENX'][j] + shifts[0][0] + shiftpad starceny = self.database.obs[key]['STARCENY'][j] + shifts[0][1] + shiftpad # Update CRPIX values. crpix1 = self.database.obs[key]['CRPIX1'][j] + shifts[0][0] + shiftpad crpix2 = self.database.obs[key]['CRPIX2'][j] + shifts[0][1] + shiftpad shifts_all += [shifts] # Write FITS file and PSF mask. head_sci['STARCENX'] = starcenx head_sci['STARCENY'] = starceny head_sci['MASKCENX'] = maskcenx head_sci['MASKCENY'] = maskceny head_sci['CRPIX1'] = crpix1 head_sci['CRPIX2'] = crpix2 # Save fits file. fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, align_shift, center_shift, align_mask, center_mask, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile, maskcenx=maskcenx, maskceny=maskceny, starcenx=starcenx, starceny=starceny, crpix1=crpix1, crpix2=crpix2)
[docs] @plt.style.context('spaceKLIP.sk_style') def subtract_nircam_coron_background(self, subdir='bgsub', mask_snr_threshold=2, r_excl_nfwhm=40, q_clip=5., align_wrapped=True, include_global_offset=True, include_stellar_psf_component=True, generate_plot=True, save_model=False, use_jbt_background=False, bgmodel_dir=None, background_sb={}, restrict_to=None): """ Fits and subtracts the astrophysical background in NIRCam coronagraphic data following the procedure described in Lawson et al. (2024). Note: This step should only be applied to data that has already been aligned. Otherwise, it will crash. For SW filters using a LW coronagraph, the field of view excludes the neutral density squares. In this case, the astrophysical background and the artificial background offset that we fit are fully degenerate in the regions we consider (away from the coronagraph). Since SW backgrounds should be low anyway, the default is to assume an astrophysical background of zero here. Alternatively, the JWST Backgrounds Tool can be used to estimate the background for affected data instead (if use_jbt_background=True and the JWST Backgrounds Tool is installed). Parameters ---------- subdir : str, optional Name of the sub-directory where the data products will be saved. The default is 'bgsub'. mask_snr_threshold : float, optional SNR threshold for features to be masked during fitting of the background. SNR is estimated using the ERR FITS extension. The default is 2. r_excl_nfwhm : float, optional Radius (in units of the effective PSF FWHM) of the region around the star to exclude from the fit. The default is 50. q_clip : float, optional After computing BG model residuals, exclude q_clip% of pixels from both ends of the residual distribution before computing chisq. This is intended to avoid over-/under-estimation of the background due to unmasked sources or artifacts. Default is 5. align_wrapped : bool, optional Whether input data were aligned using a Fourier shift without padding first (such that values wrapped at the edges). Default is True. include_global_offset : bool, optional Whether to fit a uniform background offset along with the astrophysical background model. This corrects for offsets induced by ramp fitting or use of the median subtraction step. Default is True. include_stellar_psf_component : bool, optional Whether to include a stellar PSF model component when optimizing the background model. Default is True. generate_plot : bool, optional Whether to generate a plot showing the data before and after subtraction along with the model and masked residuals. Default is True. save_model : bool, optional Whether to save the optimized background model. Default is False. use_jbt_background : bool, optional Whether to use the JWST Backgrounds Tool to estimate the background surface brightness for data without coverage of ND squares. Requires that jwst_backgrounds is installed. Default is False. bgmodel_dir : str, optional Path to the directory containing the normalized background model component FITS files to use (or to which they should be downloaded). If None, uses spaceKLIP/resources/nircam_bg_models/. Default is None. background_sb : dict, optional A dictionary of fixed background surface brightness (SB) values (in the same units as the data) to adopt for any included concatenation keys. For each key in database.obs, if background_sb[key] is None or if the key is not in background_sb, the background SB will be fit for all observations of that concatenation if possible. Otherwise, background_sb[key] should be an array-like of float or None having the same length as database.obs[key]. If background_sb[key][j] is None, the jth frame's background SB will be fit, otherwise it will be fixed to the value background_sb[key][j]. Default is {}. restrict_to : str or None, optional Restrict the background subtraction to a specific key in the database. Default is None. """ def get_jbt_background_est(t, ra, dec, wavelength): """ Uses the JWST Backgrounds tool to estimate the background surface brightness at a given time, position, and wavelength. """ from jwst_backgrounds import jbt from astropy.time import Time tobs = Time(t, format='mjd') bkg = jbt.background(ra, dec, wavelength) calendar = bkg.bkg_data['calendar'] tobs0 = Time(f'{tobs.datetime.year}-01-01T00:00:00') thisday = int(np.round((tobs.mjd-tobs0.mjd)+1)) Fbg = bkg.bathtub['total_thiswave'][np.where(thisday == calendar)[0][0]] return Fbg def background_objective(p, im, bg0, psf0, optmask, q=5): """ Objective function for fitting the multi-component background model using LMFit. """ fbg, bg_offset, fpsf = [p[key] for key in p] res = (im - fbg*bg0 - bg_offset - fpsf*psf0)[optmask] low, upp = np.nanpercentile(res, [q, 100.-q]) return np.abs(res[(res >= low) & (res <= upp)]) def get_stellar_model_path(key, bgmodel_dir): """ Searches for the correct stellar model component on disk and fetches from an online repository if needed. Returns the path to the FITS file. """ psffile = f'{bgmodel_dir}{key}_psf0.fits' if not os.path.exists(psffile): with fits.open(f'https://github.com/kdlawson/nircam_bgsub_go4050/raw/main/nominal_bgmodels/{key}_psf0.fits') as hdul: hdul.writeto(psffile) return psffile def get_background_model_path(key, bgmodel_dir): """ Searches for the correct background model component on disk and fetches from an online repository if needed. Returns the path to the FITS file. """ bgfile = f'{bgmodel_dir}{key}_background0.fits' if not os.path.exists(bgfile): with fits.open(f'https://github.com/kdlawson/nircam_bgsub_go4050/raw/main/nominal_bgmodels/{key}_background0.fits') as hdul: hdul.writeto(bgfile) return bgfile output_dir = os.path.join(self.database.output_dir, subdir+'/') if not os.path.isdir(output_dir): os.makedirs(output_dir) if bgmodel_dir is None: bgmodel_dir = os.path.join(os.path.split(os.path.abspath(__file__))[0], 'resources/nircam_bg_models/') if not os.path.exists(bgmodel_dir): os.makedirs(bgmodel_dir) # Copy input SB dictionary so we can fill out / change values as needed. bg_sb_dict = background_sb.copy() for key in self.database.obs: if ((restrict_to is not None) and (restrict_to not in key)) or not np.any(np.isin(self.database.obs[key]['TYPE'], ['SCI', 'REF'])): continue log.info('--> Concatenation ' + key) # Fill out dictionary of SB values, replacing None with a fixed value where fitting is not possible (SW filters with LW coronagraphs). if (key not in bg_sb_dict) or (bg_sb_dict[key] is None): bg_sb_dict[key] = np.repeat(None, len(self.database.obs[key])) for j, entry in enumerate(self.database.obs[key]): if (entry['DETECTOR'] == 'NRCA2') and (entry['CORONMSK'] in ['MASK335R', 'MASK430R', 'MASKLWB']) and (entry['SUBARRAY'] != 'FULL') and (bg_sb_dict[key][j] is None): if use_jbt_background: try: bg_sb_dict[key][j] = get_jbt_background_est(entry['EXPSTART'], entry['TARG_RA'], entry['TARG_DEC'], entry['CWAVEL']) except ModuleNotFoundError: raise ModuleNotFoundError(""" JBT background estimation requires the jwst_backgrounds package. Either install jwst_backgrounds or rerun with use_jbt_background=False. """) else: bg_sb_dict[key][j] = 0. db_tab = self.database.obs[key] fwhm = db_tab['CWAVEL'][0] * 1e-6 / 5.2 * 180. / np.pi * 3600. / db_tab['PIXSCALE'][0] blur_fwhm = db_tab['BLURFWHM'][0] if np.isfinite(blur_fwhm): blur_sigma = blur_fwhm/np.sqrt(8.*np.log(2.)) else: blur_sigma = None # Load the normalized stellar PSF model component if include_stellar_psf_component: psffile = get_stellar_model_path(key, bgmodel_dir=bgmodel_dir) with fits.open(psffile) as hdul: psf0_osamp, hpsf0 = hdul['OVERSAMP'].data, hdul['OVERSAMP'].header osamp = hpsf0['OSAMP'] c_psf0_osamp = np.array([hpsf0['CRPIX1'], hpsf0['CRPIX2']])-1 if blur_sigma is not None: psf0_osamp = gaussian_filter(psf0_osamp, blur_sigma*osamp) # Load the normalized BG model component if not np.all(bg_sb_dict[key] == 0): bgfile = get_background_model_path(key, bgmodel_dir=bgmodel_dir) with fits.open(bgfile) as hdul: bg0_osamp, hbg0 = hdul['OVERSAMP'].data, hdul['OVERSAMP'].header c_coron_bg0 = np.array([hbg0['CRPIX1'], hbg0['CRPIX2']])-1 osamp = hbg0['OSAMP'] # Apply any blurring used for the data: if blur_sigma is not None: bg0_osamp = gaussian_filter(bg0_osamp, blur_sigma*osamp) files = db_tab['FITSFILE'] c_star = np.array([db_tab['STARCENX'][0], db_tab['STARCENY'][0]])-1 h1 = fits.getheader(files[0], ext=1) ny, nx = h1['NAXIS2'], h1['NAXIS1'] rmap = np.sqrt((np.arange(0, nx, dtype=np.float32)-c_star[0])**2 + (np.arange(0, ny, dtype=np.float32)-c_star[1])[:, np.newaxis]**2) rmap_nfwhm = rmap/fwhm # Map of each pixel's distance from the location of the star in units of the effective PSF FWHM # With no alignment wrapping, the stellar PSF model should be the same for all frames, so we'll just set it up once outside the loop over files. if not align_wrapped: if include_stellar_psf_component: c_star_osamp = c_star*osamp + 0.5*(osamp-1) psf0_osamp_crop = webbpsf_ext.image_manip.crop_image(psf0_osamp, [ny*osamp, nx*osamp], xyloc=c_psf0_osamp, delx=c_star_osamp[0]-(nx*osamp-1)/2., dely=c_star_osamp[1]-(ny*osamp-1)/2.) psf0_crop = webbpsf_ext.image_manip.frebin(psf0_osamp_crop, scale=1/osamp, total=False) else: psf0_crop = np.zeros((ny, nx), dtype=np.float32) for j,f in enumerate(files): if db_tab[j]['TYPE'] not in ['SCI', 'REF']: continue head, tail = os.path.split(f) log.info(' --> NIRCam Background Subtraction: ' + tail) # Assume alignment and background differences between integrations are negligible, so we can use the higher SNR coadded exposure with fits.open(f) as hdul: ints = hdul['SCI'].data errs = hdul['ERR'].data h1 = hdul['SCI'].header mask_offset = np.nanmedian(hdul['MASKOFFS'].data, axis=0) imshift = np.nanmedian(hdul['ALIGN_SHIFT'].data + hdul['CENTER_SHIFT'].data, axis=0) if np.ndim(ints) == 3: med = np.nanmedian(ints, axis=0) nsample = np.sum(np.isfinite(ints), axis=0) err = np.sqrt(np.nansum(errs**2, axis=0))/nsample else: # In case using coadded data saved with only two dims med, err = ints, errs c_coron = c_star - mask_offset # post-alignment mask center position if align_wrapped: if bg_sb_dict[key][j] == 0: bg0_crop = np.zeros_like(med) else: c_coron_osamp_preshift = (c_coron-imshift)*osamp + 0.5*(osamp-1) bg0_osamp_crop = webbpsf_ext.image_manip.crop_image(bg0_osamp, [ny*osamp, nx*osamp], xyloc=c_coron_bg0, delx=c_coron_osamp_preshift[0]-(nx*osamp-1)/2., dely=c_coron_osamp_preshift[1]-(ny*osamp-1)/2.) bg0_osamp_crop = ut.imshift(bg0_osamp_crop, imshift*osamp, pad=False) bg0_crop = webbpsf_ext.image_manip.frebin(bg0_osamp_crop, scale=1/osamp, total=False) # Downsample to detector resolution if include_stellar_psf_component: c_star_osamp_preshift = (c_star-imshift)*osamp + 0.5*(osamp-1) psf0_osamp_crop = webbpsf_ext.image_manip.crop_image(psf0_osamp, [ny*osamp, nx*osamp], xyloc=c_psf0_osamp, delx=c_star_osamp_preshift[0]-(nx*osamp-1)/2., dely=c_star_osamp_preshift[1]-(ny*osamp-1)/2.) psf0_osamp_crop = ut.imshift(psf0_osamp_crop, imshift*osamp, pad=False) psf0_crop = webbpsf_ext.image_manip.frebin(psf0_osamp_crop, scale=1/osamp, total=False) else: psf0_crop = np.zeros_like(med) else: if bg_sb_dict[key][j] == 0: bg0_crop = np.zeros_like(med) else: c_coron_osamp = c_coron*osamp + 0.5*(osamp-1) bg0_osamp_crop = webbpsf_ext.image_manip.crop_image(bg0_osamp, [ny*osamp, nx*osamp], xyloc=c_coron_bg0, delx=c_coron_osamp[0]-(nx*osamp-1)/2., dely=c_coron_osamp[1]-(ny*osamp-1)/2.) bg0_crop = webbpsf_ext.image_manip.frebin(bg0_osamp_crop, scale=1/osamp, total=False) # Downsample to detector resolution if bg_sb_dict[key][j] is None: fbg0 = 1 fbg_vary = True else: fbg0 = bg_sb_dict[key][j] fbg_vary = False optmask = rmap_nfwhm > r_excl_nfwhm snr = med/err # SNR estimate using FITS ERR extension med_snr = np.nanmedian(snr[optmask]) # Median SNR in the nominal background area low_snr = snr <= (med_snr+mask_snr_threshold) # High SNR features are those more than mask_snr_threshold sigma above the approximate BG SNR optmask = optmask & low_snr bg_offset0 = 0 fpsf0 = 1. if not include_stellar_psf_component else np.nansum((med-np.nanmedian(med[optmask])) * psf0_crop) / np.nansum((psf0_crop ** 2)) # Prepare the lmfit Parameters object with default values and sensible bounds p = lmfit.Parameters() p.add('fbg', value=fbg0, min=0, max=np.inf, vary=fbg_vary) p.add('bg_offset', value=bg_offset0, min=-np.inf, max=0, vary=include_global_offset) p.add('fpsf', value=fpsf0, min=0, max=fpsf0*10, vary=include_stellar_psf_component) # Optimize the background model result = lmfit.minimize(background_objective, p, args=(med, bg0_crop, psf0_crop, optmask, q_clip), method='powell') pfin = result.params.valuesdict() logstr = ', '.join([f"{key}:{value:.3f}" for key, value in pfin.items()]) log.info(' --> NIRCam Background Subtraction: ' + logstr) # Compute the final background model and stellar PSF component: bg = pfin['fbg']*bg0_crop + pfin['bg_offset'] psf = psf0_crop*pfin['fpsf'] f_out = output_dir+os.path.basename(os.path.normpath(f)) with fits.open(f) as hdul: hdul[1].data -= bg # Subtract the BG model from the original file hdul.writeto(f_out, overwrite=True) # Save to disk in the output directory if save_model: f_model = f_out.replace('.fits', '_background_model.fits') hdu1 = fits.ImageHDU(bg, name='BG') hdul_model = fits.HDUList([hdul[0], hdu1]) if include_stellar_psf_component: hdul_model.append(fits.ImageHDU(psf, name='STELLAR_PSF')) # Add fit params to header hdul_model[0].header.update(pfin) # Add all relevant settings to the header hdul_model[0].header.update(dict(include_global_offset=include_global_offset, mask_snr_threshold=mask_snr_threshold, r_excl_nfwhm=r_excl_nfwhm, q=q_clip, include_stellar_psf_component=include_stellar_psf_component)) hdul_model.writeto(f_model, overwrite=True) hdul_model.close() if generate_plot: res = med - bg res_psfsub = res - psf low, upp = np.nanpercentile((res_psfsub)[optmask], [q_clip, 100.-q_clip]) res_inliers = np.where((res_psfsub >= low) & (res_psfsub <= upp) & optmask, res_psfsub, np.nan) cmap = copy.copy(mpl.colormaps.get_cmap('RdBu_r')) cmap.set_bad('white') clim = np.array([-1, 1])*np.nanpercentile(np.abs(med), 90) plot_mask = np.isfinite(med) plot_ims = np.where(plot_mask, [med, bg, res_inliers, med-bg], np.nan) fig,axes = plt.subplots(1, 4,figsize=(15, 3.5), sharex=True, sharey=True) labels = ['Data', 'BG Model', 'Masked Residuals', 'Data (BG-subtracted)'] norm = mpl.colors.Normalize(*clim) for ind,ax in enumerate(axes): ax.imshow(plot_ims[ind], norm=norm, interpolation='None', origin='lower', cmap=cmap) ax.set_title(labels[ind], pad=10) fig.tight_layout(w_pad=1.00) fig.colorbar(mpl.cm.ScalarMappable(norm, cmap), ax=axes, pad=0.015, label='[MJy / Sr]') plt.savefig(output_dir+os.path.basename(os.path.normpath(f)).replace('.fits', '_background_model.pdf'), bbox_inches='tight') plt.close(fig) mask_in = db_tab['MASKFILE'][j] mask = ut.read_msk(mask_in) mask_out = ut.write_msk(mask_in, mask, f_out) self.database.update_obs(key, j, f_out, mask_out)