from __future__ import division
import matplotlib
# =============================================================================
# IMPORTS
# =============================================================================
import os
import pdb
import sys
from astropy.io import fits
import numpy as np
import json
import pyklip.klip
from astropy import wcs
from jwst.pipeline import Detector1Pipeline, Image2Pipeline, Coron3Pipeline
from pyklip import parallelized, rdi
from pyklip.instruments.JWST import JWSTData
from pyklip.klip import _rotate_wcs_hdr
from spaceKLIP.psf import get_transmission
from spaceKLIP.utils import pop_pxar_kw
import logging
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)
# =============================================================================
# MAIN
# =============================================================================
[docs]
def run_obs(database,
restrict_to=None,
kwargs={},
subdir='klipsub'):
"""
Run pyKLIP on the input observations database.
Parameters
----------
database : spaceKLIP.Database
SpaceKLIP database on which pyKLIP shall be run.
kwargs : dict, optional
Keyword arguments for the pyklip.parallelized.klip_dataset method.
Available keywords are:
- mode : list of str, optional
Subtraction modes that shall be looped over. Possible values are
'ADI', 'RDI', and 'ADI+RDI'. The default is ['ADI+RDI'].
- annuli : list of int, optional
Numbers of subtraction annuli that shall be looped over. The
default is [1].
- subsections : list of int, optional
Numbers of subtraction subsections that shall be looped over. The
default is [1].
- numbasis : list of int, optional
Number of KL modes that shall be looped over. The default is [1, 2,
5, 10, 20, 50, 100].
- movement : float, optional
Minimum amount of movement (pix) of an astrophysical source to
consider using that image as a reference PSF. The default is 1.
- verbose : bool, optional
Verbose mode? The default is False.
- save_rolls : bool, optional
Save each processed roll separately? The default is False.
- save_full_output : bool, optional
Save the full output cube before flattening? The default is False.
The default is {}.
subdir : str, optional
Name of the directory where the data products shall be saved. The
default is 'klipsub'.
Returns
-------
None.
"""
# Check input.
if 'mode' not in kwargs.keys():
kwargs['mode'] = ['ADI+RDI']
if not isinstance(kwargs['mode'], list):
kwargs['mode'] = [kwargs['mode']]
if 'annuli' not in kwargs.keys():
kwargs['annuli'] = [1]
if not isinstance(kwargs['annuli'], list):
kwargs['annuli'] = [kwargs['annuli']]
if 'subsections' not in kwargs.keys():
kwargs['subsections'] = [1]
if not isinstance(kwargs['subsections'], list):
kwargs['subsections'] = [kwargs['subsections']]
if 'numbasis' not in kwargs.keys():
kwargs['numbasis'] = [1, 2, 5, 10, 20, 50, 100]
if not isinstance(kwargs['numbasis'], list):
kwargs['numbasis'] = [kwargs['numbasis']]
if 'IWA' not in kwargs.keys():
kwargs['IWA'] = 1.
kwargs_temp = kwargs.copy()
if 'movement' not in kwargs_temp.keys():
kwargs_temp['movement'] = 1.
kwargs_temp['calibrate_flux'] = False
if 'verbose' not in kwargs_temp.keys():
kwargs_temp['verbose'] = database.verbose
if 'save_rolls' not in kwargs_temp.keys():
kwargs_temp['save_ints'] = False
kwargs_temp['save_rolls'] = False
else:
# Note pyKLIP uses save_ints as the keyword for this, but we want to use save_rolls in our pipeline for clarity,
# so we need to set save_ints to the same value as save_rolls.
kwargs_temp['save_ints'] = kwargs_temp['save_rolls']
if 'save_full_output' not in kwargs_temp.keys():
kwargs_temp['save_full_output'] = False
if 'highpass' not in kwargs_temp.keys():
kwargs_temp['highpass'] = False
# Set output directory.
output_dir = os.path.join(database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
kwargs_temp['outputdir'] = output_dir
# Loop through concatenations.
datapaths = []
for i, key in enumerate(database.obs.keys()):
# if we limit to only processing some concatenations, check whether this concatenation matches the pattern
if (restrict_to is not None) and (restrict_to not in key):
continue
log.info('--> Concatenation ' + key)
filepaths, psflib_filepaths, maxnumbasis = get_pyklip_filepaths(database, key, return_maxbasis=True)
if 'maxnumbasis' not in kwargs_temp.keys() or kwargs_temp['maxnumbasis'] is None:
kwargs_temp['maxnumbasis'] = maxnumbasis
# Run KLIP subtraction.
for mode in kwargs['mode']:
# Initialize pyKLIP dataset.
pop_pxar_kw(np.append(filepaths, psflib_filepaths))
dataset = JWSTData(filepaths, psflib_filepaths,
highpass=kwargs_temp['highpass'],
center_include_offset=False,
center_keywords=['STARCENX','STARCENY'])
dataset.IWA = kwargs['IWA']
kwargs_temp['dataset'] = dataset
kwargs_temp['aligned_center'] = dataset.psflib.aligned_center
kwargs_temp['psf_library'] = dataset.psflib
kwargs_temp['mode'] = mode
# Can run pyKLIP multiple times on the same dataset with different
# annuli and subsections.
for annu in kwargs['annuli']:
for subs in kwargs['subsections']:
log.info(' --> pyKLIP: mode = ' + mode + ', annuli = ' + str(annu) + ', subsections = ' + str(subs))
fileprefix = mode + '_NANNU' + str(annu) + '_NSUBS' + str(subs) + '_' + key
# Add/update ramaining keywords.
kwargs_temp['fileprefix'] = fileprefix
kwargs_temp['annuli'] = annu
kwargs_temp['subsections'] = subs
kwargs_temp_temp = kwargs_temp.copy()
# Need to cleanup some kwargs that we're using but pyKLIP doesn't
del kwargs_temp_temp['save_full_output']
del kwargs_temp_temp['save_rolls']
del kwargs_temp_temp['IWA']
parallelized.klip_dataset(**kwargs_temp_temp)
# Get reduction path.
datapath = os.path.join(output_dir, fileprefix + '-KLmodes-all.fits')
datapaths += [datapath]
# Update reduction header.
ww_sci = np.where(database.obs[key]['TYPE'] == 'SCI')[0]
head_sci = fits.getheader(database.obs[key]['FITSFILE'][ww_sci[0]], 'SCI')
head_sci['NAXIS'] = 2
hdul = fits.open(datapath)
hdul[0].header['TELESCOP'] = database.obs[key]['TELESCOP'][ww_sci[0]]
hdul[0].header['TARGPROP'] = database.obs[key]['TARGPROP'][ww_sci[0]]
hdul[0].header['TARG_RA'] = database.obs[key]['TARG_RA'][ww_sci[0]]
hdul[0].header['TARG_DEC'] = database.obs[key]['TARG_DEC'][ww_sci[0]]
hdul[0].header['INSTRUME'] = database.obs[key]['INSTRUME'][ww_sci[0]]
hdul[0].header['DETECTOR'] = database.obs[key]['DETECTOR'][ww_sci[0]]
hdul[0].header['FILTER'] = database.obs[key]['FILTER'][ww_sci[0]]
hdul[0].header['CWAVEL'] = database.obs[key]['CWAVEL'][ww_sci[0]]
hdul[0].header['DWAVEL'] = database.obs[key]['DWAVEL'][ww_sci[0]]
hdul[0].header['PUPIL'] = database.obs[key]['PUPIL'][ww_sci[0]]
hdul[0].header['CORONMSK'] = database.obs[key]['CORONMSK'][ww_sci[0]]
hdul[0].header['EXP_TYPE'] = database.obs[key]['EXP_TYPE'][ww_sci[0]]
hdul[0].header['EXPSTART'] = np.min(database.obs[key]['EXPSTART'][ww_sci])
hdul[0].header['NINTS'] = np.sum(database.obs[key]['NINTS'][ww_sci])
hdul[0].header['EFFINTTM'] = database.obs[key]['EFFINTTM'][ww_sci[0]]
hdul[0].header['SUBARRAY'] = database.obs[key]['SUBARRAY'][ww_sci[0]]
hdul[0].header['APERNAME'] = database.obs[key]['APERNAME'][ww_sci[0]]
hdul[0].header['PPS_APER'] = database.obs[key]['PPS_APER'][ww_sci[0]]
hdul[0].header['PIXSCALE'] = database.obs[key]['PIXSCALE'][ww_sci[0]]
try:
hdul[0].header['PIXAR_SR'] = database.obs[key]['PIXAR_SR'][ww_sci[0]]
except:
pass
hdul[0].header['MODE'] = mode
hdul[0].header['ANNULI'] = annu
hdul[0].header['SUBSECTS'] = subs
hdul[0].header['BUNIT'] = database.obs[key]['BUNIT'][ww_sci[0]]
w = wcs.WCS(head_sci)
_rotate_wcs_hdr(w, database.obs[key]['ROLL_REF'][ww_sci[0]])
hdul[0].header['WCSAXES'] = head_sci['WCSAXES']
hdul[0].header['CRPIX1'] = head_sci['STARCENX']
hdul[0].header['CRPIX2'] = head_sci['STARCENY']
hdul[0].header['CRVAL1'] = head_sci['CRVAL1']
hdul[0].header['CRVAL2'] = head_sci['CRVAL2']
hdul[0].header['CTYPE1'] = head_sci['CTYPE1']
hdul[0].header['CTYPE2'] = head_sci['CTYPE2']
hdul[0].header['CUNIT1'] = head_sci['CUNIT1']
hdul[0].header['CUNIT2'] = head_sci['CUNIT2']
hdul[0].header['CD1_1'] = w.wcs.cd[0, 0]
hdul[0].header['CD1_2'] = w.wcs.cd[0, 1]
hdul[0].header['CD2_1'] = w.wcs.cd[1, 0]
hdul[0].header['CD2_2'] = w.wcs.cd[1, 1]
if not np.isnan(database.obs[key]['BLURFWHM'][ww_sci[0]]):
hdul[0].header['BLURFWHM'] = database.obs[key]['BLURFWHM'][ww_sci[0]]
hdul.writeto(datapath, output_verify='fix', overwrite=True)
hdul.close()
# If requested, save the full cube before flattening
if kwargs_temp['save_full_output']:
# Cube is held in dataset.output
intsfile = os.path.join(output_dir, fileprefix + '-KLmodes-all_fulloutput.fits')
hdul_full = fits.HDUList([fits.PrimaryHDU(data=dataset.output)])
# Set header keywords for the full cube as well, using the same header as the final output
# but with updated NAXIS and NINTS.
hdul_full[0].header = hdul[0].header.copy()
hdul_full[0].header['NAXIS'] = len(dataset.output.shape)
hdul_full[0].header['NINTS'] = dataset.output.shape[0]
hdul_full.writeto(intsfile, output_verify='fix', overwrite=True)
# If requested, save each roll separately.
if kwargs_temp['save_ints']:
n_roll = 1
for j in ww_sci:
fitsfile = os.path.split(database.obs[key]['FITSFILE'][j])[1]
head_sci = fits.getheader(database.obs[key]['FITSFILE'][j], 'SCI')
ww = [k for k in range(len(dataset._filenames)) if fitsfile in dataset._filenames[k]]
hdul = fits.open(datapath)
if dataset.allints.shape[1] == 1:
hdul[0].data = np.nanmedian(dataset.allints[:, :, ww, :, :], axis=(1, 2))
else:
hdul[0].data = np.nanmedian(dataset.allints[:, :, ww, :, :], axis=2)
hdul[0].header['NINTS'] = database.obs[key]['NINTS'][j]
hdul[0].header['WCSAXES'] = head_sci['WCSAXES']
hdul[0].header['CRPIX1'] = head_sci['STARCENX']
hdul[0].header['CRPIX2'] = head_sci['STARCENY']
hdul[0].header['CRVAL1'] = head_sci['CRVAL1']
hdul[0].header['CRVAL2'] = head_sci['CRVAL2']
hdul[0].header['CTYPE1'] = head_sci['CTYPE1']
hdul[0].header['CTYPE2'] = head_sci['CTYPE2']
hdul[0].header['CUNIT1'] = head_sci['CUNIT1']
hdul[0].header['CUNIT2'] = head_sci['CUNIT2']
hdul[0].header['CD1_1'] = head_sci['CD1_1']
hdul[0].header['CD1_2'] = head_sci['CD1_2']
hdul[0].header['CD2_1'] = head_sci['CD2_1']
hdul[0].header['CD2_2'] = head_sci['CD2_2']
hdul.writeto(datapath.replace('-KLmodes-all.fits', '-KLmodes-all_roll%.0f.fits' % n_roll), output_verify='fix', overwrite=True)
hdul.close()
n_roll += 1
# Save corresponding observations database.
file = os.path.join(output_dir, key + '.dat')
for col in ['CENTER_MASK', 'ALIGN_MASK', 'CENTER_SHIFT', 'ALIGN_SHIFT']:
if col in database.obs[key].colnames:
database.obs[key][col] = [
str([x.tolist() if hasattr(x, "tolist") else x for x in row])
if row is not None and hasattr(row, '__iter__')
else str(row)
for row in database.obs[key][col]
]
database.obs[key].write(file, format='ascii', overwrite=True)
# Compute and save corresponding transmission mask.
file = os.path.join(output_dir, key + '_psfmask.fits')
mask = get_transmission(database.obs[key])
ww_sci = np.where(database.obs[key]['TYPE'] == 'SCI')[0]
if mask is not None:
hdul = fits.open(database.obs[key]['MASKFILE'][ww_sci[0]])
hdul[0].data = None
hdul['SCI'].data = mask
hdul.writeto(file, output_verify='fix', overwrite=True)
# Read reductions into database.
database.read_jwst_s3_data(datapaths)
pass
[docs]
def get_pyklip_filepaths(database, key, return_maxbasis=False):
'''
Quick wrapper function to get the filepath information (in addition
to the maxnumbasis) for pyKLIP from a spaceKLIP database.
Parameters
----------
database : spaceKLIP.Database
SpaceKLIP database on which pyKLIP shall be run.
key : str
Key for the concatenation of interest in the spaceKLIP database
return_maxbasis : bool, optional
Toggle for whether to additionally return the
maximum number of basis vectors.
Returns
-------
filepaths : 1D-array
List of science image file names
psflib_filepaths : 1D-array
List of reference image file names
maxnumbasis : int, optional
The maximum number of basis vectors available.
'''
filepaths = []
psflib_filepaths = []
first_sci = True
nints = []
nfitsfiles = len(database.obs[key])
for j in range(nfitsfiles):
if database.obs[key]['TYPE'][j] == 'SCI':
filepaths += [database.obs[key]['FITSFILE'][j]]
if first_sci:
first_sci = False
else:
nints += [database.obs[key]['NINTS'][j]]
elif database.obs[key]['TYPE'][j] == 'REF':
psflib_filepaths += [database.obs[key]['FITSFILE'][j]]
nints += [database.obs[key]['NINTS'][j]]
filepaths = np.array(filepaths)
psflib_filepaths = np.array(psflib_filepaths)
nints = np.array(nints)
maxnumbasis = np.sum(nints)
if return_maxbasis:
return filepaths, psflib_filepaths, maxnumbasis
else:
return filepaths, psflib_filepaths