Source code for chandra_limits.utils

# Licensed under a 3-clause BSD style license - see LICENSE
import json

import numpy as np
from astropy.table import Table
from cxotime import CxoTime
from numpy import ma
from numpy.lib.recfunctions import append_fields
from xija.get_model_spec import get_xija_model_spec

model_name_map = {
    "1dpamzt": "dpa",
    "1pdeaat": "psmc",
    "1deamzt": "dea",
    "fptemp": "acisfp",
    "aacccdpt": "aca",
    "2ceahvpt": "cea",
}


def c_to_f(temp):
    """
    Convert Celsius to Fahrenheit.

    :param temp: Temperature in Celsius
    :type temp: int or float or tuple or list or np.ndarray
    :return: Temperature in Fahrenheit
    :rtype: int or float or list or np.ndarray
    """
    if isinstance(temp, (list, tuple)):
        temp = np.array(temp)
    return temp * 1.8 + 32.0


def f_to_c(temp):
    """
    Convert Fahrenheit to Celsius.

    :param temp: Temperature in Fahrenheit
    :type temp: int or float or tuple or list or np.ndarray
    :return: Temperature in Celsius
    :rtype: int or float or list or np.ndarray
    """
    if isinstance(temp, (list, tuple)):
        temp = np.array(temp)
    return (temp - 32.0) / 1.8


def get_xija_model(model_spec, msid):
    """
    Load parameters for a single Xija model.

    Parameters
    ----------
    model_spec : str, dict, or None
        File path to local model specification file, or the actual
        model specification, which just passes through. If None,
        xija will search for the appropriate model specification
        in chandra_models.
    msid : str
        The MSID of the model to get if the filename is None.

    Returns
    -------
    Model spec as a dictionary
    """
    if not isinstance(model_spec, dict):
        if model_spec is None:
            name = model_name_map.get(msid, msid)
            model_spec = get_xija_model_spec(name)[0]
        else:
            with open(model_spec) as fid:
                f = fid.read()
            model_spec = json.loads(f)
    return model_spec


def process_states(states):
    # Takes commanded states *states* in any form and
    # returns a NumPy structured array
    if not isinstance(states, np.ndarray):
        if isinstance(states, Table):
            states = states.as_array()
        else:
            dtype = [(k, str(v.dtype)) for k, v in states.items()]
            data = np.zeros(states[dtype[0][0]].size, dtype=dtype)
            for k, v in states.items():
                data[k] = v
            states = data
    return states


def states_match_times(states, times):
    # Takes commanded states *states* and ensures that
    # the times for the state are matched up to the model
    # times
    states = process_states(states)
    if "times" not in states.dtype.names and "tstop" in states.dtype.names:
        indexes = np.searchsorted(states["tstop"], times)
        states = states[indexes]
        append_fields(states, "times", times)
    else:
        try:
            matched = np.isclose(states["times"], times).all()
        except ValueError:
            matched = False
        if not matched:
            raise ValueError("Times in states do not match model times!")
    return states


def states_get_times(states):
    if "tstart" not in states.dtype.names:
        tstart = CxoTime(states["datestart"]).secs
        tstop = CxoTime(states["datestop"][-1]).secs
    else:
        tstart = states["tstart"]
        tstop = states["tstop"][-1]
    return np.append(tstart, tstop)


def pointpair(x, y=None):
    """Interleave and then flatten two arrays ``x`` and ``y``.  This is
    typically useful for making a histogram style plot where ``x`` and ``y``
    are the bin start and stop respectively.  If no value for ``y`` is provided then
    ``x`` is used.

    Example::

      from ska_matplotlib import pointpair
      x = np.arange(1, 100, 5)
      x0 = x[:-1]
      x1 = x[1:]
      y = np.random.uniform(len(x0))
      xpp = pointpair(x0, x1)
      ypp = pointpair(y)
      plot(xpp, ypp)

    :x: left edge value of point pairs
    :y: right edge value of point pairs (optional)
    :rtype: np.array of length 2*len(x) == 2*len(y)
    """
    if y is None:
        y = x
    return ma.array([x, y]).reshape(-1, order="F")


[docs] def plot_viols(ax, viols, color="r", alpha=0.25, **kwargs): """ Add bands to a plot for a list of violations. All additional keyword arguments are passed to ax.axvspan. Parameter --------- ax : matplotlib.axes.Axes The axes to add the bands to. viols : list of dict A list of violations dictionaries. color : str, optional The color of the bands. Default: 'r' alpha : float, optional The transparency of the bands. Default: 0.25 """ for viol in viols: ax.axvspan( CxoTime(viol["datestart"]).plot_date, CxoTime(viol["datestop"]).plot_date, color=color, alpha=alpha, **kwargs, )