Module riid.models.base

This module contains functionality shared across all PyRIID models.

Expand source code Browse git
# Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
# Under the terms of Contract DE-NA0003525 with NTESS,
# the U.S. Government retains certain rights in this software.
"""This module contains functionality shared across all PyRIID models."""
import json
import os
from pathlib import Path
import uuid
from abc import abstractmethod
from enum import Enum

import numpy as np
import tensorflow as tf
import tf2onnx
from keras.api.models import Model
from keras.api.utils import get_custom_objects

import riid
from riid import SampleSet, SpectraState
from riid.data.labeling import label_to_index_element
from riid.losses import mish
from riid.metrics import multi_f1, single_f1

get_custom_objects().update({
    "multi_f1": multi_f1,
    "single_f1": single_f1,
    "mish": mish,
})


class ModelInput(int, Enum):
    """Enumerates the potential input sources for a model."""
    GrossSpectrum = 0
    BackgroundSpectrum = 1
    ForegroundSpectrum = 2


class PyRIIDModel:
    """Base class for PyRIID models."""

    def __init__(self, *args, **kwargs):
        self._info = {}
        self._temp_file_path = "temp_model.json"
        self._custom_objects = {}
        self._initialize_info()

    @property
    def seeds(self):
        return self._info["seeds"]

    @seeds.setter
    def seeds(self, value):
        self._info["seeds"] = value

    @property
    def info(self):
        return self._info

    @info.setter
    def info(self, value):
        self._info = value

    @property
    def target_level(self):
        return self._info["target_level"]

    @target_level.setter
    def target_level(self, value):
        if value in SampleSet.SOURCES_MULTI_INDEX_NAMES:
            self._info["target_level"] = value
        else:
            msg = (
                f"Target level '{value}' is invalid.  "
                f"Acceptable levels: {SampleSet.SOURCES_MULTI_INDEX_NAMES}"
            )
            raise ValueError(msg)

    @property
    def model(self) -> Model:
        return self._model

    @model.setter
    def model(self, value: Model):
        self._model = value

    @property
    def model_id(self):
        return self._info["model_id"]

    @model_id.setter
    def model_id(self, value):
        self._info["model_id"] = value

    @property
    def model_inputs(self):
        return self._info["model_inputs"]

    @model_inputs.setter
    def model_inputs(self, value):
        self._info["model_inputs"] = value

    @property
    def model_outputs(self):
        return self._info["model_outputs"]

    @model_outputs.setter
    def model_outputs(self, value):
        self._info["model_outputs"] = value

    def get_model_outputs_as_label_tuples(self):
        return [
            label_to_index_element(v, self.target_level) for v in self.model_outputs
        ]

    def _get_model_dict(self) -> dict:
        model_json = self.model.to_json()
        model_dict = json.loads(model_json)
        model_weights = self.model.get_weights()
        model_dict = {
            "info": self._info,
            "model": model_dict,
            "weights": model_weights,
        }
        return model_dict

    def _get_model_str(self) -> str:
        model_dict = self._get_model_dict()
        model_str = json.dumps(model_dict, indent=4, cls=PyRIIDModelJsonEncoder)
        return model_str

    def _initialize_info(self):
        init_info = {
            "model_id": str(uuid.uuid4()),
            "model_type": self.__class__.__name__,
            "normalization": SpectraState.Unknown,
            "pyriid_version": riid.__version__,
        }
        self._update_info(**init_info)

    def _update_info(self, **kwargs):
        self._info.update(kwargs)

    def _update_custom_objects(self, key, value):
        self._custom_objects.update({key: value})

    def load(self, model_path: str):
        """Load the model from a path.

        Args:
            model_path: path from which to load the model.
        """
        if not os.path.exists(model_path):
            raise ValueError("Model file does not exist.")

        with open(model_path) as fin:
            model = json.load(fin)

        model_str = json.dumps(model["model"])
        self.model = tf.keras.models.model_from_json(model_str, custom_objects=self._custom_objects)
        self.model.set_weights([np.array(x) for x in model["weights"]])
        self.info = model["info"]

    def save(self, model_path: str, overwrite=False):
        """Save the model to a path.

        Args:
            model_path: path at which to save the model.
            overwrite: whether to overwrite an existing file if it already exists.

        Raises:
            `ValueError` when the given path already exists
        """
        if os.path.exists(model_path) and not overwrite:
            raise ValueError("Model file already exists.")

        model_str = self._get_model_str()
        with open(model_path, "w") as fout:
            fout.write(model_str)

    def to_onnx(self, model_path, **tf2onnx_kwargs: dict):
        """Convert the model to an ONNX model.

        Args:
            model_path: path at which to save the model
            tf2onnx_kwargs: additional kwargs to pass to the conversion
        """
        model_path = Path(model_path)
        if not str(model_path).endswith(riid.ONNX_MODEL_FILE_EXTENSION):
            raise ValueError(f"ONNX file path must end with {riid.ONNX_MODEL_FILE_EXTENSION}")
        if model_path.exists():
            raise ValueError("Model file already exists.")

        tf2onnx.convert.from_keras(
            self.model,
            input_signature=[
                tf.TensorSpec(
                    shape=input_tensor.shape,
                    dtype=input_tensor.dtype,
                    name=input_tensor.name
                )
                for input_tensor in self.model.inputs
            ],
            output_path=str(model_path),
            **tf2onnx_kwargs
        )

    def to_tflite(self, model_path, quantize: bool = False, prune: bool = False):
        """Convert the model to a TFLite model and optionally applying quantization or pruning.

        Note: requires export to SavedModel format first, then conversion to TFLite occurs.

        Args:
            model_path: file path at which to save the model
            quantize: whether to apply quantization
            prune: whether to apply pruning
        """
        model_path = Path(model_path)
        if not str(model_path).endswith(riid.TFLITE_MODEL_FILE_EXTENSION):
            raise ValueError(f"TFLite file path must end with {riid.TFLITE_MODEL_FILE_EXTENSION}")
        if model_path.exists():
            raise ValueError("Model file already exists.")

        optimizations = []
        if quantize:
            optimizations.append(tf.lite.Optimize.DEFAULT)
        if prune:
            optimizations.append(tf.lite.Optimize.EXPERIMENTAL_SPARSITY)

        saved_model_dir = model_path.stem
        self.model.export(saved_model_dir)
        converter = tf.lite.TFLiteConverter.from_saved_model(str(saved_model_dir))
        converter.optimizations = optimizations
        tflite_model = converter.convert()

        with open(model_path, "wb") as fout:
            fout.write(tflite_model)

    @abstractmethod
    def fit(self):
        pass

    @abstractmethod
    def predict(self):
        pass


class PyRIIDModelJsonEncoder(json.JSONEncoder):
    """Custom JSON encoder for saving models.
    """
    def default(self, o):
        """Converts certain types to JSON-compatible types.
        """
        if isinstance(o, np.ndarray):
            return o.tolist()
        elif isinstance(o, np.float32):
            return o.astype(float)

        return super().default(o)

Classes

class ModelInput (value, names=None, *, module=None, qualname=None, type=None, start=1)

Enumerates the potential input sources for a model.

Expand source code Browse git
class ModelInput(int, Enum):
    """Enumerates the potential input sources for a model."""
    GrossSpectrum = 0
    BackgroundSpectrum = 1
    ForegroundSpectrum = 2

Ancestors

  • builtins.int
  • enum.Enum

Class variables

var BackgroundSpectrum
var ForegroundSpectrum
var GrossSpectrum
class PyRIIDModel (*args, **kwargs)

Base class for PyRIID models.

Expand source code Browse git
class PyRIIDModel:
    """Base class for PyRIID models."""

    def __init__(self, *args, **kwargs):
        self._info = {}
        self._temp_file_path = "temp_model.json"
        self._custom_objects = {}
        self._initialize_info()

    @property
    def seeds(self):
        return self._info["seeds"]

    @seeds.setter
    def seeds(self, value):
        self._info["seeds"] = value

    @property
    def info(self):
        return self._info

    @info.setter
    def info(self, value):
        self._info = value

    @property
    def target_level(self):
        return self._info["target_level"]

    @target_level.setter
    def target_level(self, value):
        if value in SampleSet.SOURCES_MULTI_INDEX_NAMES:
            self._info["target_level"] = value
        else:
            msg = (
                f"Target level '{value}' is invalid.  "
                f"Acceptable levels: {SampleSet.SOURCES_MULTI_INDEX_NAMES}"
            )
            raise ValueError(msg)

    @property
    def model(self) -> Model:
        return self._model

    @model.setter
    def model(self, value: Model):
        self._model = value

    @property
    def model_id(self):
        return self._info["model_id"]

    @model_id.setter
    def model_id(self, value):
        self._info["model_id"] = value

    @property
    def model_inputs(self):
        return self._info["model_inputs"]

    @model_inputs.setter
    def model_inputs(self, value):
        self._info["model_inputs"] = value

    @property
    def model_outputs(self):
        return self._info["model_outputs"]

    @model_outputs.setter
    def model_outputs(self, value):
        self._info["model_outputs"] = value

    def get_model_outputs_as_label_tuples(self):
        return [
            label_to_index_element(v, self.target_level) for v in self.model_outputs
        ]

    def _get_model_dict(self) -> dict:
        model_json = self.model.to_json()
        model_dict = json.loads(model_json)
        model_weights = self.model.get_weights()
        model_dict = {
            "info": self._info,
            "model": model_dict,
            "weights": model_weights,
        }
        return model_dict

    def _get_model_str(self) -> str:
        model_dict = self._get_model_dict()
        model_str = json.dumps(model_dict, indent=4, cls=PyRIIDModelJsonEncoder)
        return model_str

    def _initialize_info(self):
        init_info = {
            "model_id": str(uuid.uuid4()),
            "model_type": self.__class__.__name__,
            "normalization": SpectraState.Unknown,
            "pyriid_version": riid.__version__,
        }
        self._update_info(**init_info)

    def _update_info(self, **kwargs):
        self._info.update(kwargs)

    def _update_custom_objects(self, key, value):
        self._custom_objects.update({key: value})

    def load(self, model_path: str):
        """Load the model from a path.

        Args:
            model_path: path from which to load the model.
        """
        if not os.path.exists(model_path):
            raise ValueError("Model file does not exist.")

        with open(model_path) as fin:
            model = json.load(fin)

        model_str = json.dumps(model["model"])
        self.model = tf.keras.models.model_from_json(model_str, custom_objects=self._custom_objects)
        self.model.set_weights([np.array(x) for x in model["weights"]])
        self.info = model["info"]

    def save(self, model_path: str, overwrite=False):
        """Save the model to a path.

        Args:
            model_path: path at which to save the model.
            overwrite: whether to overwrite an existing file if it already exists.

        Raises:
            `ValueError` when the given path already exists
        """
        if os.path.exists(model_path) and not overwrite:
            raise ValueError("Model file already exists.")

        model_str = self._get_model_str()
        with open(model_path, "w") as fout:
            fout.write(model_str)

    def to_onnx(self, model_path, **tf2onnx_kwargs: dict):
        """Convert the model to an ONNX model.

        Args:
            model_path: path at which to save the model
            tf2onnx_kwargs: additional kwargs to pass to the conversion
        """
        model_path = Path(model_path)
        if not str(model_path).endswith(riid.ONNX_MODEL_FILE_EXTENSION):
            raise ValueError(f"ONNX file path must end with {riid.ONNX_MODEL_FILE_EXTENSION}")
        if model_path.exists():
            raise ValueError("Model file already exists.")

        tf2onnx.convert.from_keras(
            self.model,
            input_signature=[
                tf.TensorSpec(
                    shape=input_tensor.shape,
                    dtype=input_tensor.dtype,
                    name=input_tensor.name
                )
                for input_tensor in self.model.inputs
            ],
            output_path=str(model_path),
            **tf2onnx_kwargs
        )

    def to_tflite(self, model_path, quantize: bool = False, prune: bool = False):
        """Convert the model to a TFLite model and optionally applying quantization or pruning.

        Note: requires export to SavedModel format first, then conversion to TFLite occurs.

        Args:
            model_path: file path at which to save the model
            quantize: whether to apply quantization
            prune: whether to apply pruning
        """
        model_path = Path(model_path)
        if not str(model_path).endswith(riid.TFLITE_MODEL_FILE_EXTENSION):
            raise ValueError(f"TFLite file path must end with {riid.TFLITE_MODEL_FILE_EXTENSION}")
        if model_path.exists():
            raise ValueError("Model file already exists.")

        optimizations = []
        if quantize:
            optimizations.append(tf.lite.Optimize.DEFAULT)
        if prune:
            optimizations.append(tf.lite.Optimize.EXPERIMENTAL_SPARSITY)

        saved_model_dir = model_path.stem
        self.model.export(saved_model_dir)
        converter = tf.lite.TFLiteConverter.from_saved_model(str(saved_model_dir))
        converter.optimizations = optimizations
        tflite_model = converter.convert()

        with open(model_path, "wb") as fout:
            fout.write(tflite_model)

    @abstractmethod
    def fit(self):
        pass

    @abstractmethod
    def predict(self):
        pass

Subclasses

Instance variables

var info
Expand source code Browse git
@property
def info(self):
    return self._info
var model : keras.src.models.model.Model
Expand source code Browse git
@property
def model(self) -> Model:
    return self._model
var model_id
Expand source code Browse git
@property
def model_id(self):
    return self._info["model_id"]
var model_inputs
Expand source code Browse git
@property
def model_inputs(self):
    return self._info["model_inputs"]
var model_outputs
Expand source code Browse git
@property
def model_outputs(self):
    return self._info["model_outputs"]
var seeds
Expand source code Browse git
@property
def seeds(self):
    return self._info["seeds"]
var target_level
Expand source code Browse git
@property
def target_level(self):
    return self._info["target_level"]

Methods

def fit(self)
Expand source code Browse git
@abstractmethod
def fit(self):
    pass
def get_model_outputs_as_label_tuples(self)
Expand source code Browse git
def get_model_outputs_as_label_tuples(self):
    return [
        label_to_index_element(v, self.target_level) for v in self.model_outputs
    ]
def load(self, model_path: str)

Load the model from a path.

Args

model_path
path from which to load the model.
Expand source code Browse git
def load(self, model_path: str):
    """Load the model from a path.

    Args:
        model_path: path from which to load the model.
    """
    if not os.path.exists(model_path):
        raise ValueError("Model file does not exist.")

    with open(model_path) as fin:
        model = json.load(fin)

    model_str = json.dumps(model["model"])
    self.model = tf.keras.models.model_from_json(model_str, custom_objects=self._custom_objects)
    self.model.set_weights([np.array(x) for x in model["weights"]])
    self.info = model["info"]
def predict(self)
Expand source code Browse git
@abstractmethod
def predict(self):
    pass
def save(self, model_path: str, overwrite=False)

Save the model to a path.

Args

model_path
path at which to save the model.
overwrite
whether to overwrite an existing file if it already exists.

Raises

ValueError when the given path already exists

Expand source code Browse git
def save(self, model_path: str, overwrite=False):
    """Save the model to a path.

    Args:
        model_path: path at which to save the model.
        overwrite: whether to overwrite an existing file if it already exists.

    Raises:
        `ValueError` when the given path already exists
    """
    if os.path.exists(model_path) and not overwrite:
        raise ValueError("Model file already exists.")

    model_str = self._get_model_str()
    with open(model_path, "w") as fout:
        fout.write(model_str)
def to_onnx(self, model_path, **tf2onnx_kwargs: dict)

Convert the model to an ONNX model.

Args

model_path
path at which to save the model
tf2onnx_kwargs
additional kwargs to pass to the conversion
Expand source code Browse git
def to_onnx(self, model_path, **tf2onnx_kwargs: dict):
    """Convert the model to an ONNX model.

    Args:
        model_path: path at which to save the model
        tf2onnx_kwargs: additional kwargs to pass to the conversion
    """
    model_path = Path(model_path)
    if not str(model_path).endswith(riid.ONNX_MODEL_FILE_EXTENSION):
        raise ValueError(f"ONNX file path must end with {riid.ONNX_MODEL_FILE_EXTENSION}")
    if model_path.exists():
        raise ValueError("Model file already exists.")

    tf2onnx.convert.from_keras(
        self.model,
        input_signature=[
            tf.TensorSpec(
                shape=input_tensor.shape,
                dtype=input_tensor.dtype,
                name=input_tensor.name
            )
            for input_tensor in self.model.inputs
        ],
        output_path=str(model_path),
        **tf2onnx_kwargs
    )
def to_tflite(self, model_path, quantize: bool = False, prune: bool = False)

Convert the model to a TFLite model and optionally applying quantization or pruning.

Note: requires export to SavedModel format first, then conversion to TFLite occurs.

Args

model_path
file path at which to save the model
quantize
whether to apply quantization
prune
whether to apply pruning
Expand source code Browse git
def to_tflite(self, model_path, quantize: bool = False, prune: bool = False):
    """Convert the model to a TFLite model and optionally applying quantization or pruning.

    Note: requires export to SavedModel format first, then conversion to TFLite occurs.

    Args:
        model_path: file path at which to save the model
        quantize: whether to apply quantization
        prune: whether to apply pruning
    """
    model_path = Path(model_path)
    if not str(model_path).endswith(riid.TFLITE_MODEL_FILE_EXTENSION):
        raise ValueError(f"TFLite file path must end with {riid.TFLITE_MODEL_FILE_EXTENSION}")
    if model_path.exists():
        raise ValueError("Model file already exists.")

    optimizations = []
    if quantize:
        optimizations.append(tf.lite.Optimize.DEFAULT)
    if prune:
        optimizations.append(tf.lite.Optimize.EXPERIMENTAL_SPARSITY)

    saved_model_dir = model_path.stem
    self.model.export(saved_model_dir)
    converter = tf.lite.TFLiteConverter.from_saved_model(str(saved_model_dir))
    converter.optimizations = optimizations
    tflite_model = converter.convert()

    with open(model_path, "wb") as fout:
        fout.write(tflite_model)
class PyRIIDModelJsonEncoder (*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, default=None)

Custom JSON encoder for saving models.

Constructor for JSONEncoder, with sensible defaults.

If skipkeys is false, then it is a TypeError to attempt encoding of keys that are not str, int, float or None. If skipkeys is True, such items are simply skipped.

If ensure_ascii is true, the output is guaranteed to be str objects with all incoming non-ASCII characters escaped. If ensure_ascii is false, the output can contain non-ASCII characters.

If check_circular is true, then lists, dicts, and custom encoded objects will be checked for circular references during encoding to prevent an infinite recursion (which would cause an RecursionError). Otherwise, no such check takes place.

If allow_nan is true, then NaN, Infinity, and -Infinity will be encoded as such. This behavior is not JSON specification compliant, but is consistent with most JavaScript based encoders and decoders. Otherwise, it will be a ValueError to encode such floats.

If sort_keys is true, then the output of dictionaries will be sorted by key; this is useful for regression tests to ensure that JSON serializations can be compared on a day-to-day basis.

If indent is a non-negative integer, then JSON array elements and object members will be pretty-printed with that indent level. An indent level of 0 will only insert newlines. None is the most compact representation.

If specified, separators should be an (item_separator, key_separator) tuple. The default is (', ', ': ') if indent is None and (',', ': ') otherwise. To get the most compact JSON representation, you should specify (',', ':') to eliminate whitespace.

If specified, default is a function that gets called for objects that can't otherwise be serialized. It should return a JSON encodable version of the object or raise a TypeError.

Expand source code Browse git
class PyRIIDModelJsonEncoder(json.JSONEncoder):
    """Custom JSON encoder for saving models.
    """
    def default(self, o):
        """Converts certain types to JSON-compatible types.
        """
        if isinstance(o, np.ndarray):
            return o.tolist()
        elif isinstance(o, np.float32):
            return o.astype(float)

        return super().default(o)

Ancestors

  • json.encoder.JSONEncoder

Methods

def default(self, o)

Converts certain types to JSON-compatible types.

Expand source code Browse git
def default(self, o):
    """Converts certain types to JSON-compatible types.
    """
    if isinstance(o, np.ndarray):
        return o.tolist()
    elif isinstance(o, np.float32):
        return o.astype(float)

    return super().default(o)