Source code for jwst.rscd.rscd_sub

"""Functions for the RSCD correction for MIRI science data."""

import logging

import numpy as np
from stdatamodels.jwst.datamodels import dqflags

log = logging.getLogger(__name__)

__all__ = [
    "do_correction",
    "correction_skip_groups",
    "get_rscd_parameters",
    "flag_rscd",
    "apply_rscd_flags",
]


[docs] def do_correction(output_model, rscd_model): """ Set the initial groups of an integration of MIRI data to 'DO_NOT_USE'. The number of initial groups to set to 'DO_NOT_USE' is read in from the RSCD reference file. The number of groups to skip is integration dependent. The first integration has a value defined in the reference file and the second and higher integrations have a separate value in the reference file. Parameters ---------- output_model : `~stdatamodels.jwst.datamodels.RampModel` Input ramp datamodel rscd_model : `~stdatamodels.jwst.datamodels.RSCDModel` RSCD reference datamodel Returns ------- output_model : `~stdatamodels.jwst.datamodels.RampModel` Ramp datamodel with RSCD affected groups flagged as DO_NOT_USE """ # Retrieve the reference parameters for this exposure type param = get_rscd_parameters(output_model, rscd_model) if not bool(param): # empty dictionary log.warning( "READPATT, SUBARRAY combination not found in ref file: RSCD correction will be skipped" ) output_model.meta.cal_step.rscd = "SKIPPED" return output_model group_skip_int1 = param["skip_int1"] # integration 1 group_skip_int2p = param["skip_int2p"] # integration 2, plus higher integrations if group_skip_int1 < 0: log.warning("RSCD reference file is of a deprecated model.") log.warning("There are no values for first integration") log.warning("Setting number of groups to skip in first integration to 1") group_skip_int1 = 1 log.info(f"# groups from RSCD reference file for int 1 to flag: {group_skip_int1}") log.info(f"# groups from RSCD reference file for int 2 and higher to flag: {group_skip_int2p}") output_model = correction_skip_groups(output_model, group_skip_int1, group_skip_int2p) return output_model
[docs] def correction_skip_groups(output, group_skip_int1, group_skip_int2p): """ Set the initial groups in integration to DO_NOT_USE to skip groups affected by RSCD effect. Parameters ---------- output : `~stdatamodels.jwst.datamodels.RampModel` Science data to be flagged group_skip_int1 : int Number of groups to skip at the beginning of the ramp for integration 1 group_skip_int2p : int Number of groups to skip at the beginning of the ramp for integration 2 and higher Returns ------- output: `~stdatamodels.jwst.datamodels.RampModel` Ramp datamodel with RSCD affected groups flagged as DO_NOT_USE """ # General exposure parameters sci_ngroups = output.meta.exposure.ngroups sci_nints = output.meta.exposure.nints # values defined for segmented data sci_int_start = output.meta.exposure.integration_start if sci_int_start is None: # the data is not segmented sci_int_start = 1 log.debug(f"RSCD correction using: nints={sci_nints}, ngroups={sci_ngroups}") log.debug(f"The first integration in the data is integration: {sci_int_start}") # For general RSCD flagging, we have to start with at least 3 groups. The last frame # has been rejected in the last frame correction, leaving us with 2 groups. We have to # have at least 2 valid groups to perform a fit. Therefore the minimum number of groups # to do an rscd flagging is 3 groups. MIRI has a set minimum of 5 groups in APT (so only # in rare special cases will have 3 groups or less). if sci_ngroups < 3: log.warning("Too few groups to apply RSCD correction") log.warning("RSCD step will be skipped") output.meta.cal_step.rscd = "SKIPPED" return output # Basic global checks: # ___________________ # Will we have at least 3 groups. The last frame step has rejected 1 group so we have 2 to # find a slope. # check for sci_ngroups <= 5 if sci_ngroups <= 5: group_skip_int1 = 1 group_skip_int2p = 1 log.info( f"Number of groups to skip for integration 1 (for data with <= 5 groups): " f"{group_skip_int1}" ) log.info( f"Number of groups to skip for integration 2+ (for data with <= 5 groups): " f"{group_skip_int2p}" ) # General Checks for RSCD dynamic flagging. # checks for integration 1: if sci_ngroups < (group_skip_int1 + 3): max_groups_skip = max(0, sci_ngroups - 3) if max_groups_skip != group_skip_int1: log.info(f"Changing the # of groups to skip in int 1 to {max_groups_skip}") group_skip_int1 = max_groups_skip # checks for integration 2 if sci_nints > 1 and sci_ngroups < (group_skip_int2p + 3): max_groups_skip = max(0, sci_ngroups - 3) if max_groups_skip != group_skip_int2p: group_skip_int2p = max_groups_skip log.info(f"Changing the # of groups to skip in int 2 and higher to {max_groups_skip}") # Note For segmented data the first integration in the file may not be the first # integration in the exposure. The value in meta.exposure.integration_start # holds the value of the first integration in the file. # Flag RSCD groups in integration 1 # __________________________________ if sci_int_start == 1: # Using sci_int_start to cover segmented data case. rscd_skip_array, num_rscd_lowered, num_only_one_group = flag_rscd( output, sci_int_start - 1, sci_int_start - 1, group_skip_int1 ) output = apply_rscd_flags(output, sci_int_start - 1, sci_int_start - 1, rscd_skip_array) log.info( "Number of usable bright pixels with rscd flag groups " f"not set to DO_NOT_USE: {num_rscd_lowered}" ) output.meta.rscd.keep_bright_firstgroup_int1 = num_only_one_group output.meta.rscd.keep_groups_saturation_int1 = num_rscd_lowered output.meta.rscd.ngroups_skip_int1 = group_skip_int1 # Flag RSCD groups in integration 2 and higher # ______________________________________________ int_start = 2 int_end = output.data.shape[0] # use the data shape instead of sci_ints in case we have segmented data # in segmented data the sci_ints can be much larger than data.shape[0] # depending on which segment number we are on. if sci_int_start != 1: # we have segmented data and we are not on the first integration int_start = 1 if sci_nints > 1: rscd_skip_array, num_rscd_lowered, num_only_one_group = flag_rscd( output, int_start - 1, int_end - 1, group_skip_int2p ) output = apply_rscd_flags(output, int_start - 1, int_end - 1, rscd_skip_array) log.info( "Number of usable bright pixels with rscd flag groups " f"not set to DO_NOT_USE: {num_rscd_lowered}" ) output.meta.rscd.keep_bright_firstgroup_int2p = num_only_one_group output.meta.rscd.keep_groups_saturation_int2p = num_rscd_lowered output.meta.rscd.ngroups_skip_int2p = group_skip_int2p output.meta.cal_step.rscd = "COMPLETE" return output
[docs] def flag_rscd(output_model, int_start, int_end, rscd_skip): """ Find the initial groups to set to DO_NOT_USE based on RSCD rules. Parameters ---------- output_model : `~stdatamodels.jwst.datamodels.RampModel` Science data to be flagged. int_start : int Starting integration. int_end : int Ending integration. rscd_skip : int Number of groups to skip at the beginning of the ramp for integration range. Returns ------- skip_array : ndarray Array containing the number of groups to skip based on pixel location and integration. num_rscd_lowered : int The number of pixels where the number of RSCD groups to flag as DO_NOT_USE was changed because of saturation. num_only_one_group : int The number of pixels where there is only 1 valid group after checking for saturation. """ n_ints = int_end - int_start + 1 x_dim = output_model.groupdq.shape[3] y_dim = output_model.groupdq.shape[2] skip_array = np.full((n_ints, y_dim, x_dim), rscd_skip) # --- If we encounter saturation, we might need to back off the rscd correction. # Ideally we want at least two valid groups, but we need to allow there to only # be 1 valid group. The user can set the ramp_fit parameter suppress_one_group = False # to derive a value for this point. min_group = rscd_skip + 2 # Note: min_groups starts count at 1 # 1. Identify pixels saturated at the current threshold is_sat_problem = ( ( output_model.groupdq[int_start : int_end + 1, min_group - 1, :, :] & dqflags.group["SATURATED"] ) > 0 ).astype(bool) # New check specifically for Group 1. If it is also saturated then we can not # recover this pixel. is_group_1_sat = ( (output_model.groupdq[int_start : int_end + 1, 0, :, :] & dqflags.group["SATURATED"]) > 0 ).astype(bool) # 3. Remove Group 1 saturation from the original problem mask # This keeps saturation flags ONLY if they are NOT saturated in Group 1 is_sat_problem &= ~is_group_1_sat num_rscd_lowered = 0 num_only_one_group_pixels = 0 num_sat = np.sum(is_sat_problem) log.info( f" There are {num_sat} saturated pixels that require the number of " "rscd groups flagged to be lowered" ) # Find the first non-saturating group if num_sat > 0: # do dynamic rscd flagging - based on saturation group of every pixel while num_sat > 0 and min_group > 1: # subtract 1 from skip_array skip_array[is_sat_problem] = np.maximum(skip_array[is_sat_problem] - 1, 0) min_group = min_group - 1 # re-evaluate the saturation at the lower group level is_sat_problem = ( ( output_model.groupdq[int_start : int_end + 1, min_group - 1, :, :] & dqflags.group["SATURATED"] ) > 0 ).astype(bool) # Re-apply the group 1 guard # (Otherwise, if we drop to Group 1, we might process pixels # we already deemed "unrecoverable") is_sat_problem &= ~is_group_1_sat num_sat = is_sat_problem.sum() # 1. Identify where the skip_array is less than the original rscd_skip # This means the logic was forced to "back off" to accommodate saturation. was_backed_off = skip_array < rscd_skip # 2. Collapse the 3D mask (Integrations, Y, X) to 2D (Y, X) # If a pixel was backed off in ANY integration, we flag it. is_backed_off_2d = np.any(was_backed_off, axis=0) num_rscd_lowered = is_backed_off_2d.sum() # 3. Apply the FLUX_ESTIMATED flag if np.any(is_backed_off_2d): output_model.pixeldq[is_backed_off_2d] |= dqflags.pixel["FLUX_ESTIMATED"] log.info( f"Flagged {np.sum(is_backed_off_2d)} pixels as FLUX_ESTIMATED due to RSCD back-off." ) # 4. Final Safety: Reset negative values (with this logic, 0 is the floor) skip_array = np.maximum(skip_array, 0) # now record if we have to back off all the way to group 1 is_only_one_group = skip_array == 0 num_only_one_group_pixels = np.any(is_only_one_group, axis=0).sum() return skip_array, num_rscd_lowered, num_only_one_group_pixels
[docs] def apply_rscd_flags(output_model, int_start, int_end, skip_array): """ Apply flags for RSCD correction setting DO_NOT_USE to the dq values. Parameters ---------- output_model : `~stdatamodels.jwst.datamodels.RampModel` Science data to be flagged int_start : int Starting integration int_end : int Ending integration skip_array : ndarray Number of groups to skip at the beginning of the ramp for integration range. Returns ------- output_model : `~stdatamodels.jwst.datamodels.RampModel` Ramp datamodel with RSCD affected groups flagged as DO_NOT_USE """ # Redefine starting at 0 skip_array = skip_array - 1 # 1. Extract the relevant region of the groupdq array # Shape: (N_ints, Groups, Y, X) dq = output_model.groupdq[int_start : int_end + 1, :, :, :] # 2. Create a grid of group indices # Shape: (Groups,) -> e.g., [0, 1, 2, 3...] num_groups = dq.shape[1] group_indices = np.arange(num_groups) # 3. Broadcast for comparison # We want: (1, Groups, 1, 1) < (N_ints, 1, Y, X) # This results in a 4D boolean mask mask = group_indices[None, :, None, None] <= skip_array[:, None, :, :] # 4. Apply the DO_NOT_USE flag using the mask # This updates only the pixels/groups where the index is below the skip threshold dq[mask] |= dqflags.group["DO_NOT_USE"] # Put the modified dq back output_model.groupdq[int_start : int_end + 1, :, :, :] = dq return output_model
[docs] def get_rscd_parameters(input_model, rscd_model): """ Read in the parameters from the reference file and store the parameters in a dictionary. Parameters ---------- input_model : `~stdatamodels.jwst.datamodels.RampModel` Science data to be flagged rscd_model : `~stdatamodels.jwst.datamodels.RSCDModel` RSCD reference file data Returns ------- param : dict Dictionary of parameters """ # Reference file parameters held in dictionary: param param = {} # read in the type of data from the input model (FAST,SLOW,FULL,SUBARRAY) readpatt = input_model.meta.exposure.readpatt subarray = input_model.meta.subarray.name # Check for old values of the MIRI LRS slitless subarray name # in the science data and change to the new if subarray.upper() == "SUBPRISM": subarray = "SLITLESSPRISM" # read table 1: containing the number of groups to skip for tabdata in rscd_model.rscd_group_skip_table: subarray_table = tabdata["subarray"] readpatt_table = tabdata["readpatt"] group_skip_table_int2p = tabdata["group_skip"] # integration 2 and higher (+) group_skip_table_int1 = tabdata["group_skip1"] if subarray_table == subarray and readpatt_table == readpatt: param["skip_int1"] = group_skip_table_int1 param["skip_int2p"] = group_skip_table_int2p # integration 2 and higher break return param