"""
Main tools to drive exposure level processing.
Functions
---------
wcs_from_config
Extracts a WCS from the configuration file.
initializationstep
Creates and initializes L2 data.
saturation_check
Flags saturated pixels in a 3D cube.
subtract_dark_current
Subtracts dark current in a 3D cube.
repackage_wcs
Packages a WCS so that it can be handed to romanisim.
calibrateimage
L1->L2 driver.
"""
import sys
import warnings # noqa: F401
import asdf
# not actually doing a simulation but needed to pass around the WCS types
import galsim # noqa: F401
import numpy as np
import yaml
from astropy import units as u
from astropy.io import fits
from roman_datamodels import datamodels
from roman_datamodels.dqflags import group, pixel
from romancal.datamodels.fileio import open_dataset
from romancal.dq_init import dq_initialization
from romancal.saturation import saturation
from romanisim import image as rimage
from romanisim import persistence as rip
from romanisim import wcs as riwcs
from .. import pars
from ..utils import (
coordutils,
fitting,
flatutils,
ipc_linearity,
maskhandling,
processlog,
reference_subtraction,
sky,
typefix,
)
# local imports
from . import oututils
### function definitions below here
[docs]
def wcs_from_config(config):
"""
Gets a WCS object from the configuration.
Currently supports FITS headers imported from a simulation.
Parameters
----------
config : dict
Configuration dictionary (usually imported from YAML).
Returns
-------
astropy.io.fits.header.Header
The WCS as a FITS header.
"""
if "FITSWCS" in config:
with open(config["FITSWCS"]) as f:
return fits.Header.fromstring(f.read())
# if no WCS was found, just return None (we'll deal with this later)
return None
[docs]
def initializationstep(config, caldir, mylog, exclude_first=False):
"""
Initialization step.
Parameters
----------
config : dict
Configuration dictionary (usually imported from YAML).
caldir : dict
Locations of calibration files.
mylog : romanimpreprocess.utils.processlog.ProcessLog
Processing log.
exclude_first : bool
if True, mark first resultant as DO_NOT_USE
Returns
-------
ramp_model : RampModel
ramp data model including data, groupdq, pixeldq, metadata
meta: dict
Other metadata (right now: frame_time and read_pattern)
"""
if "mask" in caldir:
maskfile = asdf.open(caldir["mask"])
mask = datamodels.MaskRefModel.create_from_model(maskfile["roman"])
else:
mask = None
with open_dataset(config["IN"], update_version=True) as l1model:
ramp_model = dq_initialization.do_dqinit(l1model, mask, expand_gw_flagging=1)
if "mask" in caldir:
maskfile.close()
meta = {
"frame_time": ramp_model.meta.exposure.frame_time,
"read_pattern": ramp_model.meta.exposure["read_pattern"],
}
# more information
meta["ngrp"] = len(meta["read_pattern"])
meta["tbar"] = np.zeros(meta["ngrp"], dtype=np.float32)
meta["tau"] = np.zeros(meta["ngrp"], dtype=np.float32)
meta["N"] = np.zeros(meta["ngrp"], dtype=np.int16)
for i in range(meta["ngrp"]):
# N_i, tbar_i, and tau_i as defined in Casertano et al. 2022
meta["N"][i] = len(meta["read_pattern"][i])
t0 = meta["read_pattern"][i][0]
meta["tbar"][i] = (t0 + (meta["N"][i] - 1) / 2.0) * meta["frame_time"]
meta["tau"][i] = (t0 + (meta["N"][i] - 1) * (2 * meta["N"][i] - 1) / (6.0 * meta["N"][i])) * meta[
"frame_time"
]
if exclude_first:
ramp_model["groupdq"][0, ...] |= group.DO_NOT_USE
return ramp_model, meta
[docs]
def saturation_check(ramp_model, caldir, mylog, backup=1, skip_firstn=1):
"""
Flags saturated pixels (in both 3D and 2D arrays).
Performs a saturation check on the data cube (`data`) using the calibration files in `caldir`.
Information is appended to `mylog`. The flags `rdq` and `pdq` are updated in place.
This function serves as a wrapper for ``flag_saturation`` (imported from ``romancal``).
Parameters
----------
ramp_model : roman_datamodels.datamodels.RampModel
data model including resultant cube
caldir : dict
Locations of calibration files.
mylog : romanimpreprocess.utils.processlog.ProcessLog
Processing log.
backup : int
Number of resultants to "back up" when flagging saturation.
skip_firstn : int
Do not check the first n resultants in ramp_model.data for saturation.
Returns
-------
None
"""
with asdf.open(caldir["saturation"]) as satreffile:
satref = datamodels.SaturationRefModel.create_from_model(satreffile["roman"])
if skip_firstn != 0:
old_data = ramp_model.data
old_dq = ramp_model.groupdq
old_read_pattern = ramp_model.meta.exposure.read_pattern
ramp_model.data = old_data[skip_firstn:, ...]
ramp_model.groupdq = old_dq[skip_firstn:, ...]
ramp_model.meta.exposure.read_pattern = ramp_model.meta.exposure.read_pattern[skip_firstn:]
saturation.flag_saturation(ramp_model, satref, n_pix_grow_sat=1, backup=backup)
if skip_firstn != 0:
ramp_model.data = old_data
ramp_model.groupdq = old_dq
ramp_model.meta.exposure.read_pattern = old_read_pattern
[docs]
def subtract_dark_current(data, rdq, pdq, caldir, meta, mylog):
"""
Subtracts dark current from a linearized image.
The `data`, `rdq`, and `pdq` fields are updated in place.
Parameters
----------
data : np.array
3D data cube (in DN_lin, shape ngroup,4096,4096)
rdq : np.array
3D ramp data quality (uint32, shape ngroup,4096,4096)
pdq : np.array
2D pixel data quality (uint32, shape 4096,4096)
caldir : dict
Locations of calibration files.
meta : dict
Metadata dictionary (from L1 ASDF tree).
mylog : romanimpreprocess.utils.processlog.ProcessLog
Processing log.
Returns
-------
np.array
The 2D image of subtracted dark current in DN/s.
"""
with asdf.open(caldir["dark"]) as f:
dcsub = np.copy(f["roman"]["dark_slope"])
ngrp = meta["ngrp"]
for j in range(ngrp):
data[j, :, :] -= meta["tbar"][j] * dcsub
return dcsub
[docs]
def repackage_wcs(thewcs):
"""
Packages a WCS to feed to romanisim.
Right now supports FITS-standard headers from a simulation.
Since this for compatibility in ramp-fitting routines, can use this
and overwrite the WCS in the L2 ASDF tree with a full-accuracy gwcs
at a later stage.
Parameters
----------
thewcs : astropy.io.fits.Header or galsim.CelestialWCS
Input WCS.
Returns
-------
class
Packaged WCS, 2 layers deep for compatibility with romanisim.
"""
# make WCS --- a few ways of doing this
while True:
wcsobj = None
class Blank:
pass
# first try a FITS header
if isinstance(thewcs, fits.Header):
wcsobj = Blank()
wcsobj.header = Blank()
wcsobj.header.header = thewcs
break
# should work if this is a GalSim WCS
# I commented this option out since I think it has a bug related to the 0 vs 1 offset,
# but we're currently not using it.
# Make sure to test it if you un-comment this.
# -C.H. 02/12/26
#
# try:
# header = fits.Header()
# thewcs.writeToFitsHeader(header, galsim.BoundsI(0, pars.nside_active, 0, pars.nside_active))
# # offset to FITS convention -- this is undone later
# header["CRPIX1"] += 1
# header["CRPIX2"] += 1
# wcsobj = Blank()
# wcsobj.header = Blank()
# wcsobj.header.header = header
# warnings.warn("Use of GalSim WCS in calibrate is not fully working yet!")
# break
# except Exception as e:
# wcsobj = None
# raise Exception("Unrecognized WCS") from e
return wcsobj
[docs]
def calibrateimage(config, verbose=True):
"""
Main routine to run the specified calibrations from a config file.
Parameters
----------
config : dict
Configuration (likely unpacked from a YAML file).
verbose : bool, optional
Whether to print lots of intermediate stuff to the terminal.
Returns
-------
None
"""
# setup
mylog = processlog.ProcessLog()
# get an initial WCS (if provided)
# in some simulations we may need to give this if the input stars themselves are simulated
thewcs = wcs_from_config(config)
caldir = config["CALDIR"]
backup = config.get("SATURATION_BACKUP", 1)
# initialize a data cube and data quality
ramp_model, meta = initializationstep(config, caldir, mylog)
nb = meta["nborder"] = pars.nborder
mylog.append("Initialized data\n")
# saturation check
saturation_check(ramp_model, caldir, mylog, backup=backup)
mylog.append("Saturation check complete\n")
data, rdq, pdq, l1meta, amp33 = (
ramp_model["data"],
ramp_model["groupdq"],
ramp_model["pixeldq"],
ramp_model.meta,
ramp_model["amp33"],
)
(ngrp, ny, nx) = np.shape(data)
# reference pixel correction -- right now using a 5-pixel filter of the left & right ref pixels
# and the top & bottom pixel subtraction functions from Laliotis et al. (2024)
# **This is a placeholder until:
# - improved reference pixel correction from GSFC group should be available
#
slope = None # will overwrite later
with asdf.open(caldir["dark"]) as f:
# rsub = np.zeros((ngrp, pars.nside), dtype=np.float32)
for j in range(ngrp):
image = np.zeros((pars.nside, pars.nside_augmented), dtype=np.float32)
image[:, : pars.nside] = data[j, :, :] - f["roman"]["data"][j, :, :]
with asdf.open(caldir["read"]) as fr:
if "amp33" in fr["roman"]:
image[:, -pars.channelwidth :] = amp33[j, :, :] - fr["roman"]["amp33"]["med"]
image[:, -pars.channelwidth :] -= np.median(image[:, -pars.channelwidth :])
# compute optimal slope, but only once
if slope is None:
a = fr["roman"]["amp33"]
cvar = fr["roman"]["anc"]["C_PINK"] ** 2
slope = (
a["M_PINK"]
* cvar
/ (
a["M_PINK"] ** 2 * cvar
+ a["RU_PINK"] ** 2
+ np.median(a["std"]) ** 2 / 128 / np.log(4096)
)
)
image = reference_subtraction.ref_subtraction_row(image, use_ref_channel=True, slope=slope)
image = reference_subtraction.ref_subtraction_channel(image, use_ref_channel=True)
data[j, :, :] = image[:, : pars.nside] + f["roman"]["data"][j, :, :]
# bias correction
if "biascorr" in caldir:
with asdf.open(caldir["biascorr"]) as f:
data[:, nb:-nb, nb:-nb] -= f["roman"]["data"]
mylog.append("Included bias correction\n")
else:
mylog.append("Skipping bias correction\n")
# linearity correxction
# ** right now applies the linearity to a group average, which isn't strictly correct **
# ** will fix this in a future upgrade! **
data, dq_lin = ipc_linearity.multilin(
data,
caldir["linearitylegendre"], # the linearity cube
do_not_flag_first=meta["read_pattern"][0]
== [0], # don't flag the first read for being off scale if it is the reset
attempt_corr=~rdq
& pixel.SATURATED, # don't flag saturated pixels as having a bad linearity correction
)
if len(np.shape(dq_lin)) == 2:
rdq |= dq_lin[None, :, :]
else:
rdq |= dq_lin
del dq_lin # we have everything we need
mylog.append("Linearity correction complete\n")
# now data is in linearized DN, floating point
# subtract out dark current
# dcsub is the dark current that was subtracted --- data is updated in place
subtract_dark_current(data, rdq, pdq, caldir, meta, mylog) # removed dcsub= assignment as it isn't used
mylog.append("Dark current subtracted")
# IPC correction
if "ipc4d" in caldir:
ipc_linearity.correct_cube(data, caldir["ipc4d"], mylog, gain_file=caldir["gain"])
else:
mylog.append("skipping IPC correction\n")
# ramp fitting
uopt = {"slope": 0.4, "gain": 1.8, "sigma_read": 6.5}
if "RAMP_OPT_PARS" in config:
uopt = config["RAMP_OPT_PARS"]
u_ = float(uopt["slope"]) / float(uopt["gain"]) / float(uopt["sigma_read"]) ** 2
meta["K"] = fitting.construct_weights(u_, meta, exclude_first=True)
mylog.append(f"\n\nRamp fit optimized for u = {u_:11.5E} s**-1\n")
mylog.append("weights = {}\n".format(meta["K"]))
if "JUMP_DETECT_PARS" in config:
meta["jump_detect_pars"] = config["JUMP_DETECT_PARS"]
slope, slope_err_read, slope_err_poisson = fitting.ramp_fit(
data, rdq, pdq, meta, caldir, mylog, exclude_first=True
)
# apply flat field
flat = flatutils.get_flat(caldir, meta, pdq)
# this is the ratio of the true pixel area to the reference area (0.11 arcsec)^2
AreaFactor = (
coordutils.pixelarea(riwcs.convert_wcs_to_gwcs(repackage_wcs(thewcs)), N=np.shape(slope)[-1])
/ pars.Omega_ideal
)
flat = (flat / AreaFactor).astype(np.float32)
mylog.append("acquired flat field\n")
for p in [1, 2, 5, 10, 25, 50, 75, 90, 95, 98, 99]:
mylog.append(f" {p:2d}%ile = {np.percentile(flat, p):6.4f},")
mylog.append("\n")
slope /= flat
slope_err_read /= flat
slope_err_poisson /= flat
# need the median gain to send to a file
with asdf.open(caldir["gain"]) as g_:
medgain = np.median(g_["roman"]["data"])
mylog.append(f"median gain = {medgain:8.5f} e/DN\n")
# blank persistence object right now
persistence = rip.Persistence()
# sky information
slope_withsky = np.copy(slope) # version before sky subtraction
m = maskhandling.PixelMask1.build(pdq)
medsky, _ = sky.smooth_mode(sky.binkxk(np.where(np.logical_not(m), slope, np.nan), 4))
# if the configuration asks for simple subtraction, do it
if "SKYORDER" in config:
skyorder = int(config["SKYORDER"])
skycoefs, skymodel = sky.medfit(slope[nb:-nb, nb:-nb], order=skyorder)
slope[nb:-nb, nb:-nb] -= skymodel
del skymodel
else:
skycoefs = np.array([]).astype(np.float32)
skyorder = -1 # not used
im2, extras2 = rimage.make_asdf(
slope[nb:-nb, nb:-nb] * u.DN / u.s,
(slope_err_read[nb:-nb, nb:-nb] * u.DN / u.s) ** 2,
(slope_err_poisson[nb:-nb, nb:-nb] * u.DN / u.s) ** 2,
metadata=l1meta,
persistence=persistence,
dq=pdq[nb:-nb, nb:-nb],
imwcs=repackage_wcs(thewcs),
gain=medgain,
)
# strip unit from certain fields if not needed
for x in ["data", "var_poisson", "var_rnoise", "var_flat", "err"]:
if x in im2 and hasattr(im2[x], "value"):
im2[x] = im2[x].value
oututils.add_in_ref_data(im2, config["IN"], rdq, pdq)
# update the metadata
# oututils.update_flags(im2, "gen_cal_image") # <-- this doesn't work with updated roman_datamodels,
# but it isn't essential
if "cal_step" in im2["meta"]:
im2["meta"]["cal_step"]["wfi18_transient"] = "INCOMPLETE"
im2["meta"]["cal_step"]["dark_decay"] = "INCOMPLETE"
oututils.add_in_provenance(im2, "gen_cal_image")
# process information specific to this code
processinfo = {
"medsky": medsky,
"medgain": medgain,
"skyorder": skyorder,
"skycoefs": skycoefs,
"ramp_opt_pars": uopt,
"meta": meta,
"weights": meta["K"],
"config": config,
"log": mylog.output,
"exclude_first": True,
}
# this is for getting the ramp data so we know which range was used
# (max 127 groups)
if "SLICEOUT" in config:
if config["SLICEOUT"]:
if ngrp >= 128:
raise ValueError("too many groups")
endslice = np.zeros((pars.nside_active, pars.nside_active), dtype=np.int8) - 1
nb = pars.nborder
for iend in range(1, ngrp):
endslice = np.where(
rdq[iend, nb:-nb, nb:-nb] & ~rdq[iend - 1, nb:-nb, nb:-nb] & pixel.SATURATED != 0,
iend - 1,
endslice,
)
processinfo["endslice"] = endslice
# Write file
with asdf.AsdfFile() as af2:
af2.tree = {"roman": im2, "processinfo": processinfo}
af2.tree["roman"]["data_withsky"] = slope_withsky[nb:-nb, nb:-nb]
if hasattr(af2.tree["roman"]["data_withsky"], "value"):
af2.tree["roman"]["data_withsky"] = af2.tree["roman"]["data_withsky"].value
if "cal_step" in af2.tree["roman"]["meta"]:
print(af2.tree["roman"]["meta"]["cal_step"])
else:
print("cal_step not in roman->meta")
typefix.fix(af2)
with open(config["OUT"], "wb") as f:
af2.write_to(f)
if "FITSOUT" in config:
if config["FITSOUT"]:
good = ~maskhandling.PixelMask1.build(im2["dq"]) # this is one choice
# note we accept saturated pixels in this step
fits.HDUList(
[
fits.PrimaryHDU(im2["data"]),
fits.ImageHDU(im2["dq"]),
fits.ImageHDU(np.where(good, im2["data"], -1000)),
]
).writeto(config["OUT"][:-5] + "_asdf_to.fits", overwrite=True)
print(mylog.output)
return
if __name__ == "__main__":
with open(sys.argv[1]) as f:
[docs]
config = yaml.safe_load(f)
calibrateimage(config)