Source code for ska_helpers.chandra_models

# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Get data from chandra_models repository.
"""
import contextlib
import functools
import hashlib
import os
import shutil
import tempfile
import warnings
from pathlib import Path
from typing import Callable, Optional, Union

import git
import requests

from ska_helpers.git_helpers import make_git_repo_safe
from ska_helpers.paths import chandra_models_repo_path
from ska_helpers.utils import LRUDict

__all__ = [
    "chandra_models_cache",
    "get_data",
    "get_repo_version",
    "get_github_version",
]

CHANDRA_MODELS_LATEST_URL = (
    "https://api.github.com/repos/sot/chandra_models/releases/latest"
)

ENV_VAR_NAMES = [
    "CHANDRA_MODELS_REPO_DIR",
    "CHANDRA_MODELS_DEFAULT_VERSION",
    "THERMAL_MODELS_DIR_FOR_MATLAB_TOOLS_SW",
]


def chandra_models_cache(func):
    """Decorator to cache outputs for a function that gets chandra_models data.

    The key used for caching the function output includes the passed arguments and
    keyword arguments, as well as the values of the environment variables below.
    This ensures that the cache is invalidated if any of these environment variables
    change:

    - CHANDRA_MODELS_REPO_DIR
    - CHANDRA_MODELS_DEFAULT_VERSION
    - THERMAL_MODELS_DIR_FOR_MATLAB_TOOLS_SW

    Example::

        @chandra_models_cache
        def get_aca_spec_info(version=None):
            _, info = get_data("chandra_models/xija/aca/aca_spec.json", version=version)
            return info
    """
    cache = LRUDict(capacity=32)

    @functools.wraps(func)
    def cached_func(*args, **kwargs):
        key = (
            args,
            tuple(sorted(kwargs.items())),
            tuple((name, os.environ.get(name)) for name in ENV_VAR_NAMES),
        )
        if key not in cache:
            cache[key] = func(*args, **kwargs)

        return cache[key]

    return cached_func


@contextlib.contextmanager
def get_local_repo(repo_path, version):
    """Get local version of ``repo_path`` and ensure correct clean-up."""

    def onerror(func, path, exc_info):
        os.chmod(path, 0o0777)
        try:
            func(path)
        except Exception as exc:
            print(f"Warning: temp_dir() could not remove {path} because of {exc}")

    clone = str(repo_path).startswith("https://github.com") or version is not None
    if clone:
        repo_path_local = tempfile.mkdtemp()
        repo = git.Repo.clone_from(repo_path, repo_path_local)
        if version is not None:
            repo.git.checkout(version)
    else:
        repo = git.Repo(repo_path)
        make_git_repo_safe(repo_path)
        repo_path_local = repo_path

    yield repo, repo_path_local

    repo.close()
    if clone:
        shutil.rmtree(repo_path_local, onerror=onerror)


def get_data(
    file_path: str | Path,
    version: Optional[str] = None,
    repo_path: Optional[str | Path] = None,
    require_latest_version: bool = False,
    timeout: int | float = 5,
    read_func: Optional[Callable] = None,
    read_func_kwargs: Optional[dict] = None,
) -> tuple:
    """
    Get data from chandra_models repository.

    There are three environment variables that impact the behavior:

    - ``CHANDRA_MODELS_REPO_DIR`` or ``THERMAL_MODELS_DIR_FOR_MATLAB_TOOLS_SW``:
      override the default root for the chandra_models repository
    - ``CHANDRA_MODELS_DEFAULT_VERSION``: override the default repo version. You can set
      this to a fixed version in unit tests (e.g. with ``monkeypatch``), or set to a
      developement branch to test a model file update with applications like yoshi where
      specifying a version would require a long chain of API updates.

    ``THERMAL_MODELS_DIR_FOR_MATLAB_TOOLS_SW`` is used to define the chandra_models repository
    location when running in the MATLAB tools software environment.  If this environment
    variable is set, then the git is_dirty() check of the chandra_models directory is skipped
    as the chandra_models repository is verified via SVN in the MATLAB tools software environment.
    Users in the FOT Matlab tools should exercise caution if using locally-modified files
    for testing, as the version information reported by this function in that case will not
    be correct.


    Examples
    --------
    First we read the model specification for the ACA model. The ``get_data()`` function
    returns the text of the model spec so we need to use ``json.loads()`` to convert it
    to a dict.
    ::

        >>> import json
        >>> from astropy.io import fits
        >>> from ska_helpers import chandra_models

        >>> txt, info = chandra_models.get_data("chandra_models/xija/aca/aca_spec.json")
        >>> model_spec = json.loads(txt)
        >>> model_spec["name"]
        'aacccdpt'

    Next we read the acquisition probability model image. Since the image is a gzipped
    FITS file we need to use a helper function to read it.
    ::

        >>> def read_fits_image(file_path):
        ...     with fits.open(file_path) as hdus:
        ...         out = hdus[1].data
        ...     return out, file_path
        ...
        >>> acq_model_image, info = chandra_models.get_data(
        ...     "chandra_models/aca_acq_prob/grid-floor-2018-11.fits.gz",
        ...     read_func=read_fits_image
        ... )
        >>> acq_model_image.shape
        (141, 31, 7)

    Now let's get the version of the chandra_models repository::

        >>> chandra_models.get_repo_version()
        '3.47'

    Finally get version 3.30 of the ACA model spec from GitHub. The use of a lambda
    function to read the JSON file is compact but not recommended for production code.
    ::

        >>> model_spec_3_30, info = chandra_models.get_data(
        ...     "chandra_models/xija/aca/aca_spec.json",
        ...     version="3.30",
        ...     repo_path="https://github.com/sot/chandra_models.git",
        ...     read_func=lambda fn: (json.load(open(fn, "rb")), fn),
        ... )
        >>> model_spec_3_30 == model_spec
        False

    Parameters
    ----------
    file_path : str, Path
        Name of model
    version : str
        Tag, branch or commit of chandra_models to use (default=latest tag from repo).
        If the ``CHANDRA_MODELS_DEFAULT_VERSION`` environment variable is set then this
        is used as the default. This is useful for testing.
    repo_path : str, Path
        Path to directory or URL containing chandra_models repository (default is
        ``$SKA/data/chandra_models`` or either of the ``CHANDRA_MODELS_REPO_DIR`` or
        ``THERMAL_MODELS_DIR_FOR_MATLAB_TOOLS_SW`` environment variables if set).
    require_latest_version : bool
        Require that ``version`` matches the latest release on GitHub
    timeout : int, float
        Timeout (sec) for querying GitHub for the expected chandra_models version.
        Default = 5 sec.
    read_func : callable
        Optional function to read the data file. This function must take the file path
        as its first argument. If not provided then read the file as a text file.
    read_func_kwargs : dict
        Optional dict of kwargs to pass to ``read_func``.

    Returns
    -------
    tuple of dict, str
        Xija model specification dict, chandra_models version
    """
    # Information about this request.
    info = {
        "call_args": {
            "file_path": str(file_path),
            "version": version,
            "repo_path": str(repo_path),
            "require_latest_version": require_latest_version,
            "timeout": timeout,
            "read_func": str(read_func),
            "read_func_kwargs": read_func_kwargs,
        }
    }

    if repo_path is None:
        repo_path = chandra_models_repo_path()

    if version is None:
        version = os.environ.get("CHANDRA_MODELS_DEFAULT_VERSION")

    # NOTE code in xija.get_model_spec.get_repo_version() which is there to handle the
    # fact that a few files are in the repo with permissions 0755 while on Parallels
    # windows they are 0644, so the tree is always dirty.
    # TODO: just fix the repo permissions.
    #
    # with temp_directory() as repo_path_local:
    #     if platform.system() == 'Windows':
    #         repo = git.Repo.clone_from(repo_path, repo_path_local)
    #     else:
    #         repo = git.Repo(repo_path)

    # Potentially work in a clone of the repo in a temporary directory, but only if
    # necessary. In particular:
    # - If the repo is remote on GitHub then we always clone to a temp dir
    # - If the repo is local and the version is not the default then we clone to a temp
    #   to allow checking out at the specified version.
    # This is all done with a context manager that ensure the repo object is
    # properly closed and that all temporary files are cleaned up. Doing this
    # on Windows was challenging. Search on slack:
    # "The process cannot access the file because it is being used"
    with get_local_repo(repo_path, version) as (repo, repo_path_local):
        repo_file_path = Path(repo_path_local) / file_path
        if not repo_file_path.exists():
            raise FileNotFoundError(f"chandra_models {file_path=} does not exist")

        if version is None:
            # This also ensures that the repo is not dirty.
            version = get_repo_version(repo=repo)

        if require_latest_version:
            assert_latest_version(version, timeout)

        if read_func is None:
            data = repo_file_path.read_text()
        else:
            if read_func_kwargs is None:
                read_func_kwargs = {}
            # read_func() returns the data and the actual file path used. This is useful
            # for file globs where the file path may be a glob pattern (specified in
            # the read_func_kwargs).
            data, repo_file_path = read_func(repo_file_path, **read_func_kwargs)

        # Compute the MD5 sum of repo_file_path.
        file_bytes = repo_file_path.read_bytes().replace(b"\r", b"")
        md5 = hashlib.md5(file_bytes).hexdigest()

        # Store some info about this request in the cache.
        info.update(
            {
                "version": version,
                "commit": repo.head.commit.hexsha,
                "data_file_path": str(repo_file_path),
                "repo_path": str(repo_path),
                "md5": md5,
            }
        )
        for name in ENV_VAR_NAMES:
            info[name] = os.environ.get(name)

    return data, info


def assert_latest_version(version, timeout):
    gh_version = get_github_version(timeout=timeout)
    if gh_version is None:
        warnings.warn(
            "Could not verify GitHub chandra_models release tag "
            f"due to timeout ({timeout} sec)"
        )
    elif version != gh_version:
        raise ValueError(
            f"version mismatch: local repo {version} vs " f"github {gh_version}"
        )


[docs] def get_repo_version( repo_path: Optional[Path] = None, repo: Optional[git.Repo] = None ) -> str: """Return version (most recent tag) of models repository. Returns ------- str Version (most recent tag) of models repository """ if repo is None: if repo_path is None: repo_path = chandra_models_repo_path() repo = git.Repo(repo_path) # Use the THERMAL_MODELS_DIR_FOR_MATLAB_TOOLS_SW environment variable as a proxy # to determine if we are running in the MATLAB tools software environment. If so # the repo will be checked via SVN and using is_dirty() would change the .git/index # and cause SVN to mark the directory as modified. So skip is_dirty() in this case. if os.environ.get("THERMAL_MODELS_DIR_FOR_MATLAB_TOOLS_SW") is None: if repo.is_dirty(): raise ValueError("repo is dirty") tags = sorted(repo.tags, key=lambda tag: tag.commit.committed_datetime) tag_repo = tags[-1] return tag_repo.name
[docs] def get_github_version( url: str = CHANDRA_MODELS_LATEST_URL, timeout: Union[int, float] = 5 ) -> Optional[bool]: """Get latest chandra_models GitHub repo release tag (version). This queries GitHub for the latest release of chandra_models. Parameters ---------- url : str URL for latest chandra_models release on GitHub API timeout : int, float Request timeout (sec, default=5) Returns ------- str, None Tag name (str) of latest chandra_models release on GitHub. None if the request timed out, indicating indeterminate answer. """ try: req = requests.get(url, timeout=timeout) except (requests.ConnectTimeout, requests.ReadTimeout): return None if req.status_code != requests.codes.ok: req.raise_for_status() page_json = req.json() return page_json["tag_name"]