Source code for spaceKLIP.plotting

# =============================================================================
# IMPORTS
# =============================================================================

import os
from itertools import chain

import numpy as np
import warnings

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gs
import matplotlib.patheffects as patheffects
from matplotlib.patches import Rectangle
from matplotlib import font_manager
import scipy.ndimage as ndi

import astropy
import astropy.io.fits as fits
import astropy.units as u
import astropy.visualization as v

from . import wcs_utils
import jwst.datamodels

import shutil
import tempfile
import ipywidgets as widgets
import matplotlib.patches as patches
from matplotlib.backends.backend_pdf import PdfPages
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes, mark_inset

import plotly.graph_objects as go
from plotly.subplots import make_subplots

import logging
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

SPACEKLIP_PLOT_STYLE = os.path.join(os.path.dirname(__file__),'sk_style.mplstyle')
# =============================================================================
# MAIN
# =============================================================================

[docs] def set_default_style(style): " Adjust the default matplotlib style for spaceKLIP plots. " global SPACEKLIP_PLOT_STYLE SPACEKLIP_PLOT_STYLE = style
[docs] def load_plt_style(style=None): """ Load the matplotlib style for spaceKLIP plots. Load the style sheet (default of `sk_style.mplstyle`), which is a modified version of the style sheet from the `webbpsf_ext` package. """ # Wrap in an is None to ensure it uses the current value of SPACEKLIP_PLOT_STYLE if style is None: style = SPACEKLIP_PLOT_STYLE plt.style.use(style)
[docs] def annotate_compass(ax, image, wcs, xf=0.9, yf=0.1, length_fraction=0.07, color='white', bbox_color='#4B0082', fontsize=12): """ Plot a compass annotation onto an image, indicating the directions of North and East. Makes use of the methods from jdaviz, but positions the compass differently: jdaviz defaults to putting it in the center, and filling most of the image. Here we want to make a small compass in the corner. Parameters ---------- ax : matplotlib.Axes The axis on which to draw the compass annotation. image : ndarray A 2D image to be annotated (used just to get the image dimensions). wcs : astropy.wcs.WCS World Coordinate System (WCS) information. xf, yf : float, optional The horizontal (xf) and vertical (yf) positions of the compass as fractions of the image width and height, respectively. Values should be between 0 and 1, where 0 places the compass at the left or bottom edge, and 1 places it at the right or top edge. Default is (0.9, 0.1), placing the compass near the bottom-right corner. length_fraction : float, optional Length of the compass, as a fraction of the size of the entire image color : str, optional The color of the compass arrows and labels. Default is 'white'. bbox_color : str, optional The background color for the text labels ('N' and 'E'). Default is '#4B0082'. If set to None, no background box will be drawn. fontsize : int, optional The font size for the compass labels. Default is 12. Returns ------- None. """ # Use wcs_utils from jdaviz to compute arrow positions. x, y, xn, yn, xe, ye, degn, dege, xflip = wcs_utils.get_compass_info(wcs, image.shape, r_fac=length_fraction) # Calculate the offsets needed to reposition the compass to the desired location. # `xo` and `yo` are the offsets in the x and y directions, respectively. xo = image.shape[1] * xf - x yo = image.shape[0] * yf - y x, xn, xe = [coord + xo for coord in (x, xn, xe)] # Adjust x-coordinates. y, yn, ye = [coord + yo for coord in (y, yn, ye)] # Adjust y-coordinates. # Plot the compass base point on the axis as a small circle. ax.plot(x, y, marker='o', color=color, markersize=4) # Annotate the North ('N') and East ('E') directions on the image. # Use arrows to point from the base of the compass to the respective endpoints. for label, (x_end, y_end) in zip(['N', 'E'], [(xn, yn), (xe, ye)]): bbox = dict(facecolor=bbox_color, alpha=0.5, edgecolor='none') if bbox_color else None ax.annotate(label, xy=(x, y), xytext=(x_end, y_end), arrowprops={'arrowstyle': '<-', 'color': color, 'lw': 1.5}, bbox=bbox, # Semi-transparent background for text. color=color, fontsize=fontsize, va='center', ha='center')
[docs] def annotate_scale_bar(ax, image, wcs, length=1 * u.arcsec, xf=0.1, yf=0.1, color='white', bbox_color='#4B0082', lw=3, fontsize=10): """ Plot a scale bar on an image. Parameters ---------- ax : matplotlib.Axes The axis on which to draw the scale bar. image : ndarray The 2D image to be annotated (used just to get the image dimensions). wcs : astropy.wcs.WCS World Coordinate System (WCS) information. length : astropy.Quantity, optional Length of the scale bar, in arcsec or equivalent unit xf, yf : float, optional The horizontal (xf) and vertical (yf) positions of the compass as fractions of the image width and height, respectively. Values should be between 0 and 1, where 0 places the compass at the left or bottom edge, and 1 places it at the right or top edge. Default is (0.9, 0.1), placing the compass near the bottom-right corner. color : str, optional The color of the scale bar and the text label. Default is 'white'. bbox_color : str, optional The background color for the text label. Default is '#4B0082'. If set to None, no background box will be drawn. lw : float, optional The line width of the scale bar. Default is 3. fontsize : int, optional The font size of the text label displaying the scale bar length. Default is 10. Returns ------- None. """ # Calculate the pixel scale in arcseconds per pixel. pixelscale = astropy.wcs.utils.proj_plane_pixel_scales(wcs).mean() * u.deg sb_length = (length / pixelscale).decompose() # Calculate the position of the scale bar based on image dimensions. xo = image.shape[1] * xf yo = image.shape[0] * yf # Draw the scale bar on the image. ax.plot([xo, xo + sb_length], [yo, yo], color=color, lw=lw) bbox = dict(facecolor=bbox_color, alpha=0.3, edgecolor='none') if bbox_color else None ax.text(xo + sb_length / 2, yo + 0.02 * image.shape[0], length, color=color, bbox=bbox, horizontalalignment='center', fontsize=fontsize)
[docs] def annotate_secondary_axes_arcsec(ax, image, wcs): """ Update an image display to add secondary axes labels in an arcsec. Parameters ---------- ax : matplotlib.Axes The axis on which to add secondary axes labels. image : ndarray The 2D image to be annotated (used just to get the image dimensions). wcs : astropy.wcs.WCS World Coordinate System (WCS) information. Returns ------- None. """ # Define the forward and inverse transforms needed for secondary_axes. # see https://matplotlib.org/3.1.0/gallery/subplots_axes_and_figures/secondary_axis.html pixelscale = astropy.wcs.utils.proj_plane_pixel_scales(wcs) pix2as_x = lambda x: (x - wcs.wcs.crpix[0]) * pixelscale[0] * 3600 pix2as_y = lambda y: (y - wcs.wcs.crpix[1]) * pixelscale[1] * 3600 as2pix_x = lambda x: x / pixelscale[0] / 3600 + wcs.wcs.crpix[0] as2pix_y = lambda y: y / pixelscale[1] / 3600 + wcs.wcs.crpix[1] secax = ax.secondary_xaxis('top', functions=(pix2as_x, as2pix_x)) secax.set_xlabel('Offset [arcsec]', fontsize='small') secay = ax.secondary_yaxis('right', functions=(pix2as_y, as2pix_y)) secay.set_ylabel('Offset [arcsec]', fontsize='small') secax.tick_params(labelsize='small', color='white', which='both') secay.tick_params(labelsize='small', color='white', which='both')
[docs] def compare_clean_methods(files, kwargs_list=None, zoom_region=None, interactive=False): """ Plot and compare spaceKLIP clean methods. For the Bad Pixel Tutorial Notebook. Parameters ---------- files : list of str FITS files containing cleaned products. kwargs_list : list of dict List of parameter dictionaries corresponding to each run. zoom_region : tuple, optional (x0, x1, y0, y1) zoom region. Returns ------- None. """ # Load all runs. runs = [] all_data = [] for i, file in enumerate(files): # Get the cleaned data. data = fits.getdata(file, ext=1) data = data if data.ndim == 2 else data[-1] data = np.asarray(data, dtype=np.float32) all_data.append(data) # Build annotation. text = [] if kwargs_list: method = kwargs_list[i].get("method", f"Run {i+1}") text.append(f"method = {method}") if method != "custom": flat = [] for k, v in kwargs_list[i].items(): if isinstance(v, dict): flat.extend( f"{k2} = {v2}" for k2, v2 in v.items() ) else: flat.append(f"{k} = {v}") text.extend( ", ".join(flat[j:j+2]) for j in range(0, len(flat), 2) ) runs.append(dict(data=data, txt="\n".join(text), html="<br>".join(text))) # Shared scaling. flat = np.concatenate([d.ravel() for d in all_data]) zmin, zmax = np.nanpercentile(flat, [1, 98]) unit = fits.getval(files[0], "BUNIT", ext=1) # STATIC PLOTS. if not interactive: # Setup the figure. n = len(runs) ncols = min(2, n) nrows = (n + ncols - 1) // ncols fig, axes = plt.subplots(nrows * 2, ncols, figsize=(5 * ncols, 7 * nrows), gridspec_kw={"height_ratios": [4, 1] * nrows}, squeeze=False) fig.suptitle(r"Bad Pixel $\it{Clean}$ Method Comparison") for i, run in enumerate(runs): row = (i // ncols) * 2 col = i % ncols ax = axes[row, col] txt = axes[row + 1, col] im = ax.imshow(run["data"], origin="lower", cmap="viridis", vmin=zmin, vmax=zmax) ax.set_title(f"Run {i+1}") ax.set_xlabel("Pixel X") ax.set_ylabel("Pixel Y") if zoom_region: x0, x1, y0, y1 = zoom_region ax.set_xlim(x0, x1) ax.set_ylim(y0, y1) txt.text(0.02, 0.95, run["txt"], va="top", fontsize=8) txt.set_facecolor("#f7f7f7") txt.axis("off") # Hide unused panels. total = nrows * ncols for j in range(n, total): row = (j // ncols) * 2 col = j % ncols axes[row, col].axis("off") axes[row + 1, col].axis("off") cax = fig.add_axes([0.93, 0.15, 0.015, 0.7]) fig.colorbar(im, cax=cax, label=unit) plt.tight_layout(rect=[0.02, 0.02, 0.90, 0.97]) plt.show() # INTERACTIVE PLOTS else: # Setup figure. ncols = min(2, len(runs)) nrows = int(np.ceil(len(runs) / 2)) fig = make_subplots(rows=nrows * 2, cols=ncols, row_heights=[0.8, 0.2] * nrows, vertical_spacing=0.06) for i, run in enumerate(runs): row = 2 * (i // 2) + 1 col = (i % 2) + 1 fig.add_trace(go.Heatmap(z=run["data"], zmin=zmin, zmax=zmax, colorscale="Viridis", showscale=(i == 0), colorbar=dict(title=unit) if i == 0 else None, hovertemplate=("x: %{x}<br>y: %{y}<br>value: %{z:.2f}<extra></extra>")), row=row, col=col) fig.add_trace(go.Scatter(x=[0.5], y=[0.5], text=[run["html"]], mode="text", textfont=dict(size=11), showlegend=False), row=row + 1, col=col) fig.update_xaxes(title="Pixel X", matches="x", row=row, col=col) fig.update_yaxes(title="Pixel Y", matches="y", scaleanchor="x", scaleratio=1, row=row, col=col) fig.update_xaxes(row=row + 1, col=col, visible=False) fig.update_yaxes(row=row + 1, col=col, visible=False) if zoom_region: x0, x1, y0, y1 = zoom_region fig.update_xaxes(range=[x0, x1], row=row, col=col) fig.update_yaxes(range=[y0, y1], row=row, col=col) fig.update_layout( title="Bad Pixel <i>Clean</i> Method Comparison", height=450 * nrows, width=900, margin=dict(l=50, r=50, t=100, b=75)) fig.show()
[docs] def compare_find_methods(files, kwargs_list=None, zoom_region=None, interactive=True): """ Plot and compare bad-pixel masks from multiple runs on a given image. For the Bad Pixel Tutorial Notebook. Parameters ---------- files : list of str FITS files containing DQ and science data. kwargs_list : list of dict, optional Parameter dictionaries corresponding to each run. zoom_region : tuple, optional (x0, x1, y0, y1) zoom region. interactive : bool, optional Interactive Plotly plot or static Matplotlib plot. Default is True. Returns ------- None """ # Use first file as reference image. reference_file = files[0] # Load reference image. data = fits.getdata(reference_file, ext=1) data = data if data.ndim == 2 else data[-1] data = np.asarray(data, dtype=np.float32) unit = fits.getval(reference_file, "BUNIT", ext=1) zmin, zmax = np.nanpercentile(data, [1, 98]) # Load each run. runs = [] for i, file in enumerate(files): # Load the DQ data. dq = fits.getdata(file, extname="DQ") dq = dq if dq.ndim == 2 else dq[-1] # Create DO_NOT_USE mask. mask = (np.asarray(dq, dtype=np.int32) & 1) != 0 y, x = np.where(mask) text = [r"DO_NOT_USE = " + f"{mask.sum():,}"] # Annotate the kwargs for each run. if kwargs_list and kwargs_list[i].get("method") != "custom": flat = [] for k, v in kwargs_list[i].items(): if isinstance(v, dict): flat.extend(f"{k2} = {v2}" for k2, v2 in v.items()) else: flat.append(f"{k} = {v}") text.extend(", ".join(flat[j:j+2]) for j in range(0, len(flat), 2)) runs.append(dict(x=x, y=y, txt="\n".join(text), html="<br>".join(text))) # STATIC PLOT if not interactive: # Setup figure. n = len(runs) ncols = min(2, n) nrows = (n + ncols - 1) // ncols fig, axes = plt.subplots( nrows * 2, ncols, figsize=(6 * ncols, 7 * nrows), squeeze=False, gridspec_kw={"height_ratios": [4, 1] * nrows}) #axes = np.atleast_2d(axes) fig.suptitle(r"Bad Pixel $\it{Find}$ Method Comparison") # Plot each run. for i, run in enumerate(runs): row = (i // ncols) * 2 col = i % ncols ax = axes[row, col] txt = axes[row + 1, col] im = ax.imshow(data, origin="lower", cmap="viridis", vmin=zmin, vmax=zmax) ax.scatter(run["x"], run["y"], s=10, marker="x", color="red", linewidths=0.5) ax.set_title(f"Run {i+1}") ax.set_xlabel("Pixel X") ax.set_ylabel("Pixel Y") if zoom_region: x0, x1, y0, y1 = zoom_region ax.set_xlim(x0, x1) ax.set_ylim(y0, y1) txt.text(0.01, 0.95, r"$\times$", color="red", va="top", fontsize=10) txt.text(0.06, 0.95, run["txt"], va="top", fontsize=8) txt.set_facecolor("#f7f7f7") txt.axis("off") cax = fig.add_axes([0.93, 0.15, 0.015, 0.7]) fig.colorbar(im, cax=cax, label=unit) plt.tight_layout(rect=[0.02, 0.02, 0.90, 0.97]) plt.show() # INTERACTIVE PLOT. else: fig = make_subplots(rows=2, cols=1, row_heights=[0.8, 0.2], vertical_spacing=0.06) # Plot the reference. fig.add_trace(go.Heatmap(z=data, zmin=zmin, zmax=zmax, colorscale="Viridis", colorbar=dict(title=unit)), row=1, col=1) # Overlay the mask. for i, run in enumerate(runs): fig.add_trace(go.Scattergl(x=run["x"], y=run["y"], visible=(i == 0), mode="markers", showlegend=False, marker=dict(symbol="x", size=5, color="red")), row=1, col=1) # Anootation. fig.add_trace(go.Scatter(x=[0.5], y=[0.5], text=[runs[0]["html"]], mode="text", showlegend=False), row=2, col=1) if zoom_region: x0, x1, y0, y1 = zoom_region fig.update_xaxes(range=[x0, x1], row=1, col=1) fig.update_yaxes(range=[y0, y1], row=1, col=1) fig.update_layout( title="Bad Pixel <i>Find</i> Method Comparison", height=900, width=1000, margin=dict(l=50, r=200, t=100, b=50), updatemenus=[dict( x=1.25, y=1, buttons=[dict( label=f"Run {i+1}", method="restyle", args=[{ "visible": [True] + [j == i for j in range(len(runs))] + [True], "text": [None] * (len(fig.data)-1) + [[runs[i]["html"]]] }] ) for i in range(len(runs))] )] ) fig.update_xaxes(title="Pixel X", row=1, col=1) fig.update_xaxes(row=2, col=1, visible=False) fig.update_yaxes(row=2, col=1, visible=False) fig.update_yaxes(title="Pixel Y", scaleanchor="x", scaleratio=1, row=1, col=1) fig.show()
[docs] def display_coron_image(filename, vmin=None, vmax=None, stretch=0.0001, bbox_color='#4B0082', dq_only=False, zoom_center=3, ax=None, plot_style=None): """ Display and annotate a coronagraphic image. Shows image on asinh scale along with some basic metadata, scale bar, and compass. Optionally, a zoomed inset around the image center can be displayed. This display function is designed to flexibly adapt to several different kinds of input data, including rate, rateints, cal, calints files. And pyKLIP's KLmode cubes. Parameters ---------- filename : str The path to the file containing the image data. vmin, vmax : float, optional The minimum/maximum data value to use for scaling the image. If None, determined automatically. stretch : float, optional The stretch factor for the asinh normalization. If None, defaults to 0.0001. bbox_color : str, optional The background color for the text label. Default is '#4B0082'. If set to None, no background box will be drawn. dq_only : bool, optional If True, only the DO_NOT_USE DQ flags are displayed, not the image data itself. zoom_center : int, optional The zoom factor for the inset axis centered on the image's center. Set to None to disable. Returns ------- None. """ # Intialize the matplotlib style. load_plt_style(plot_style) # Early exit for unsupported file types. if 'uncal' in filename: raise RuntimeError("Display code does not support stage 0 uncal files. Reduce the data further before trying to display it.") # Determine the input file type and set corresponding flags. is_pyklip = 'KLmodes' in filename cube_ints = 'rateints' in filename or 'calints' in filename if is_pyklip: # Handle pyKLIP output. # PyKLIP output, we have to open this differently, can't use JWST datamodels. image = fits.getdata(filename) header = fits.getheader(filename) #display(header) try: center_x, center_y = header['STARCENX'], header['STARCENY'] except: warnings.warn(f'Keyword STARCENX/Y missing from fits header. Falling back to CRPIX1/2') center_x, center_y = header['CRPIX1'], header['CRPIX2'] bunit = header['BUNIT'] wcs = astropy.wcs.WCS(header) if image.ndim == 3: image = image[-1] # Select the last KL mode. wcs = wcs.dropaxis(2) # Drop the nints axis. annotation_text = f"pyKLIP results for {header['TARGPROP']}\n{header['FILTER']}" is_psf = False else: # Handle JWST pipeline outputs. # Load in JWST pipeline outputs using jwst.datamodel. header = fits.open(filename)[1].header #display(header) modeltype = jwst.datamodels.CubeModel if cube_ints else jwst.datamodels.ImageModel model = modeltype(filename) image = np.nanmean(model.data, axis=0) if cube_ints else model.data dq = model.dq[0] if cube_ints else model.dq nints = model.meta.exposure.nints if cube_ints else None #center_x, center_y = model.meta.wcsinfo.crpix1, model.meta.wcsinfo.crpix2 try: center_x, center_y = header['STARCENX'], header['STARCENY'] except: warnings.warn(f'Keyword STARCENX/Y missing from fits header. Falling back to CRPIX1/2') center_x, center_y = header['CRPIX1'], header['CRPIX2'] bunit = model.meta.bunit_data is_psf = model.meta.exposure.psf_reference annotation_text = ( f"{model.meta.target.proposer_name}\n" f"{model.meta.instrument.filter}, {model.meta.exposure.readpatt}:" f"{model.meta.exposure.ngroups}:{model.meta.exposure.nints}\n" f"{model.meta.exposure.effective_exposure_time:.2f} s" ) try: # I don't know how to deal with the slightly different API of the GWCS class # so, this is crude, just cast it to a regular WCS and drop the high order distortion stuff # This suffices for our purposes in plotting compass annotations etc. # (There is almost certainly a better way to do this...) wcs = astropy.wcs.WCS(model.meta.wcs.to_fits()[0]) except: wcs = model.get_fits_wcs() if cube_ints: wcs = wcs.dropaxis(2) # Create a bad pixel mask. # Does this file have DQ extension or not? PyKLIP outputs do not. bpmask = np.zeros_like(image) + np.nan bpmask[np.isnan(image) | ((dq & 1) == 1) if not is_pyklip else np.isnan(image)] = 1 # Set up the asinh normalization for image display. stats = astropy.stats.sigma_clipped_stats(image) vmin = vmin if vmin is not None else stats[0] - stats[2] # 1 sigma below mean. vmax = vmax if vmax is not None else np.nanmax(image) # Max value in image. stretch = v.AsinhStretch(a=stretch) norm = v.ImageNormalize(image, interval=v.ManualInterval(vmin, vmax), stretch=stretch) # Create the figure and axis or use provided. if ax is None: fig, ax = plt.subplots(figsize=(16, 9)) else: fig = ax.figure # Get the figure from the provided axes. if not dq_only: im = ax.imshow(image, norm=norm) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.95) cb = fig.colorbar(im, ax=ax, cax=cax, label=bunit) cb.ax.set_yscale('asinh') cb.ax.tick_params(labelsize='small') # Display the bad pixel mask (DO_NOT_USE flags). ax.set_facecolor('#4B0082') imdq = ax.imshow(bpmask, vmin=0, vmax=1.5, cmap=matplotlib.cm.inferno) # Add image annotations. bbox=dict(facecolor=bbox_color, alpha=0.5, edgecolor='none') if bbox_color else None ax.text(0.01, 0.99, annotation_text, bbox=bbox, transform=ax.transAxes, color='white', verticalalignment='top', fontsize=10) ax.set_title(f"{os.path.basename(filename)}\n", fontsize=14) ax.set_xlabel("Pixels", fontsize='small') ax.set_ylabel("Pixels", fontsize='small') ax.tick_params(labelsize='small', color='white', which='both') labelstr = 'PSF Reference' if is_psf else 'Science target after pyKLIP PSF sub.' if is_pyklip else 'Science target' ax.text(0.5, 0.99, labelstr, bbox=bbox, style='italic', fontsize=10, color='white', horizontalalignment='center', verticalalignment='top', transform=ax.transAxes) if cube_ints: ax.text(0.99, 0.99, f"Showing average of {nints} ints", style='italic', fontsize=10, color='white', bbox=bbox, horizontalalignment='right', verticalalignment='top', transform=ax.transAxes) # Annotate compass, scale bar, and secondary axes. annotate_compass(ax, image, wcs, yf=0.07, bbox_color=bbox_color) annotate_scale_bar(ax, image, wcs, yf=0.07, bbox_color=bbox_color) # Annotate secondary axes in arcseconds relative to coronagraph mask center. ax.xaxis.set_tick_params(which='both', bottom=True, top=False) ax.yaxis.set_tick_params(which='both', left=True, right=False) annotate_secondary_axes_arcsec(ax, image, wcs) if zoom_center: # Add a zoomed-in inset of the center. zoom_ax = zoomed_inset_axes(ax, zoom=zoom_center/2, loc=2, bbox_to_anchor=[-0.65, 0.91, 0.1, 0.1], bbox_transform=ax.transAxes) zoom_ax.imshow(image, norm=norm) zoom_ax.imshow(bpmask, vmin=0, vmax=1.5, cmap=matplotlib.cm.inferno) zoom_size = min(center_x, center_y) // zoom_center zoom_ax.set_xlim(center_x - zoom_size, center_x + zoom_size) zoom_ax.set_ylim(center_y - zoom_size, center_y + zoom_size) zoom_ax.tick_params(labelsize='small', color='white', which='both') zoom_ax.set_xlabel("Pixels", fontsize='small') zoom_ax.set_ylabel("Pixels", fontsize='small') # Annotate secondary axes in the zoomed image. zoom_ax.xaxis.set_tick_params(which='both', bottom=True, top=False) zoom_ax.yaxis.set_tick_params(which='both', left=True, right=False) annotate_secondary_axes_arcsec(zoom_ax, image, wcs) mark_inset(ax, zoom_ax, loc1=3, loc2=1, fc="none", ec="0.6", alpha=0.5) fig.tight_layout(rect=[0.3, 0, 1, 1]) fig.show() return ax
[docs] def display_coron_dataset(database, restrict_to=None, save_filename=None, stage3=None, vmin=None, vmax=None, stretch=0.0001, zoom_center=3, dq_only=False, interactive=False, bbox_color='#4B0082', plot_style=None): """ Display multiple files in a coronagraphic dataset. Parameters ---------- database : spaceklip.Database Database of files to plot. restrict_to : str or dict, optional Optional query to filter and display specific data. - `None`: No filtering; all tables are processed. - `str`: Only datasets whose database concatenation (file group) name includes this string will be shown. Most simply, set this to a filter name to only plot images with that filter. - `dict`: Filters tables based on database column values, where keys are column names and values are filter criteria. stage3 : str, optional Specify if data is stage 3. save_filename : str If provided, the plots will be saved to a PDF file with this name. vmin, vmax : float, optional The minimum/maximum data value to use for scaling the image. If None, determined automatically. stretch : float, optional The stretch factor for the asinh normalization. If None, defaults to 0.0001. zoom_center : int, optional The zoom factor for the inset axis centered on the image's center. Set to None to disable. dq_only : bool, optional If True, only the DO_NOT_USE DQ flags are displayed, not the image data itself. interactive : bool, optional If `True`, the plots will be displayed interactively. bbox_color : str, optional The background color for the text label. Default is '#4B0082'. If set to None, no background box will be drawn. Returns ------- None. """ # Intialize the matplotlib style. load_plt_style(plot_style) # Initialize PDF saving if a filename is provided. pdf = PdfPages(save_filename) if save_filename else None # Infer stage3 based on db contents. if stage3 is None: stage3 = hasattr(database, 'red') and len(database.red) > 0 # Select the appropriate dataset and file types. dataset = database.red if stage3 else database.obs types = ['PYKLIP', 'STAGE3'] if stage3 else ['SCI', 'REF'] # Filter files based on the 'restrict_to' criteria provided. filtered_files = [] for key, table in dataset.items(): if isinstance(restrict_to, dict): for col, val in restrict_to.items(): if col not in table.colnames: print(f"Warning: Column '{col}' not found in the observation table for key '{key}'. Skipping over this filtering criteria.") continue # Skip the current filter if the column doesn't exist. vals = val if isinstance(val, list) else [val] table = table[[str(cell) in map(str, vals) for cell in table[col]]] elif isinstance(restrict_to, str) and restrict_to not in key: continue # Skip this key if it doesn't match the 'restrict_to' string. filtered_files += [row['FITSFILE'] for row in table if row['TYPE'] in types] # Loop through each file that matches the specified types. for fn in filtered_files: display_coron_image(fn, vmin=vmin, vmax=vmax, stretch=stretch, zoom_center=zoom_center) if pdf: pdf.savefig(plt.gcf()) if interactive: plt.close() # Close the figure to avoid displaying it. # Optional: interactively slide through the files in the database. if interactive: slider = widgets.IntSlider(value=0, min=0, max=len(filtered_files) - 1, step=1, description='Image Index:') def update_image(index): #plt.clf() display_coron_image(filtered_files[index], vmin=vmin, vmax=vmax, stretch=stretch, zoom_center=zoom_center, dq_only=dq_only, bbox_color=bbox_color) plt.show() out = widgets.interactive_output(update_image, {'index': slider}) display(slider, out) if pdf: pdf.close()
[docs] def display_image_comparisons(database, base_dirs, restrict_to=None, save_filename=None, vmin=None, vmax=None, stretch=0.0001, zoom_center=None, interactive=False, dq_only=False, subtract_first=False, plot_style=None): """ Compare images before and after processing. Parameters ---------- database : spaceklip.Database Database of files to plot. base_dirs : list of str List of base directory names. restrict_to : str or dict, optional Optional query to filter and display specific data. - `None`: No filtering; all tables are processed. - `str`: Only datasets whose database concatenation (file group) name includes this string will be shown. Most simply, set this to a filter name to only plot images with that filter. - `dict`: Filters tables based on database column values, where keys are column names and values are filter criteria. save_filename : str If provided, the plots will be saved to a PDF file with this name. vmin, vmax : float, optional The minimum/maximum data value to use for scaling the image. If None, determined automatically. stretch : float, optional The stretch factor for the asinh normalization. If None, defaults to 0.0001. zoom_center : int, optional The zoom factor for the inset axis centered on the image's center. Set to None to disable. dq_only : bool, optional If True, only the DO_NOT_USE DQ flags are displayed, not the image data itself. subtract_first : bool Whether to subtract the first SCI frame from subsequent frames. interactive : bool, optional If `True`, the plots will be displayed interactively. Returns ------- None. """ # Intialize the matplotlib style. load_plt_style(plot_style) # Initialize PDF saving if a filename is provided. pdf = PdfPages(save_filename) if save_filename else None # Initialize a dictionary to store image details for each base directory. image_files = {base_dir: {'bp_counts': [], 'first_sci_file': None} for base_dir in base_dirs} # Iterate over each key and corresponding table in the database. # Filter files based on the 'restrict_to' criteria provided. filtered_files = [] for key, table in database.obs.items(): if isinstance(restrict_to, dict): for col, val in restrict_to.items(): if col not in table.colnames: print(f"Warning: Column '{col}' not found in the observation table for key '{key}'. Skipping over this filtering criteria.") continue # Skip the current filter if the column doesn't exist. vals = val if isinstance(val, list) else [val] table = table[[str(cell) in map(str, vals) for cell in table[col]]] elif isinstance(restrict_to, str) and restrict_to not in key: continue # Skip this key if it doesn't match the 'restrict_to' string. # Filter for SCI and REF types. filtered_table = [row for row in table if row['TYPE'] in ['SCI', 'REF']] filtered_files.extend(row['FITSFILE'] for row in filtered_table) # Check if any SCI data remains after filtering. if not any(row['TYPE'] == 'SCI' for row in filtered_table): print(f"No SCI type files found in key: {key}.") continue # Identify the first SCI frame for subtraction, store it for later use. first_sci_file = next((row['FITSFILE'] for row in filtered_table if row['TYPE'] == 'SCI'), None) root_dir = first_sci_file.split(os.sep)[0] for base_dir in base_dirs: image_files[base_dir]['first_sci_file'] = os.path.join(root_dir, base_dir, os.path.basename(first_sci_file)) if len(filtered_files)==0: print("No files found. Check 'restrict_to' criteria. Exiting.") return # Create figure of appropriate size. num_dirs = len(base_dirs) num_rows, num_cols = (1, num_dirs) if num_dirs <= 3 else (2, (num_dirs + 1) // 2) # Iterate over the filtered files to process and display images. def update_image(index): fn = filtered_files[index] root_dir = fn.split(os.sep)[0] # Extract the root directory from the file path. base_fn = os.path.basename(fn) # Extract the base filename from the file path. fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 10, num_rows * 10)) axes = axes.flatten() if isinstance(axes, np.ndarray) else [axes] # Iterate over each base directory and its image information. for ax, (base_dir, image_info) in zip(axes, image_files.items()): fn_path = os.path.join(root_dir, base_dir, base_fn) # Full file path. # Count the number of bad (DO_NOT_USE) pixels in the DQ data. dq = fits.getdata(fn_path, extname='DQ') num_bad_pixels = np.sum((dq & 1) == 1) #image_info['bp_counts'].append(num_bad_pixels) # Handle subtraction only if enabled. if subtract_first: # Load and subtract the first SCI frame from the current frame. with tempfile.NamedTemporaryFile(suffix='_' + fn.split('_')[-1], delete=False) as tmp: shutil.copy2(fn_path, tmp.name) with fits.open(tmp.name, mode='update') as hdul: first_sci_frame = fits.getdata(image_info['first_sci_file'], extname='SCI') hdul['SCI'].data -= first_sci_frame.astype(np.float32) # Determine the center of the image. ny, nx = hdul['SCI'].data.shape[-2:] # Handle both 2D and 3D arrays. ax.axhline(y=ny // 2, color='white', linestyle='--', linewidth=1) ax.axvline(x=nx // 2, color='white', linestyle='--', linewidth=1) hdul.flush() fn_path = tmp.name ax = display_coron_image(fn_path, ax=ax, vmin=vmin, vmax=vmax, stretch=stretch, zoom_center=zoom_center, dq_only=dq_only) ax.images[0].set_cmap('RdBu_r' if subtract_first else 'viridis') plt.draw() ax.set_title(base_dir) ax.legend(handles=[patches.Patch(color='orange', label=f"DO_NOT_USE = {num_bad_pixels} px")], loc='lower center', bbox_to_anchor=(0.5, -0.18)) fig.suptitle( f"{os.path.basename(fn)} - {os.path.basename(image_info['first_sci_file'])}" if subtract_first else os.path.basename(fn), fontsize=16) if interactive: plt.show() # Optional: interactively slide through the files in the database. if interactive: slider = widgets.IntSlider(value=0, min=1 if subtract_first else 0, max=len(filtered_files) - 1, step=1, description='Image Index:') out = widgets.interactive_output(update_image, {'index': slider}) display(slider, out) # Static mode. else: for i in range(len(filtered_files)): update_image(i+1 if subtract_first else i) if pdf: pdf.savefig(plt.gcf()) plt.show() plt.close() if pdf: pdf.close()
[docs] def plot_contrast_images(meta, data, data_masked, pxsc=None, savefile='./maskimage.pdf', plot_style=None): """ Plot subtracted images to be used for contrast estimation, one with companions marked, one with the masking adopted. """ # Intialize the matplotlib style. load_plt_style(plot_style) # Set some quick information depending on whether a pixel scale was passed if pxsc == None: extent=(-0.5, data.shape[1]-0.5, data.shape[1]-0.5, -0.5) pxsc = 1 xlabel, ylabel = 'Pixels', 'Pixels' else: extl = (data.shape[1]+1.)/2.*pxsc/1000. # arcsec extr = (data.shape[1]-1.)/2.*pxsc/1000. # arcsec extent = (extl, -extr, -extl, extr) xlabel, ylabel = '$\\Delta$RA [arcsec]', '$\\Delta$Dec [arcsec]' # Initialize plots f, ax = plt.subplots(1, 2, figsize=(2*6.4, 1*4.8)) # Plot subtracted image, circle input companion locations ax[0].imshow(np.log10(np.abs(data[-1])), origin='lower', cmap='inferno', extent=extent) for j in range(len(meta.ra_off)): cc = plt.Circle((meta.ra_off[j]/1000., meta.de_off[j]/1000.), 10.*pxsc/1000., fill=False, edgecolor='green', linewidth=3) ax[0].add_artist(cc) if hasattr(meta, 'ra_off_mask') and hasattr(meta, 'de_off_mask'): for j in range(len(meta.ra_off_mask)): cc = plt.Circle((meta.ra_off_mask[j]/1000., meta.de_off_mask[j]/1000.), 10.*pxsc/1000., fill=False, edgecolor='red', linewidth=3) ax[0].add_artist(cc) ax[0].set_xlabel(xlabel) ax[0].set_ylabel(ylabel) ax[0].set_title('KLIP-subtracted') # Plot subtracted image, with adopted masking ax[1].imshow(np.log10(np.abs(data_masked[-1])), origin='lower', cmap='inferno', extent=extent) ax[1].set_xlabel(xlabel) ax[1].set_ylabel(ylabel) if 'SWB' in savefile or 'LWB' in savefile: ax[1].set_title('Companions & bar masked') else: ax[1].set_title('Companions masked') plt.tight_layout() # Save and close plot plt.savefig(savefile) plt.close() return
[docs] def plot_contrast_raw(meta, seps, cons, labels='default', savefile='./rawcontrast.pdf', plot_style=None): """ Plot raw contrast curves for different KL modes. """ # Intialize the matplotlib style. load_plt_style(plot_style) # Initialize figure plt.figure(figsize=(6.4, 4.8)) ax = plt.gca() # Figure out if we're plotting one contrast curve, or multiple if len(cons) == 1: if labels == 'default': labels == 'contrast' ax.plot(seps, np.squeeze(cons), label=labels) elif len(cons) > 1: if labels == 'default': labels = ['contrast_{}'.format(i+1) for i in range(len(seps))] # Loop over contrast curves to plot for i in range(len(cons)): ax.plot(seps, cons[i], label=labels[i]) # Plot settings ax.set_yscale('log') ax.grid(axis='y') ax.set_xlabel('Separation [arcsec]') ax.set_ylabel('Contrast [5$\sigma$]') ax.set_title('Raw contrast curve') ax.legend(loc='upper right', prop=dict(size=8)) plt.tight_layout() # Save and close plot plt.savefig(savefile) ax.set_xlim([0., 6.]) # arcsec plt.savefig(savefile.replace('raw', 'raw_trim')) plt.close() return
[docs] def plot_injected_locs(meta, data, transmission, seps, pas, pxsc=None, savefile='./injected.pdf', plot_style=None): ''' Plot subtracted image and 2D transmission alongside locations of injected planets. ''' # Intialize the matplotlib style. load_plt_style(plot_style) #Set some quick information depending on whether a pixel scale was passed if pxsc == None: extent=(-0.5, data.shape[1]-0.5, data.shape[1]-0.5, -0.5) extent_tr=(-0.5, transmission.shape[1]-0.5, transmission.shape[1]-0.5, -0.5) pxsc = 1 xlabel, ylabel = 'Pixels', 'Pixels' else: extl = (data.shape[1]+1.)/2.*pxsc/1000. # arcsec extr = (data.shape[1]-1.)/2.*pxsc/1000. # arcsec extent = (extl, -extr, -extl, extr) extl = (transmission.shape[1]/2.)*pxsc/1000 # arcsec extr = (transmission.shape[1]/2.)*pxsc/1000 # arcsec extent_tr = (extl, -extr, -extl, extr) xlabel, ylabel = '$\Delta$RA [arcsec]', '$\Delta$DEC [arcsec]' f, ax = plt.subplots(1, 2, figsize=(2*6.4, 1*4.8)) ax[0].imshow(np.log10(np.abs(data)), origin='lower', cmap='inferno', extent=extent) for i in range(len(meta.ra_off)): cc = plt.Circle((meta.ra_off[i]/1000., meta.de_off[i]/1000.), 10.*pxsc/1000., fill=False, edgecolor='green', linewidth=3) ax[0].add_artist(cc) for i in range(len(seps)): ra = seps[i]*pxsc*np.sin(np.deg2rad(pas[i])) # mas de = seps[i]*pxsc*np.cos(np.deg2rad(pas[i])) # mas cc = plt.Circle((ra/1000., de/1000.), 10.*pxsc/1000., fill=False, edgecolor='red', linewidth=3) ax[0].add_artist(cc) # ax[0].set_xlim([5., -5.]) # ax[0].set_ylim([-5., 5.]) ax[0].set_xlabel(xlabel) ax[0].set_ylabel(ylabel) ax[0].set_title('KLIP-subtracted') p1 = ax[1].imshow(transmission, origin='lower', cmap='viridis', vmin=0., vmax=1., extent=extent_tr) ax[1].text(0.5, 0.9, 'THIS PLOT IS WRONG', fontsize=16, transform=plt.gca().transAxes, c='k', fontweight='bold', ha='center', va='center') c1 = plt.colorbar(p1, ax=ax[1]) c1.set_label('Transmission', rotation=270, labelpad=20) for i in range(len(meta.ra_off)): cc = plt.Circle((meta.ra_off[i]/1000., meta.de_off[i]/1000.), 10.*pxsc/1000., fill=False, edgecolor='green', linewidth=3) ax[1].add_artist(cc) for i in range(len(seps)): ra = seps[i]*pxsc*np.sin(np.deg2rad(pas[i])) # mas de = seps[i]*pxsc*np.cos(np.deg2rad(pas[i])) # mas cc = plt.Circle((ra/1000., de/1000.), 10.*pxsc/1000., fill=False, edgecolor='red', linewidth=3) ax[1].add_artist(cc) # ax[1].set_xlim([5., -5.]) # ax[1].set_ylim([-5., 5.]) ax[1].set_xlabel(xlabel) ax[1].set_ylabel(ylabel) ax[1].set_title('Transmission') plt.tight_layout() plt.savefig(savefile) plt.close() return
[docs] def plot_contrast_calibrated(thrput, med_thrput, fit_thrput, con_seps, cons, corr_cons, savefile='./calcontrast.pdf', plot_style=None): ''' Plot calibrated throughput alongside calibrated contrast curves. ''' # Intialize the matplotlib style. load_plt_style(plot_style) f, ax = plt.subplots(1, 2, figsize=(2*6.4, 1*4.8)) ax[0].plot(med_thrput['seps'], med_thrput['tps'], color='mediumaquamarine', label='Median throughput') ax[0].scatter(thrput['seps'], thrput['tps'], s=75, color='mediumaquamarine', alpha=0.5) ax[0].plot(fit_thrput['seps'], fit_thrput['tps'], color='teal', label='Best fit model') ax[0].set_xlim([fit_thrput['seps'][0], fit_thrput['seps'][-1]]) ax[0].set_ylim([0.0, 1.2]) ax[0].grid(axis='y') ax[0].set_xlabel('Separation [pix]') ax[0].set_ylabel('Throughput') ax[0].set_title('Algo & coronmsk throughput') ax[0].legend(loc='lower right') ax[1].plot(con_seps, cons, color='mediumaquamarine', label='Raw contrast') ax[1].plot(con_seps, corr_cons, color='teal', label='Calibrated contrast') ax[1].set_yscale('log') ax[1].set_xlim([0., np.max(con_seps)]) # arcsec ax[1].set_ylim(top=3e-3) ax[1].grid(axis='y') ax[1].set_xlabel('Separation [arcsec]') ax[1].set_ylabel('Contrast [5$\sigma$]') ax[1].set_title('Calibrated contrast curve') ax[1].legend(loc='upper right') plt.tight_layout() plt.savefig(savefile) plt.close() return
[docs] def plot_fm_psf(meta, fm_frame, data_frame, guess_flux, pxsc=None, j=0, savefile='./fmpsf.pdf', plot_style=None): ''' Plot forward model psf ''' # Intialize the matplotlib style. load_plt_style(plot_style) #Set some quick information depending on whether a pixel scale was passed if pxsc == None: extent=(-0.5, fm_frame.shape[1]-0.5, fm_frame.shape[1]-0.5, -0.5) pxsc = 1 xlabel, ylabel = 'Pixels', 'Pixels' else: extl = (fm_frame.shape[1]+1.)/2.*pxsc/1000. # arcsec extr = (fm_frame.shape[1]-1.)/2.*pxsc/1000. # arcsec extent = (extl, -extr, -extl, extr) xlabel, ylabel = '$\Delta$RA [arcsec]', '$\Delta$DEC [arcsec]' f, ax = plt.subplots(1, 2, figsize=(2*6.4, 1*4.8)) p0 = ax[0].imshow(fm_frame*guess_flux, origin='lower', cmap='inferno', extent=extent) c0 = plt.colorbar(p0, ax=ax[0]) c0.set_label('DN', rotation=270, labelpad=20) cc = plt.Circle((meta.ra_off[j]/1000., meta.de_off[j]/1000.), 10.*pxsc/1000., fill=False, edgecolor='green', linewidth=3) ax[0].add_artist(cc) ax[0].set_xlim([5., -5.]) ax[0].set_ylim([-5., 5.]) ax[0].set_xlabel(xlabel) ax[0].set_ylabel(ylabel) ax[0].set_title(r'FM PSF ($\alpha$ = %.0e)' % guess_flux) p1 = ax[1].imshow(data_frame, origin='lower', cmap='inferno', extent=extent) c1 = plt.colorbar(p1, ax=ax[1]) c1.set_label('DN', rotation=270, labelpad=20) cc = plt.Circle((meta.ra_off[j]/1000., meta.de_off[j]/1000.), 10.*pxsc/1000., fill=False, edgecolor='green', linewidth=3) ax[1].add_artist(cc) ax[1].set_xlim([5., -5.]) ax[1].set_ylim([-5., 5.]) ax[1].set_xlabel(xlabel) ax[1].set_ylabel(ylabel) ax[1].set_title('KLIP-subtracted') plt.tight_layout() plt.savefig(savefile) plt.close() return
[docs] def plot_chains(chain, savefile): ''' Plot MCMC chains from companion fitting ''' f, ax = plt.subplots(4, 1, figsize=(1*6.4, 2*4.8)) ax[0].plot(chain[:, :, 0].T, color='black', alpha=1./3.) ax[0].set_xlabel('Steps') ax[0].set_ylabel(r'$\Delta$RA [mas]') ax[1].plot(chain[:, :, 1].T, color='black', alpha=1./3.) ax[1].set_xlabel('Steps') ax[1].set_ylabel(r'$\Delta$Dec [mas]') ax[2].plot(chain[:, :, 2].T, color='black', alpha=1./3.) ax[2].set_xlabel('Steps') ax[2].set_ylabel(r'$\alpha$ [sec/pri]') ax[3].plot(chain[:, :, 3].T, color='black', alpha=1./3.) ax[3].set_xlabel('Steps') ax[3].set_ylabel(r'$l$ [pix]') plt.suptitle('MCMC chains') plt.tight_layout() plt.savefig(savefile) plt.close()
[docs] def plot_subimages(imgdirs, subdirs, filts, submodes, numKL, window_size=2.5, cmaps_list=['viridis'], imgVmin=[-40], imgVmax=[40], subVmin=[-10], subVmax=[10], labelpos=[0.04, 0.05], imtext_col='w', showKL=True, useticklabels=True, cbar_textoff=1, hspace=0.05, wspace=0.05, plot_style=None): ''' Create a "publication ready" plot of the coronagraphic images, alongside the PSF subtracted images. A grid of images will be made. Rows will correspond to filters, first column the unsubtracted PSF, following columns different submodes and numKLs Parameters ---------- imgdirs : list of strings Parent directories of the unsubtracted images, filters won't be repeated subdirs : list of strings Parent directories of the subtracted images filts : list of strings List of filter strings to include in the plot submodes : list of strings 'ADI', 'RDI', 'RDI+ADI' (or 'ADI+RDI') numKL : list of ints output images are 3-D with the third axis corresponding to different KL modes used in subtraction. Indicate the number of KL modes you want to display, if more than one, display each for each submode on the same row . window_size : float the number of arcseconds on a side cmaps_list : list a list of strings naming the colormaps for each filter. If len(cmaps_list)==1 the same cmap will be used for each filter. imgVmin : list a list of the min values for the the cmaps for each filter. if len(imgVmin)==1 the same min value will be used for each filter. imgVmax : list a list of the max values for the the cmaps for each filter. if len(imgVmax)==1 the same max value will be used for each filter. subVmin : list a list of the min values for the the subtracted image cmaps for each filter. if len(imgVmin)==1 the same min value will be used for each filter. subVmax : list a list of the max values for the the subtracted image cmaps for each filter. if len(imgVmin)==1 the same min value will be used for each filter. labelpos: list Position of the [RA, Dec] axis labels (figure coords) imtext_col: str Color of any text / arrows that go on top of the images ''' from matplotlib.ticker import MultipleLocator, MaxNLocator # Intialize the matplotlib style. load_plt_style(plot_style) # Get the files we care about imgfiles = sorted(list(chain.from_iterable([glob.glob(imgdir+'*.fits') for imgdir in imgdirs]))) subfiles = list(chain.from_iterable([glob.glob(subdir+'*.fits') for subdir in subdirs])) filts = [filt.upper() for filt in filts] # Filter the imgfiles used_filts = [] true_imgfiles = {} for imgfile in imgfiles: hdr = fits.getheader(imgfile) if hdr.get('SUBPXPTS') is not None: # This keyword only exists when larger than 1. If so, # this is a dithered reference observation, which we don't want. continue elif (hdr['EXP_TYPE'] != 'NRC_CORON') and (hdr['EXP_TYPE'] != 'MIR_4QPM'): # Don't want TACONFIRM or TACQ continue elif hdr['FILTER'] not in filts: # We don't want this filter continue elif hdr['FILTER'] in used_filts: # We've already got a file for this filter continue else: # We want this file print(imgfile) true_imgfiles[hdr['FILTER']] = imgfile used_filts.append(hdr['FILTER']) # Filter the subfiles # Note that we allow repeat filters for different reductions. true_subfiles = [] for subfile in subfiles: print(subfile) if any(filt in subfile for filt in filts): true_subfiles.append(subfile) #read the subtracted images, store relevant info sub_dict = {} for flt in filts: for fn in true_subfiles: if flt in fn: print(flt) with fits.open(fn) as hdul: imgcube = hdul[0].data psfparam = hdul[0].header['PSFPARAM'] center_pix = (int(np.rint(hdul[0].header['PSFCENTY'])), int(np.rint(hdul[0].header['PSFCENTX']))) mode = psfparam.split('mode=')[-1].split(',')[0].replace("'", '')#ADI/RDI/ADI+RDI if '+' in mode: reverse_mode = mode.split('+')[-1] + '+' + mode.split('+')[0] else: reverse_mode = None numbasis = [int(KLnum) for KLnum in psfparam.split('numbasis=')[-1].split(',')[0].split(']')[0].split('[')[-1].split()] if flt in sub_dict.keys(): sub_dict[flt][mode] = {'image':imgcube, 'numbasis':numbasis, 'filename':fn, 'center_pix':center_pix} if reverse_mode is not None: sub_dict[flt][reverse_mode] ={'image':imgcube, 'numbasis':numbasis, 'filename':fn, 'center_pix':center_pix} else: sub_dict[flt] = {mode:{'image':imgcube, 'numbasis':numbasis, 'filename':fn, 'center_pix':center_pix}} if reverse_mode is not None: sub_dict[flt][reverse_mode] = {'image':imgcube, 'numbasis':numbasis, 'filename':fn, 'center_pix':center_pix} #miri centers from pyklip v2.6 (y, x) miri_img_centers = {'F1065C': (int(np.rint(111.89-1)), int(np.rint(120.81-1))), 'F1140C': (int(np.rint(112.2-1)) , int(np.rint(119.99-1))) , 'F1550C': (int(np.rint(113.33-1)), int(np.rint(119.84-1))), } pltscale = {'NIRCAM': 0.063, 'MIRI': 0.11} plot_extent = (-1*window_size/2, window_size/2, -1*window_size/2, window_size/2) ydim = len(filts) xdim = len(submodes)*len(numKL) + 1 wratios = [1]*xdim wratios.append(0.015*xdim) # Start making the figure fig = plt.figure(figsize=[xdim*3, ydim*3]) grid = gs.GridSpec(ydim, xdim+1, figure=fig, wspace=wspace, hspace=hspace, width_ratios=wratios, left=0.06, right=0.94, bottom=0.08, top=0.93) if len(cmaps_list) == 1: cmaps_list *= len(filts) if len(imgVmin) == 1: imgVmin *= len(filts) if len(imgVmax) == 1: imgVmax *= len(filts) if len(subVmin) == 1: subVmin *= len(filts) if len(subVmax) == 1: subVmax *= len(filts) for row, flt in enumerate(filts): prihdr = fits.getheader(true_imgfiles[flt]) instrument = prihdr['INSTRUME'] ax = plt.subplot(grid[row, 0]) with fits.open(true_imgfiles[flt]) as hdul: img = np.nanmedian(hdul['SCI'].data, axis=0) hdr = hdul['SCI'].header roll = hdr['PA_V3'] + hdr['V3I_YANG'] # get center information if instrument == 'NIRCAM': center_pix = (int(np.rint(hdr.get('STARCENY', hdr.get('MASKCENY', hdr['CRPIX2']))-1)), int(np.rint(hdr.get('STARCENX', hdr.get('MASKCENX', hdr['CRPIX1']))-1)), ) window_pix = int(np.rint(window_size / pltscale['NIRCAM'] / 2)) offset = (window_pix - window_size / 0.063 / 2) *0.063 else: # Potentially use same center_pix code as for NIRCAM case above center_pix = miri_img_centers[flt] window_pix = int(np.rint(window_size / pltscale['MIRI'] / 2)) offset = (window_pix - window_size / 0.11 / 2) *0.11 # Rotate unsubtracted image so North is up x = np.linspace(0, img.shape[0], img.shape[0]) y = np.linspace(0, img.shape[1], img.shape[1]) xmesh, ymesh = np.meshgrid(x, y) old_center = [img.shape[1]/2, img.shape[0]/2] #y, x xmesh += center_pix[1] ymesh += center_pix[0] xmesh -= old_center[1] ymesh -= old_center[0] new_data = ndi.map_coordinates(img, [ymesh, xmesh]) rot_img = ndi.rotate(new_data, -roll) rotshape = rot_img.shape focus_slices_img = (slice(int(rotshape[0]/2) - window_pix, int(rotshape[0]/2) + window_pix), slice(int(rotshape[1]/2) - window_pix, int(rotshape[1]/2) + window_pix)) ax.imshow(rot_img[focus_slices_img], extent=plot_extent, cmap=cmaps_list[row], vmin=imgVmin[row], vmax=imgVmax[row]) if row == 0: unsub = ax.text(0.95, 0.88, 'Unsub', fontsize=16, transform=plt.gca().transAxes, c='w', fontweight='bold', ha='right') # Plot N-E arrows ar_n = ax.arrow(0.9, 0.096, 0.0, 0.1, transform=plt.gca().transAxes, color=imtext_col, width=0.01, \ head_width=0.04, head_length=0.04, path_effects=[patheffects.Stroke(linewidth=3, foreground='k'), patheffects.Normal()]) n = ax.text(0.9, 0.29, 'N', fontsize=16, transform=plt.gca().transAxes, c=imtext_col, \ va='center', ha='center', fontweight='bold') ar_e = ax.arrow(0.905, 0.1, -0.1, 0.0, transform=plt.gca().transAxes, color=imtext_col, width=0.01, \ head_width=0.04, head_length=0.04, path_effects=[patheffects.Stroke(linewidth=3, foreground='k'), patheffects.Normal()]) e = ax.text(0.71, 0.1, 'E', fontsize=16, transform=plt.gca().transAxes, c=imtext_col, \ va='center', ha='center', fontweight='bold') # Plot 1" Line arc1text = ax.text(-int(window_size/2)+0.5, -int(window_size/2)+(window_size/75), '1"', fontsize=16, c=imtext_col, \ va='bottom', ha='center', fontweight='bold') ax.add_patch(Rectangle((-int(window_size/2), -int(window_size/2)-0.075), 1, 0.15, alpha=1, facecolor='w', edgecolor='k', linewidth=1)) #ax.set_ylabel(prihdr['FILTER']+'\n\n', fontsize=18)#, fontproperties=compModern(14)) #Draw lines around image text etc ftext = ax.text(0.05, 0.88, prihdr['FILTER'], fontsize=16, transform=plt.gca().transAxes, c='w', fontweight='bold') for ti, temp in enumerate([unsub, n, e, ftext, arc1text]): temp = temp.set_path_effects([patheffects.withStroke(linewidth=2, foreground='k')]) # Draw another arrow on top to mask the line ax.arrow(0.9, 0.096, 0.0, 0.1, transform=plt.gca().transAxes, color=imtext_col, width=0.01, \ head_width=0.04, head_length=0.04) #ax.xaxis.set_major_locator(MaxNLocator(nbins=7, min_n_ticks=3, prune='both')) ax.yaxis.set_major_locator(MultipleLocator(1)) ax.xaxis.set_major_locator(MultipleLocator(1)) if row != (len(filts)-1): plt.setp(ax.get_xticklabels(), visible=False) if useticklabels == False: plt.setp(ax.get_xticklabels(), visible=False) plt.setp(ax.get_yticklabels(), visible=False) # else: # ax.set_xlabel('RA Offset (")')#, fontproperties=compModern(14)) # tickFont(ax, 'x', fontproperties=compModern(10)) # tickFont(ax, 'y', fontproperties=compModern(10)) ax.tick_params(which='major', length=4, color=imtext_col) ax.tick_params(which='minor', length=0, color=imtext_col) for axis in ['top','bottom','left','right']: ax.spines[axis].set_linewidth(2) for mm, mde in enumerate(submodes): focus_slices_sub = (slice(sub_dict[flt][mde]['center_pix'][0] - window_pix, sub_dict[flt][mde]['center_pix'][0] + window_pix), slice(sub_dict[flt][mde]['center_pix'][1] - window_pix, sub_dict[flt][mde]['center_pix'][1] + window_pix)) for kk, nkl in enumerate(numKL): if nkl == 'max': nkl = sub_dict[flt][mde]['numbasis'][-1] column = 1 + (len(numKL)*mm + kk) ax = plt.subplot(grid[row, column]) plotimg = sub_dict[flt][mde]['image'][sub_dict[flt][mde]['numbasis'].index(nkl),:,:][focus_slices_sub] plotimg[np.where(np.isnan(plotimg))] = 0 subimg = ax.imshow(plotimg, extent=plot_extent, cmap=cmaps_list[row], vmin=subVmin[row], vmax=subVmax[row]) star = ax.scatter(0.+offset, 0.+offset, marker='*', color='w', s=100) star.set_path_effects([patheffects.withStroke(linewidth=3, foreground='k')]) ax.xaxis.set_major_locator(MultipleLocator(1)) ax.yaxis.set_major_locator(MultipleLocator(1)) plt.setp(ax.get_yticklabels(), visible=False) ax.tick_params(which='major', length=4, color=imtext_col) ax.tick_params(which='minor', length=0, color=imtext_col) for axis in ['top','bottom','left','right']: ax.spines[axis].set_linewidth(2) if row != (len(filts)-1): plt.setp(ax.get_xticklabels(), visible=False) if useticklabels == False: plt.setp(ax.get_xticklabels(), visible=False) plt.setp(ax.get_yticklabels(), visible=False) # else: #ax.xaxis.set_major_locator(MaxNLocator(nbins=7, min_n_ticks=3, prune='both')) #tickFont(ax, 'x', fontproperties=compModern(10)) # ax.set_x label('RA Offset (")')#, fontproperties=compModern(14)) if row == 0: if showKL: ax.set_title('{:s}, KL={:d}'.format(mde, nkl), fontweight='bold')#, fontproperties=compModern(16)) else: temp = ax.text(0.95, 0.88, mde, fontsize=16, transform=plt.gca().transAxes, c='w', fontweight='bold', ha='right') temp.set_path_effects([patheffects.withStroke(linewidth=2, foreground='k')]) # Plot the scale bar cax = plt.subplot(grid[row,-1]) cb = plt.colorbar(mappable=subimg, cax=cax) cb.outline.set_linewidth(2) cax.tick_params(which='both', color='k', labelsize=14, width=2, direction='out') for axis in ['top','bottom','left','right']: cax.spines[axis].set_linewidth(2) if flt != 'F1550C': cax.yaxis.set_major_locator(MaxNLocator(integer=True)) # Add label for color bars fig.text(cbar_textoff, 0.5, "Counts (MJy/sr)", rotation=270, va='center', fontsize=18) if useticklabels != False: fig.text(0.5, labelpos[0], 'RA Offset (")', ha='center', fontsize=18) fig.text(labelpos[1], 0.5, 'Dec Offset (")', va='center', rotation='vertical', fontsize=18) fig.tight_layout() fig.tight_layout() fig.tight_layout() #plt.show() return fig
[docs] def tickFont(ax, xy, fontproperties): if xy == 'x': plt.setp(ax.get_xticklabels(), fontproperties=fontproperties) if xy == 'y': plt.setp(ax.get_yticklabels(), fontproperties=fontproperties)
[docs] def compModern(size=20): computerModern = font_manager.FontProperties(size=size, fname=os.path.join(os.path.dirname(__file__), 'cmunrm.ttf')) return computerModern