# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Xija - framework to model complex time-series data using a network of
coupled nodes with pluggable model components that define the node
interactions.
"""
from __future__ import print_function
import ctypes
import json
import os
from collections import OrderedDict
from io import StringIO
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
import pyyaks.context as pyc
import Ska.DBI
import Ska.Numpy
from astropy.io import ascii
# Optional packages for model fitting or use on HEAD LAN
from Chandra.Time import DateTime
from cxotime import date2secs
from . import clogging, component, tmal
from .files import files as xija_files
# HDF5 version of commanded states table
H5FILE = "/proj/sot/ska/data/cmd_states/cmd_states.h5"
src = pyc.CONTEXT["src"] if "src" in pyc.CONTEXT else pyc.ContextDict("src")
files = (
pyc.CONTEXT["file"]
if "file" in pyc.CONTEXT
else pyc.ContextDict("files", basedir=os.getcwd())
)
files.update(xija_files)
if "debug" in globals():
from IPython.Debugger import Tracer
pdb_settrace = Tracer()
logger = clogging.config_logger("xija", level=clogging.INFO)
DEFAULT_DT = 328.0
dt_factors = np.array([1.0, 0.5, 0.25, 0.2, 0.125, 0.1, 0.05, 0.025])
def convert_type_star_star(array, ctype_type):
f4ptr = ctypes.POINTER(ctype_type)
return (f4ptr * len(array))(*[row.ctypes.data_as(f4ptr) for row in array])
def _get_bad_times_indices(
times: np.ndarray,
datestart: str,
datestop: str,
bad_times_in: List[Tuple[str, str]],
) -> Tuple[List[Tuple[str, str]], List[Tuple[int, int]]]:
"""Return bad_times_indices into ``times`` for elements in the
``bad_times_in`` list that overlap with the ``datestart`` to ``datestop``.
NOTE: bad_times_in is a list of [datestart, datestop] lists. The "time"
name is unfortunate since it has string dates, not float CXCsec times.
:returns: bad_times_indices: List[List[int, int]]
"""
if bad_times_in is None or len(bad_times_in) == 0:
return []
bad_times: np.ndarray = np.array(bad_times_in)
# Get inclusive overlap of bad_times with datestart to datestop
ok = (bad_times[:, 1] > datestart) & (bad_times[:, 0] < datestop)
if np.any(ok):
bad_times = bad_times[ok]
bad_times_secs = date2secs(bad_times)
idxs = np.searchsorted(times, bad_times_secs)
ok = idxs[:, 0] < idxs[:, 1]
bad_times_indices = idxs[ok].tolist()
else:
bad_times_indices = []
return bad_times_indices
[docs]
class FetchError(Exception):
pass
[docs]
class XijaModel(object):
"""Xija model class to encapsulate all ModelComponents and provide the
infrastructure to define and evaluate models.
The parameters ``name``, ``start``, and ``stop`` are determined as follows:
- If a model specification is provided then that sets the default values
for keywords that are not supplied to the class init call.
- ``evolve_method = 1`` uses the original ODE solver which treats every
two steps as a full RK2 step.
- ``evolve_method = 2`` uses the new ODE solver which treats every step
as a full RK2 step, and optionally allows for RK4 if ``rk4 = 1``.
- Otherwise defaults are: ``name='xijamodel'``, ``start = stop - 45 days``,
``stop = NOW - 30 days``, ``dt = 328 secs``, ``evolve_method = 1``,
``rk4 = 0``
Parameters
----------
name :
model name
start :
model start time (any DateTime format)
stop :
model stop time (any DateTime format)
dt :
delta time step (default=328 sec)
model_spec :
model specification (None | filename | dict)
cmd_states :
commanded states input (None | structured array)
evolve_method :
choose method to evolve ODE (None | 1 or 2, default 1)
rk4 :
use 4th-order Runge-Kutta to evolve ODE, only works with
evolve_method == 2 (None | 0 or 1, default 0)
limits :
dict of limit values (None | dict)
Returns
-------
"""
def __init__(
self,
name=None,
start=None,
stop=None,
dt=None,
model_spec=None,
cmd_states=None,
evolve_method=None,
rk4=None,
limits=None,
):
# If model_spec is a str or Path then read that file
if isinstance(model_spec, (str, Path)):
model_spec = json.load(open(model_spec, "r"))
# If a model_spec is now available (dict) then use as kwarg defaults
if model_spec:
stop = stop or model_spec["datestop"]
start = start or model_spec["datestart"]
name = name or model_spec["name"]
dt = dt or model_spec["dt"]
evolve_method = evolve_method or model_spec.get("evolve_method", None)
rk4 = rk4 or model_spec.get("rk4", None)
limits = model_spec.get("limits", {})
if stop is None:
stop = DateTime() - 30
if start is None:
start = DateTime(stop) - 45
if name is None:
name = "xijamodel"
if dt is None:
dt = DEFAULT_DT
if evolve_method is None:
evolve_method = 1
if rk4 is None:
rk4 = 0
if limits is None:
limits = {}
self.name = name
self.comp = OrderedDict()
self.dt = self._get_allowed_timestep(dt)
self.dt_ksec = self.dt / 1000.0
self.times = self._eng_match_times(start, stop)
self.tstart = self.times[0]
self.tstop = self.times[-1]
self.ksecs = (self.times - self.tstart) / 1000.0
self.datestart = DateTime(self.tstart).date
self.datestop = DateTime(self.tstop).date
self.n_times = len(self.times)
self.evolve_method = evolve_method
self.rk4 = rk4
self.limits = limits
self.bad_times = [] if (model_spec is None) else model_spec.get("bad_times", [])
self.bad_times_indices = _get_bad_times_indices(
self.times, self.datestart, self.datestop, self.bad_times
)
# This is really setting the mask times for the first
# time in this case
self.reset_mask_times()
self.pars = []
if model_spec:
self._set_from_model_spec(model_spec)
self.cmd_states = cmd_states
def _get_allowed_timestep(self, dt):
"""This method ensures that only certain timesteps are chosen,
which are integer multiples of 8.2 and where 328.0/dt is an
integer.
Parameters
----------
dt :
Returns
-------
"""
if dt > DEFAULT_DT:
logger.warning(
"dt = %g s greater than upper limit of %g s! " % (dt, DEFAULT_DT)
+ "Setting dt = %g s." % DEFAULT_DT
)
return DEFAULT_DT
dt_factor = dt / DEFAULT_DT
idx = np.argmin(np.abs(dt_factor - dt_factors))
dt = DEFAULT_DT * dt_factors[idx]
logger.debug("Using dt = %g s." % dt)
return dt
def _set_from_model_spec(self, model_spec):
for comp in model_spec["comps"]:
ComponentClass = getattr(component, comp["class_name"])
args = comp["init_args"]
kwargs = dict((str(k), v) for k, v in comp["init_kwargs"].items())
self.add(ComponentClass, *args, **kwargs)
pars = model_spec["pars"]
if len(pars) != len(self.pars):
raise ValueError(
"Number of spec pars does not match model: \n{0}\n{1}".format(
len(pars), len(self.pars)
)
)
for par, specpar in zip(self.pars, pars):
for attr in specpar:
setattr(par, attr, specpar[attr])
[docs]
def inherit_from_model_spec(self, inherit_spec):
"""Inherit parameter values from any like-named parameters within the
inherit_spec model specification. This is useful for making a new
variation of an existing model.
Parameters
----------
inherit_spec :
Returns
-------
"""
try:
inherit_spec = json.load(open(inherit_spec, "r"))
except:
pass
inherit_pars = {par["full_name"]: par for par in inherit_spec["pars"]}
for par in self.pars:
if par.full_name in inherit_pars:
logger.info("Inheriting par {}".format(par.full_name))
par.val = inherit_pars[par.full_name]["val"]
par.min = inherit_pars[par.full_name]["min"]
par.max = inherit_pars[par.full_name]["max"]
par.frozen = inherit_pars[par.full_name]["frozen"]
par.fmt = inherit_pars[par.full_name]["fmt"]
def _eng_match_times(self, start, stop):
"""
Parameters
----------
start :
stop :
Returns
-------
type
sec intervals. The times are roughly aligned (within 1 sec) to the
timestamps in the '5min' (328 sec) Ska eng archive data.
"""
time0 = 410270764.0
i0 = int((DateTime(start).secs - time0) / self.dt) + 1
i1 = int((DateTime(stop).secs - time0) / self.dt)
return time0 + np.arange(i0, i1) * self.dt
def _get_cmd_states(self):
if not hasattr(self, "_cmd_states"):
import kadi.commands.states as kadi_states
logger.info(
"Getting kadi commanded states over %s to %s"
% (self.datestart, self.datestop)
)
states = kadi_states.get_states(self.datestart, self.datestop)
self._cmd_states = kadi_states.interpolate_states(
states, self.times
).as_array()
return self._cmd_states
def _set_cmd_states(self, states):
"""Set the states that define component data inputs.
Parameters
----------
states :
numpy structured array
Returns
-------
"""
if states is not None:
if (
states[0]["tstart"] >= self.times[0]
or states[-1]["tstop"] <= self.times[-1]
):
raise ValueError(
"cmd_states time range too small:\n{} : {} versus {} : {}".format(
states[0]["tstart"],
states[-1]["tstop"],
self.times[0],
self.times[-1],
)
)
indexes = np.searchsorted(states["tstop"], self.times)
self._cmd_states = states[indexes]
cmd_states = property(_get_cmd_states, _set_cmd_states)
"""test cmdstats"""
[docs]
def fetch(self, msid, attr="vals", method="linear"):
"""Get data from the Chandra engineering archive.
Parameters
----------
msid :
attr :
(Default value = 'vals')
method :
(Default value = 'linear')
Returns
-------
"""
tpad = DEFAULT_DT * 5.0
datestart = DateTime(self.tstart - tpad).date
datestop = DateTime(self.tstop + tpad).date
logger.info("Fetching msid: %s over %s to %s" % (msid, datestart, datestop))
try:
import cheta.fetch_sci as fetch
tlm = fetch.MSID(msid, datestart, datestop, stat="5min")
tlm.filter_bad_times()
except ImportError:
raise ValueError("cheta.fetch not available")
if tlm.times[0] > self.tstart or tlm.times[-1] < self.tstop:
raise ValueError(
"Fetched telemetry does not span model start and "
"stop times for {}".format(msid)
)
vals = Ska.Numpy.interpolate(
getattr(tlm, attr), tlm.times, self.times, method=method
)
return vals
[docs]
def interpolate_data(self, data, times, comp=None):
"""Interpolate supplied ``data`` values at the model times using
nearest-neighbor or state value interpolation.
The ``times`` arg can be either a 1-d or 2-d ndarray. If 1-d,
then ``data`` is interpreted as a set of values at the specified
``times``. If 2-d then ``data`` is interpreted as a set of binned
state values with ``tstarts = times[0, :]`` and
``tstops = times[1, :]``.
Parameters
----------
data :
times :
comp :
(Default value = None)
Returns
-------
"""
if times is None:
if len(data) != self.n_times:
raise ValueError(
"Data length not equal to model times for {} component".format(comp)
)
return data
if len(data) != times.shape[-1]:
raise ValueError(
"Data length not equal to data times for {} component".format(comp)
)
if times.ndim == 1: # Data value specification
vals = Ska.Numpy.interpolate(data, times, self.times, method="nearest")
elif times.ndim == 2: # State-value specification
tstarts = times[0]
tstops = times[1]
if self.times[0] < tstarts[0] or self.times[-1] > tstops[-1]:
raise ValueError(
"Model times extend outside the state value"
" data_times for component {}".format(comp)
)
indexes = np.searchsorted(tstops, self.times)
vals = data[indexes]
else:
raise ValueError(
"data_times for {} has {} dimensions, must be either 1 or 2".format(
comp, times.ndim
)
)
return vals
[docs]
def add(self, ComponentClass, *args, **kwargs):
"""Add a new component to the model
Parameters
----------
ComponentClass :
*args :
**kwargs :
Returns
-------
"""
comp = ComponentClass(self, *args, **kwargs)
# Store args and kwargs used to initialize object for later object
# storage and re-creation
comp.init_args = args
comp.init_kwargs = kwargs
self.comp[comp.name] = comp
for par in comp.pars:
self.pars.append(par)
return comp
comps = property(lambda self: list(self.comp.values()))
"""List of model components"""
[docs]
def get_comp(self, name):
"""Get a model component. Works with either a string or a component
object
Parameters
----------
name :
Returns
-------
"""
return None if name is None else self.comp[str(name)]
@property
def model_spec(self):
"""Generate a full model specification data structure for this
model
Parameters
----------
Returns
-------
"""
model_spec = dict(
name=self.name,
comps=[],
dt=self.dt,
datestart=self.datestart,
datestop=self.datestop,
tlm_code=None,
mval_names=[],
evolve_method=self.evolve_method,
rk4=self.rk4,
)
model_spec["bad_times"] = self.bad_times
model_spec["pars"] = [dict(par) for par in self.pars]
model_spec["limits"] = self.limits
stringfy = lambda x: (str(x) if isinstance(x, component.ModelComponent) else x)
for comp in self.comps:
init_args = [stringfy(x) for x in comp.init_args]
init_kwargs = dict((k, stringfy(v)) for k, v in comp.init_kwargs.items())
model_spec["comps"].append(
dict(
class_name=comp.__class__.__name__,
name=comp.name,
init_args=init_args,
init_kwargs=init_kwargs,
)
)
return model_spec
[docs]
def write_vals(self, filename):
"""Write dvals and mvals for each model component (as applicable) to an
ascii table file. Some component have neither (couplings), some have
just dvals (TelemData), others have both (Node, AcisDpaPower).
Everything is guaranteed to be time synced, so write a single time
column.
Parameters
----------
filename :
Returns
-------
"""
colvals = OrderedDict(time=self.times)
for comp in self.comps:
if hasattr(comp, "dvals"):
colvals[comp.name + "_data"] = comp.dvals
if hasattr(comp, "mvals") and comp.predict:
colvals[comp.name + "_model"] = comp.mvals
ascii.write(colvals, filename, names=list(colvals.keys()))
[docs]
def write(self, filename, model_spec=None):
"""Write the model specification as JSON or Python to a file.
If the file name ends with ".py" then the output will the Python
code to create the model (using ``get_model_code()``), otherwise
the JSON model specification will be written.
Parameters
----------
filename :
output filename
model_spec :
model spec structure (optional) (Default value = None)
Returns
-------
"""
if model_spec is None:
model_spec = self.model_spec
with open(filename, "w") as f:
if filename.endswith(".py"):
f.write(self.get_model_code())
else:
json.dump(model_spec, f, sort_keys=True, indent=4)
[docs]
def get_model_code(self):
"""Return Python code that will create the current model.
This is useful during model development as a way to derive from and
modify existing models while retaining good parameter values.
Parameters
----------
Returns
-------
type
string of Python code
"""
out = StringIO()
ms = self.model_spec
model_call = "model = xija.XijaModel({}, start={}, stop={}, dt={},\n"
model_call += "evolve_method={} rk4={}\n"
print("import sys", file=out)
print("import xija\n", file=out)
print(
model_call.format(
repr(ms["name"]),
repr(ms["datestart"]),
repr(ms["datestop"]),
repr(ms["dt"]),
repr(ms["evolve_method"]),
repr(ms["rk4"]),
),
file=out,
)
for comp in ms["comps"]:
args = [repr(x) for x in comp["init_args"]]
kwargs = [
"{}={}".format(k, repr(v)) for k, v in comp["init_kwargs"].items()
]
print("model.add(xija.{},".format(comp["class_name"]), file=out)
for arg in args:
print(" {},".format(arg), file=out)
for kwarg in kwargs:
print(" {},".format(kwarg), file=out)
print(" )", file=out)
parattrs = ("val", "min", "max", "fmt", "frozen")
last_comp_name = None
for par in ms["pars"]:
comp_name = par["comp_name"]
if comp_name != last_comp_name:
print("# Set {} component parameters".format(comp_name), file=out)
print("comp = model.get_comp({})\n".format(repr(comp_name)), file=out)
print("par = comp.get_par({})".format(repr(par["name"])), file=out)
par_upds = ["{}={}".format(attr, repr(par[attr])) for attr in parattrs]
print("par.update(dict({}))\n".format(", ".join(par_upds)), file=out)
last_comp_name = comp_name
print("model.bad_times = {}".format(repr(self.bad_times)), file=out)
print("if len(sys.argv) > 1:", file=out)
print(" model.write(sys.argv[1])", file=out)
return out.getvalue()
def _get_parvals(self):
""" """
return tuple(par.val for par in self.pars)
def _set_parvals(self, vals):
"""Set the full list of parameter values. No provision is made for
setting individual elements or slicing (use self.pars directly in this
case).
Parameters
----------
vals :
Returns
-------
"""
if len(vals) != len(self.pars):
raise ValueError(
"Length mismatch setting parvals {} vs {}".format(
len(self.pars), len(vals)
)
)
for par, val in zip(self.pars, vals):
par.val = val
parvals = property(_get_parvals, _set_parvals)
@property
def parnames(self):
""" """
return tuple(par.full_name for par in self.pars)
[docs]
def make(self):
"""Call self.make_mvals and self.make_tmal to prepare for model evaluation
once all model components have been added.
Parameters
----------
Returns
-------
"""
self.make_mvals()
self.make_tmal()
[docs]
def make_mvals(self):
"""Initialize the global mvals (model values) array. This is an
N (rows) x n_times (cols) array that contains all data needed
to compute the model prediction. All rows are initialized to
relevant data values (e.g. node temps, time-dependent power,
external temperatures, etc). In the model calculation some
rows will be overwritten with predictions.
Parameters
----------
Returns
-------
"""
# Select components with data values, and from those select ones that
# get predicted and those that do not get predicted
comps = [x for x in self.comps if x.n_mvals]
preds = [x for x in comps if x.predict]
unpreds = [x for x in comps if not x.predict]
# Register the location of component mvals in global mvals
i = 0
for comp in preds + unpreds:
comp.mvals_i = i
i += comp.n_mvals
# Stack the input dvals. This *copies* the data values.
self.n_preds = len(preds)
self.mvals = np.hstack([comp.dvals for comp in preds + unpreds])
self.mvals.shape = (len(comps), -1) # why doesn't this use vstack?
self.cvals = self.mvals[:, 0::2]
[docs]
def make_tmal(self):
"""Make the TMAL "code" using components that generate TMAL
statements
Parameters
----------
Returns
-------
"""
for comp in self.comps:
comp.update()
tmal_comps = [x for x in self.comps if hasattr(x, "tmal_ints")]
self.tmal_ints = np.zeros((len(tmal_comps), tmal.N_INTS), dtype=np.int32)
self.tmal_floats = np.zeros((len(tmal_comps), tmal.N_FLOATS), dtype=np.float64)
for i, comp in enumerate(tmal_comps):
self.tmal_ints[i, 0 : len(comp.tmal_ints)] = comp.tmal_ints
self.tmal_floats[i, 0 : len(comp.tmal_floats)] = comp.tmal_floats
[docs]
def calc(self):
"""Calculate the model. The results appear in the self.mvals array."""
self.make_tmal()
# int calc_model(int n_times, int n_preds, int n_tmals, float dt,
# float **mvals, int **tmal_ints, float **tmal_floats)
mvals = convert_type_star_star(self.mvals, ctypes.c_double)
tmal_ints = convert_type_star_star(self.tmal_ints, ctypes.c_int)
tmal_floats = convert_type_star_star(self.tmal_floats, ctypes.c_double)
if self.evolve_method == 1:
dt = self.dt_ksec * 2
self.core_1.calc_model_1(
self.n_times,
self.n_preds,
len(self.tmal_ints),
dt,
mvals,
tmal_ints,
tmal_floats,
)
elif self.evolve_method == 2:
dt = self.dt_ksec
self.core_2.calc_model_2(
self.rk4,
self.n_times,
self.n_preds,
len(self.tmal_ints),
dt,
mvals,
tmal_ints,
tmal_floats,
)
# hackish fix to ensure last value is a computed value for predicted components
self.mvals[: self.n_preds, -1] = self.mvals[: self.n_preds, -2]
# Apply Delay components after the model calculation
for comp in self.comps:
if isinstance(comp, component.Delay) and comp.delay != 0.0:
# Note: starting from index 0 creates an instability in xija_gui_fit,
# so just copy from index 1.
comp.node.mvals[1:] = np.interp(
x=self.times - comp.delay * 1000, xp=self.times, fp=comp.node.mvals
)[1:]
[docs]
def calc_stat(self):
"""Calculate model fit statistic as the sum of component fit stats"""
self.calc() # parvals already set with dummy_calc
fit_stat = sum(comp.calc_stat() for comp in self.comps if comp.predict)
return fit_stat
[docs]
def calc_staterror(self, data):
"""Calculate model fit statistic error (dummy array for Sherpa use)
Parameters
----------
data :
Returns
-------
"""
return np.ones_like(data)
@property
def date_range(self):
""" """
return "%s_%s" % (
DateTime(self.tstart).greta[:7],
DateTime(self.tstop).greta[:7],
)
@property
def core_1(self):
"""Lazy-load the "core_1" ctypes shared object libary that does the
low-level model calculation via the C "calc_model_1" routine. Only
load once by setting/returning a class attribute.
Parameters
----------
Returns
-------
"""
if not hasattr(XijaModel, "_core_1"):
loader_path = os.path.abspath(os.path.dirname(__file__))
_core_1 = np.ctypeslib.load_library("core_1", loader_path)
_core_1.calc_model_1.restype = ctypes.c_int
_core_1.calc_model_1.argtypes = [
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_double,
ctypes.POINTER(ctypes.POINTER(ctypes.c_double)),
ctypes.POINTER(ctypes.POINTER(ctypes.c_int)),
ctypes.POINTER(ctypes.POINTER(ctypes.c_double)),
]
XijaModel._core_1 = _core_1
return XijaModel._core_1
@property
def core_2(self):
"""Lazy-load the "core_2" ctypes shared object libary that does the
low-level model calculation via the C "calc_model_2" routine. Only
load once by setting/returning a class attribute.
Parameters
----------
Returns
-------
"""
if not hasattr(XijaModel, "_core_2"):
loader_path = os.path.abspath(os.path.dirname(__file__))
_core_2 = np.ctypeslib.load_library("core_2", loader_path)
_core_2.calc_model_2.restype = ctypes.c_int
_core_2.calc_model_2.argtypes = [
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_double,
ctypes.POINTER(ctypes.POINTER(ctypes.c_double)),
ctypes.POINTER(ctypes.POINTER(ctypes.c_int)),
ctypes.POINTER(ctypes.POINTER(ctypes.c_double)),
]
XijaModel._core_2 = _core_2
return XijaModel._core_2
def append_mask_time(self, new_times, bad=False):
t0, t1 = DateTime(new_times).secs
i0, i1 = np.searchsorted(self.times, [t0, t1])
if i1 > i0:
self.mask_times_indices.append((i0, i1))
self.mask_times_bad = np.append(self.mask_times_bad, bad)
def append_bad_time(self, new_times):
self.append_mask_time(new_times, bad=True)
self.bad_times.append(new_times)
t0, t1 = DateTime(new_times).secs
i0, i1 = np.searchsorted(self.times, [t0, t1])
if i1 > i0:
self.bad_times_indices.append([i0, i1])
def reset_mask_times(self):
self.mask_times_indices = self.bad_times_indices.copy()
self.mask_times_bad = np.ones(len(self.mask_times_indices), dtype="bool")
ThermalModel = XijaModel