Source code for peat.data.data_utils

from __future__ import annotations

import copy
from collections import ChainMap
from collections.abc import Callable
from operator import attrgetter
from typing import Any

from peat import PeatError, config, log, utils

from .base_model import BaseModel


[docs] class DeepChainMap(ChainMap): """ Variant of :class:`collections.ChainMap` that supports edits of nested :class:`dict` objects. In PEAT, this is used for providing nested sets of device and protocol options (configurations) with the ability the modify the underlying sources (e.g. a set of global runtime defaults) and still preserve the order of precedence and transparent overriding (e.g. options configured at runtime for a specific device will still override the global defaults, even though the global defaults were also modified at runtime). Nested objects can override keys at various levels without overriding the parent structure. This is best explained via examples. .. code-block:: python >>> from peat.data.data_utils import DeepChainMap >>> layer1 = {} >>> layer2 = {"key": 9999} >>> layer3 = {"deep_object": {"deep_key": "The Deep"}} >>> deep_map = DeepChainMap(layer1, layer2, layer3) >>> deep_map["key"] 9999 >>> layer1["key"] = -1111 >>> deep_map["key"] -1111 >>> layer2["key"] 9999 >>> deep_map["deep_object"]["deep_key"] 'The Deep' >>> layer1["deep_object"] = {"another_key": "another_value"} >>> deep_map["deep_object"]["deep_key"] 'The Deep' >>> deep_map["deep_object"]["another_key"] 'another_value' """ def __getitem__(self, key): values = [] for mapping in self.maps: try: values.append(mapping[key]) except KeyError: pass if not values: return self.__missing__(key) first = values.pop(0) rv = first if isinstance(first, dict): values = [x for x in values if isinstance(x, dict)] if values: values.insert(0, first) rv = self.__class__(*values) return rv
[docs] def to_dict(self, to_convert: DeepChainMap | None = None) -> dict: """Create a copy of the object as a normal :class:`dict`.""" if to_convert is None: to_convert = self converted = {} for key, value in dict(to_convert).items(): if isinstance(value, DeepChainMap): converted[key] = self.to_dict(value) else: converted[key] = value return converted
[docs] def lookup_by_str(container: list[BaseModel], value: BaseModel, lookup: str) -> int | None: """ String of attribute to search for, e.g. ``"ip"`` to lookup interfaces using ``Interface.ip`` attribute on the value. """ if not container: return None if hasattr(value, lookup) and not value.is_default(lookup): value_to_find = getattr(value, lookup) if value_to_find not in [None, ""]: return find_position(container, lookup, value_to_find) return None
[docs] def find_position(obj: list[BaseModel], key: str, value: Any) -> int | None: """Find if and where an object with a given value is in a :class:`list`.""" for index, item in enumerate(obj): if getattr(item, key, None) == value: return index return None
[docs] def match_all(obj_list: list[BaseModel], value: dict[str, Any]) -> int | None: """Search the list for objects where all values in value match.""" if not value: return None for loc, item in enumerate(obj_list): # If all the values match their corresponding entries in item # then return it's location vals = item.dict(exclude_defaults=True, exclude_none=True) if all(vals.get(key) == value[key] for key in value.keys()): return loc return None
[docs] def strip_empty_and_private( obj: dict, strip_empty: bool = True, strip_private: bool = True ) -> dict: """Recursively removes empty values and keys starting with ``_``.""" new = {} for key, value in obj.items(): if strip_private and _is_private(key): continue elif strip_empty: # NOTE: checking type is required to prevent stripping "False", "-1", etc. if _is_empty(value): continue if isinstance(value, dict): stripped = strip_empty_and_private(value, strip_empty, strip_private) if strip_empty and not stripped: continue else: new[key] = stripped elif isinstance(value, list): # NOTE: lists are not recursively stripped new[key] = [ ( strip_empty_and_private(v, strip_empty, strip_private) if isinstance(v, dict) else v ) for v in value if not _is_empty(v) ] elif isinstance(value, set): for empty_val in ["", None]: if empty_val in value: value.remove(empty_val) if value: new[key] = value else: new[key] = value else: new[key] = value return new
def _is_empty(v: Any | None) -> bool: return bool(v is None or (isinstance(v, (str, bytes, dict, list, set)) and not v)) def _is_private(key: Any) -> bool: return bool(isinstance(key, str) and key.startswith("_"))
[docs] def strip_key(obj: dict, bad_key: str) -> dict: """ Recursively removes all matching keys from a :class:`dict`. .. warning:: This will NOT strip values out of a :class:`list` of :class:`dict`! """ new = {} for key, value in obj.items(): if key != bad_key: if isinstance(value, dict): new[key] = strip_key(value, bad_key) else: new[key] = value return new
[docs] def only_include_keys(obj: dict, allowed_keys: str | list[str]) -> dict: """ Filters any keys that don't match the allowed list of keys. """ new = {} if isinstance(allowed_keys, str): allowed_keys = [allowed_keys] for key, value in obj.items(): if key in allowed_keys: new[key] = value return new
[docs] def compare_dicts(d1: dict | None, d2: dict | None, keys: list[str]) -> bool: if not d1 or not d2 or not keys: raise PeatError("bad compare_dicts args") for key in keys: if d1.get(key) is None or d2.get(key) is None: continue if d1.get(key) != d2.get(key): return False return True
[docs] def dedupe_model_list(current: list[BaseModel]) -> list[BaseModel]: """ Deduplicates a :class:`list` of :class:`~peat.data.base_model.BaseModel` objects while preserving the original order. Models that are a subset of another (contains some keys and values) will be merged together and their values combined. .. warning:: This function is expensive to call, ~O(n^2 log n) algorithm (in)efficiency. Do not call more than absolutely needed! Args: current: list of models to deduplicate Returns: List of deduplicated items """ # Don't bother with empty or single-element lists if not current or len(current) < 2: return current if not isinstance(current[0], BaseModel): raise PeatError(f"expected BaseModel for dedupe_model_list, got {type(current[0])}") # NOTE (cegoes, 02/21/2023) # # There is a lot going on here. This was originally an atrocious O(n^3) function # (actually,it was close to O(n^4) before my first set of optimizations). # # The main hotspots are: # - Nested for loops mean all operations are done twice, O(n^2) # - Dict comparisons (==, <) will compare every key and value in the dict, O(n) # - Pydantic model comparison is very slow, since it converts # to a dict under the hood every time (yeah...so like O(n) or O(n log n)) # - Function calls are expensive in Python, and that just adds to the cost of # each iteration of n. # # Solutions: # - Convert all models to dicts at the start. This avoids the issues with Pydantic # converting on every comparison. Additionally, this caches the id() of the # model. The id is used to check if the model is a duplicate, since it's an # int and can be stored in a set, which has O(1) lookups. # # - Two sets of dicts for the two loops. When a duplicate is found, or a merge occurs, # then the duplicated/merged item is removed from the dict for the inner loop. This # changes O(n^2) to O(n log n), since the inner loop shrinks as the algorithm progresses. # In the case all items are duplicates, then this is close to O(n), while the case where # all items are unique it's closer to O(n^2), but it's a good tradeoff, since we usually # sit somewhere in the middle in PEAT. # # When items are merged, the inner dict it updated with the new value, so it can be used # for future comparisons. merge_models() is also called, which handles updating the actual # underlying model in-place, which updates the ultimate result of this function (yay for # classes and pass by reference). # # - For the subset comparison, use '<' to compare the dict items. dict.items() is a # memoryview object, so it's as fast as we're going to get for the inherrantly slow # operation of comparing every key and value between two dicts. '<=' is not needed # since '==' is already done before entering the subset comparison section of the # code, which is a minor but notable optimization (~15-20% faster). # hack to prevent recursive imports (data_utils.py/models.py) model_type = current[0].__repr_name__() # type: str duplicates = set() # type: set[int] model_cache = {id(m): m for m in current} # type: dict[int, BaseModel] outer_dicts = {id(m): m.dict(exclude_defaults=True, exclude_none=True) for m in current} # type: dict[int, dict] inner_dicts = copy.deepcopy(outer_dicts) # type: dict[int, dict] for item_id, item_dict in outer_dicts.items(): if item_id in duplicates: continue # outer loop for comp_id, comp_dict in inner_dicts.items(): # Skip if it's in the excluded set or it's the same item if comp_id in duplicates or comp_id == item_id: continue # inner loop # If they're equal, it's a duplicate elif item_dict == comp_dict: duplicates.add(item_id) # add to set of duplicates del inner_dicts[item_id] # remove from future comparisons break # inner loop # If dict key sets are disjoint, then merge them # If it's a Service, and "status" is "open", preserve that value # Using subset with "dict.items()": https://stackoverflow.com/a/41579450 elif (item_dict.items() < comp_dict.items()) or ( model_type == "Service" and ( comp_dict.get("status") == "verified" or (comp_dict.get("status") == "open" and item_dict.get("status") == "closed") ) and compare_dicts(item_dict, comp_dict, ["port", "protocol"]) ): # Update the underlying model which will be in the results merge_models(model_cache[comp_id], model_cache[item_id]) # Update the cached dict value to use for remaining comparisons inner_dicts[comp_id] = model_cache[comp_id].dict( exclude_defaults=True, exclude_none=True ) # Model was merged, so remove it from future checks duplicates.add(item_id) # add to set of duplicates del inner_dicts[item_id] # remove from future comparisons break # inner loop # Create a de-duplicated list of objects # by excluding those that were marked as duplicate deduped = [model for model_id, model in model_cache.items() if model_id not in duplicates] # type: list[BaseModel] if duplicates and config.DEBUG: log.trace( f"Removed {len(duplicates)} duplicates from list of {len(current)} " f"{model_type} items ({len(deduped)} items remaining in list)" ) return deduped
[docs] def none_aware_attrgetter(attrs: tuple[str]) -> Callable: """ Variant of ``operator.attrgetter()`` that handles values that may be :obj:`None`. """ def g(obj) -> tuple: pairs = [] for attr in attrs: value = getattr(obj, attr) pairs.append(value is None) pairs.append(value) return tuple(pairs) return g
[docs] def sort_model_list(model_list: list[BaseModel]) -> None: """ In-place sort of a :class:`list` of models. The attribute ``_sort_by_fields`` on the first model in the list is used to sort the models. Raises: PeatError: invalid type in list or ``_sort_by_fields`` is undefined on the model being sorted """ if not model_list or len(model_list) < 2: return if not isinstance(model_list[0], BaseModel): raise PeatError(f"expected BaseModel for sort_model_list, got {type(model_list[0])}") if not getattr(model_list[0], "_sort_by_fields", None): raise PeatError( f"No '_sort_by_fields' attribute on model class " f"'{model_list[0].__repr_name__()}' to use for sorting" ) if config.DEBUG >= 3: log.debug(f"Sorting '{model_list[0].__repr_name__()}' list with {len(model_list)} items") model_list.sort(key=none_aware_attrgetter(model_list[0]._sort_by_fields))
[docs] def merge_models(dest: BaseModel, source: BaseModel) -> None: """ Copy values from one model to another. """ if not dest or not source: return if not isinstance(source, BaseModel): raise PeatError(f"non-model source: {source}") dst_type = dest.__repr_name__() # type: str src_type = source.__repr_name__() # type: str if dst_type != src_type: raise PeatError(f"merge_models: '{dst_type}' != '{src_type}'") # TODO: hack to make DeviceData merging work if dst_type == "DeviceData": if source.module: for mod_to_merge in source.module: # If there's an existing module in same slot, merge the contents for curr_mod in dest.module: if ( curr_mod.slot and mod_to_merge.slot and curr_mod.slot == mod_to_merge.slot ) or ( curr_mod.serial_number and mod_to_merge.serial_number and curr_mod.serial_number == mod_to_merge.serial_number ): merge_models(curr_mod, mod_to_merge) break # Append the module else: dest.module.append(mod_to_merge) dest.module.sort(key=attrgetter("slot")) # Sort modules by Slot ID # WARNING: do NOT call source.dict(...) here! # dict(source) converts just the top-level model to a dict, not sub-models. # source.dict(...) will convert all sub-models to dicts, which is no bueno. source_dict = dict(source) overwrite = False if dst_type == "Service" and ( source_dict.get("status") == "verified" or (source_dict.get("status") == "open" and dest.status == "closed") ): overwrite = True for attr, new_value in source_dict.items(): # If it's None for some reason (e.g. a default), we don't care if new_value is None: continue if not hasattr(dest, attr): raise PeatError(f"No attribute for key '{attr}'. Value: {new_value}") # !! hack to make DeviceData merging work !! if dst_type == "DeviceData" and attr == "module": continue current_value = getattr(dest, attr) # Skip if the values match # Skip if the source is a default model if new_value == current_value or ( isinstance(source, BaseModel) and source.is_default(attr) ): continue # If they're models (e.g., "Hardware"), use merge_models to handle the merging elif isinstance(current_value, BaseModel): merge_models(current_value, new_value) # Merge dicts, preserving existing values # NOTE: this is usually ".extra" fields elif isinstance(current_value, dict): utils.merge(current_value, new_value, no_copy=True) # Sets automatically remove duplicate values elif isinstance(current_value, set): current_value.update(new_value) # Combine, deduplicate, and sort lists elif isinstance(current_value, list): if current_value and new_value: if isinstance(current_value[0], BaseModel): current_value.extend(new_value) dedupe_model_list(current_value) sort_model_list(current_value) else: for new_item in new_value: if not any(new_item == c for c in current_value): current_value.append(new_item) elif new_value: setattr(dest, attr, new_value) # If the destination value is a default value, then copy the value # This won't overwriting existing values on destination # NOTE: using setattr() will also trigger value validation by Pydantic elif dest.is_default(attr): setattr(dest, attr, new_value) elif overwrite: msg = ( f"Changed existing value for field '{attr}' with " f"'{new_value}' (old value: '{current_value}')" ) if attr == "status": log.debug(msg) else: log.warning(msg) setattr(dest, attr, new_value) elif config.DEBUG >= 4: log.warning( f"Skipping merge of existing non-default attribute '{attr}' for " f"'{dest.__class__.__name__}' model (new_value={new_value} " f"current_value={current_value})" )