Source code for PyNucleus_base.utilsFem

###################################################################################
# Copyright 2021 National Technology & Engineering Solutions of Sandia,           #
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the           #
# U.S. Government retains certain rights in this software.                        #
# If you want to use this code, please refer to the README.rst and LICENSE files. #
###################################################################################

import numpy as np
import logging
import os
import atexit
import sys
import traceback
import re
import argparse
import mpi4py
mpi4py.rc.initialize = False
from mpi4py import MPI
from collections import OrderedDict
from copy import deepcopy
import inspect
from . myTypes import INDEX, REAL
from . performanceLogger import PLogger, LoggingTimer, Timer, LoggingPLogger
from . blas import uninitialized, uninitialized_like

_syncDefault = False


[docs] def setSyncDefault(sync): global _syncDefault _syncDefault = sync
[docs] class TimerManager:
[docs] def __init__(self, logger, comm=None, print_rank=0, prefix='', myPLogger=None, memoryProfiling=False, loggingSubTimers=False): self.logger = logger self.comm = comm self.print_rank = print_rank self.prefix = prefix self.memoryProfiling = memoryProfiling if self.memoryProfiling: from psutil import Process self.process = Process(os.getpid()) else: self.process = None if myPLogger is None: if loggingSubTimers: self.PLogger = LoggingPLogger(self.logger, logging.INFO, self.process) else: self.PLogger = PLogger(self.process) self.isSubManager = False self.totalTimer = Timer('total', self.PLogger, forceMemRegionOff=True) self.totalTimer.__enter__() else: self.isSubManager = True self.PLogger = myPLogger
[docs] def getTimer(self, FinalMessage=None, StartMessage=None, level=logging.INFO, overrideComm=None, sync=_syncDefault): if overrideComm is None: comm = self.comm else: comm = overrideComm if comm is None or comm.rank == self.print_rank: t = LoggingTimer(self.logger, level, FinalMessage, self.PLogger) else: t = Timer(FinalMessage, self.PLogger) return t
[docs] def __call__(self, *args, **kwargs): return self.getTimer(*args, **kwargs)
[docs] def setOutputGroup(self, rank, oG): def mergeOrdered(a_list, b_list): keys = [] while (len(a_list) > 0) and (len(b_list) > 0): if a_list[0] in b_list: if a_list[0] == b_list[0]: keys.append(a_list[0]) a_list.pop(0) b_list.pop(0) else: keys.append(b_list[0]) a_list.remove(b_list[0]) b_list.pop(0) elif b_list[0] in a_list: keys.append(a_list[0]) a_list.pop(0) else: keys.append(a_list[0]) a_list.pop(0) keys += a_list keys += b_list return keys if not self.isSubManager: self.totalTimer.__exit__(None, None, None) data = self.PLogger.values data2 = OrderedDict() for key in data.keys(): val = data[key] # (number of calls, min over calls, mean over calls, med over calls, max over calls) # on rank try: data2[key] = (len(val), np.min(val), np.mean(val), np.median(val), np.max(val)) except: pass data = data2 # gather data for all ranks if self.comm is not None: data = self.comm.gather(data, root=rank) else: data = [data] if self.comm is None or self.comm.rank == rank: if self.comm is None: commSize = 1 else: commSize = self.comm.size assert len(data) == commSize keys = list(data[0].keys()) for i in range(1, len(data)): keys = mergeOrdered(keys, list(data[i].keys())) pData = {} for i in range(len(data)): for key in keys: d = data[i].get(key, (0, np.inf, 0., 0., -np.inf)) try: pData[key].append(d) except KeyError: pData[key] = [d] for key in pData: assert len(pData[key]) == commSize, key oG.add(key, pData[key])
[docs] def getSubManager(self, logger): return TimerManager(logger, self.comm, self.print_rank, self.prefix, self.PLogger)
[docs] def getLoggingTimer(logger, comm=None, print_rank=0, prefix='', rootOutput=False): def getTimer(FinalMessage='', StartMessage='', level=logging.INFO, overrideComm=None, sync=_syncDefault): from . performanceLogger import FakePLogger # return Timer(prefix=prefix, FinalMessage=FinalMessage, StartMessage=StartMessage, # logger=logger, level=level, comm=comm, overrideComm=overrideComm, print_rank=print_rank, sync=sync) if StartMessage != '': StartMessage = prefix+StartMessage return LoggingTimer(logger, level, prefix+FinalMessage, FakePLogger(), StartMessage=StartMessage, comm=comm, sync=sync) return getTimer
display_available = ("DISPLAY" in os.environ and "SSH_CONNECTION" not in os.environ)
[docs] def computeErrors(levels, solutions, norm, timeVals=None): assert len(levels) == len(solutions) if timeVals is not None: for lvlNo in range(len(levels)): assert solutions[lvlNo].shape[0] == len(timeVals[lvlNo]) for lvlNo in range(1, len(levels)): assert solutions[lvlNo].shape[1] == levels[lvlNo]['P'].shape[0] assert solutions[lvlNo-1].shape[1] == levels[lvlNo]['P'].shape[1] errors = [] uFine = solutions[-1] for lvlNo in range(len(levels)-1): u = solutions[lvlNo] for j in range(lvlNo+1, len(levels)): u2 = uninitialized((u.shape[0], levels[j]['P'].shape[0]), dtype=REAL) for k in range(u.shape[0]): u2[k, :] = levels[j]['P']*u[k, :] u = u2 if timeVals is not None: # linear interpolation in time uNew = uninitialized_like(uFine) uNew[0, :] = u[0, :] uNew[-1, :] = u[-1, :] for k in range(1, uFine.shape[0]-1): t = timeVals[-1][k] j = 0 while timeVals[lvlNo][j+1] < t: j += 1 t0 = timeVals[lvlNo][j] t1 = timeVals[lvlNo][j+1] assert t0 <= t <= t1 uNew[k, :] = (t1-t)/(t1-t0)*u[j, :] + (t-t0)/(t1-t0)*u[j+1, :] u = uNew errors.append(norm(u-uFine)) return errors
[docs] def roc(idx, val, FillUp=False, exp=False): "Calculates the rate of convergence." idx, val = np.atleast_2d(idx), np.atleast_2d(val) if idx.shape[0] == 1: idx = idx.transpose() if idx.shape[0] != val.shape[0]: val = val.transpose() if idx.shape[0] != val.shape[0]: raise Exception('Arrays of sizes {} and {} not compatible.'.format(idx.shape[0], val.shape[0])) if exp: rate = np.log(val[0:-1, :]/val[1:, :])/(idx[0:-1, :]-idx[1:, :]) else: rate = np.log(val[0:-1, :]/val[1:, :])/np.log(idx[0:-1, :]/idx[1:, :]) if FillUp: return np.vstack([rate, [np.nan]]) else: return rate
[docs] class exitHandler(object):
[docs] def __init__(self, comm): self.comm = comm self.exit_code = None self.exception = None self.exc_type = None self._orig_exit = sys.exit sys.exit = self.exit sys.excepthook = self.exc_handler atexit.register(self.atExitHandler)
[docs] def exit(self, code=0): self.exit_code = code self._orig_exit(code)
[docs] def exc_handler(self, exc_type, exc, *args): self.exc_type = exc_type self.exception = exc
[docs] def atExitHandler(self): if self.exit_code is not None and self.exit_code != 0: logging.error("death by sys.exit(%d)" % self.exit_code) self.comm.Abort(self.exit_code) elif self.exception is not None: lines = traceback.format_exception(self.exc_type, self.exception, tb=self.exception.__traceback__) msg = ''.join(['{}: {}'.format(self.comm.rank, line) for line in lines]) logging.error('\n'+msg) self.comm.Abort(1234)
[docs] def saveDictToHDF5(params, f, ignore=set()): import h5py for key, val in params.items(): if key in ignore: continue if isinstance(val, dict): g = f.create_group(key) saveDictToHDF5(val, g) elif isinstance(val, np.ndarray): f.create_dataset(key, data=val) elif isinstance(val, list): try: if isinstance(val[0], list) and isinstance(val[0][0], (int, float, INDEX, REAL)): raise ValueError() f.create_dataset(key, data=np.array(val)) except: if isinstance(val[0], list) and isinstance(val[0][0], (int, float, INDEX, REAL)): g = f.create_group(key) g.attrs['type'] = 'compressedList' listItems = 0 for i in range(len(val)): listItems += len(val[i]) indptr = uninitialized((len(val)+1), dtype=INDEX) if isinstance(val[0][0], (int, INDEX)): data = uninitialized((listItems), dtype=INDEX) else: data = uninitialized((listItems), dtype=REAL) listItems = 0 for i in range(len(val)): indptr[i] = listItems data[listItems:listItems+len(val[i])] = val[i] listItems += len(val[i]) indptr[-1] = listItems g.create_dataset('indptr', data=indptr) g.create_dataset('data', data=data) elif isinstance(val[0], str): f.create_dataset(key, data=np.array(val, dtype=np.string_)) else: g = f.create_group(key) g.attrs['type'] = 'list' for k in range(len(val)): g.attrs[str(k)] = val[k] elif val is None: try: f.attrs[key] = h5py.Empty(np.dtype("f")) except AttributeError: print('Failed to write \'{}\' because h5py is too old.'.format(key)) elif hasattr(val, 'HDF5write') and callable(val.HDF5write): g = f.create_group(key) val.HDF5write(g) elif hasattr(val, 'toarray') and callable(val.toarray): f.create_dataset(key, data=val.toarray()) else: try: f.attrs[key] = val except: try: import pickle f.attrs[key] = np.void(pickle.dumps(val)) except: print('Failed to write \'{}\''.format(key)) f.attrs[key] = str(val)
[docs] def loadDictFromHDF5(f): import h5py from . linear_operators import LinearOperator from PyNucleus_fem.DoFMaps import DoFMap params = {} for key in f.attrs: if isinstance(f.attrs[key], h5py.Empty): params[key] = None else: try: import pickle params[key] = pickle.loads(f.attrs[key]) except: params[key] = f.attrs[key] for key in f: if isinstance(f[key], h5py.Group): if 'type' in f[key].attrs: if f[key].attrs['type'] == 'list': myList = [] for k in range(len(f[key].attrs)-1): myList.append(f[key].attrs[str(k)]) params[key] = myList elif f[key].attrs['type'] == 'compressedList': myCompressedList = [] indptr = np.array(f[key]['indptr'], dtype=INDEX) if isinstance(f[key]['data'], (int, INDEX)): data = np.array(f[key]['data'], dtype=INDEX) else: data = np.array(f[key]['data'], dtype=REAL) for i in range(len(indptr)-1): myCompressedList.append(data[indptr[i]:indptr[i+1]].tolist()) params[key] = myCompressedList elif f[key].attrs['type'] == 'series': d = loadDictFromHDF5(f[key]) grp = seriesOutputGroup(key) grp.fromDict(d) params[key] = grp elif f[key].attrs['type'] == 'DoFMap': params[key] = DoFMap.HDF5read(f[key]) elif f[key].attrs['type'] == 'h2': from PyNucleus_nl.clusterMethodCy import H2Matrix params[key] = H2Matrix.HDF5read(f[key]) else: params[key] = LinearOperator.HDF5read(f[key]) elif 'vertices' in f[key] and 'cells' in f[key]: from PyNucleus_fem.mesh import meshNd params[key] = meshNd.HDF5read(f[key]) else: params[key] = loadDictFromHDF5(f[key]) else: params[key] = np.array(f[key]) try: myList = [] for i in range(len(params[key])): myList.append(params[key][i].decode('utf-8')) params[key] = myList except: pass return params
[docs] def processDictForYaml(params): from PyNucleus_fem.functions import function paramsNew = {} for key in params: if isinstance(params[key], dict): paramsNew[key] = processDictForYaml(params[key]) elif isinstance(params[key], REAL): paramsNew[key] = float(params[key]) elif isinstance(params[key], np.ndarray): if params[key].dtype == REAL: if params[key].ndim == 1: paramsNew[key] = params[key].tolist() for i in range(len(paramsNew[key])): paramsNew[key][i] = float(paramsNew[key][i]) elif params[key].ndim == 2: paramsNew[key] = params[key].tolist() for i in range(len(paramsNew[key])): for j in range(len(paramsNew[key][i])): paramsNew[key][i][j] = float(paramsNew[key][i][j]) else: raise NotImplementedError() else: paramsNew[key] = params[key].tolist() elif isinstance(params[key], list): paramsNew[key] = params[key] for i in range(len(paramsNew[key])): if isinstance(paramsNew[key][i], REAL): paramsNew[key][i] = float(paramsNew[key][i]) elif isinstance(params[key], function): paramsNew[key] = str(params[key]) else: paramsNew[key] = params[key] return paramsNew
[docs] def updateFromDefaults(params, defaults): for key in defaults: if key not in params: params[key] = defaults[key] elif isinstance(defaults[key], dict): updateFromDefaults(params[key], defaults[key])
[docs] def getMPIinfo(grp, verbose=False): from sys import modules if 'mpi4py.MPI' in modules: import mpi4py mpi4py.initialize = False from mpi4py import MPI if not MPI.Is_initialized(): return t = {MPI.THREAD_SINGLE: 'single', MPI.THREAD_FUNNELED: 'funneled', MPI.THREAD_SERIALIZED: 'serialized', MPI.THREAD_MULTIPLE: 'multiple'} hosts = MPI.COMM_WORLD.gather(MPI.Get_processor_name()) if MPI.COMM_WORLD.rank == 0: hosts = ','.join(set(hosts)) grp.add('MPI library', '{}'.format(MPI.Get_library_version()[:-1])) if verbose: for label, value in [('MPI standard supported', MPI.Get_version()), ('Vendor', MPI.get_vendor()), ('Level of thread support', t[MPI.Query_thread()]), ('Is threaded', MPI.Is_thread_main()), ('Threads requested', mpi4py.rc.threads), ('Thread level requested', mpi4py.rc.thread_level)]: grp.add(label, value) for label, value in [('Hosts', hosts), ('Communicator size', MPI.COMM_WORLD.size)]: grp.add(label, value)
[docs] def getEnvVariables(grp, envVars=[('OMP_NUM_THREADS', True)]): from os import environ s = [] for var, printNotSet in envVars: if var in environ: varVal = environ[var] elif printNotSet: varVal = 'not set' else: continue grp.add(var, varVal) return '\n'.join(s)
[docs] def getSystemInfo(grp, argv=None, envVars=[('OMP_NUM_THREADS', True)]): from sys import executable if argv is not None: grp.add('Running', executable + ' ' + ' '.join(argv)) else: grp.add('Running', executable) import mpi4py mpi4py.initialize = False from mpi4py import MPI if MPI.Is_initialized(): getMPIinfo(grp) getEnvVariables(grp, envVars) import pkg_resources from PyNucleus import subpackages versions = {} for pkg in ['numpy', 'scipy', 'mpi4py', 'cython']: version = pkg_resources.get_distribution(pkg).version try: versions[version].append(pkg) except KeyError: versions[version] = [pkg] for version in versions: grp.add(','.join(versions[version]), version) import importlib versions = {} for pkg in sorted(subpackages.keys()): version = pkg_resources.get_distribution('PyNucleus_'+pkg).version module = importlib.import_module('PyNucleus_'+pkg+'.config') sha = module.gitSHA try: versions[(version, sha)].append(pkg) except KeyError: versions[(version, sha)] = [pkg] for version, sha in versions: grp.add('PyNucleus_'+(','.join(versions[(version, sha)])), '{}, {}'.format(version, sha))
[docs] class MPIFileHandler(logging.Handler): """ A handler class which writes formatted logging records to disk files. """
[docs] def __init__(self, filename, comm, mode=MPI.MODE_WRONLY | MPI.MODE_CREATE): from pathlib import Path # filename = os.fspath(filename) # keep the absolute path, otherwise derived classes which use this # may come a cropper when the current directory changes self.baseFilename = os.path.abspath(filename) assert len(self.baseFilename) <= 245, ('The length of the log file path \"{}\" is too long ' + 'and will probably crash MPI. Try running with \"--disableFileLog\"').format(self.baseFilename) if Path(self.baseFilename).exists() and comm.rank == 0: from os import remove remove(self.baseFilename) self.mpiFile = MPI.File.Open(comm, self.baseFilename, mode) self.mpiFile.Set_atomicity(True) logging.Handler.__init__(self)
[docs] def emit(self, record): try: msg = self.format(record)+'\n' mv = memoryview(bytes(msg, encoding='utf-8')) self.mpiFile.Write_shared((mv, len(mv), MPI.BYTE)) self.flush() except Exception: self.handleError(record)
[docs] def sync(self): self.mpiFile.Sync()
def __repr__(self): level = logging.getLevelName(self.level) return '<%s %s (%s)>' % (self.__class__.__name__, self.baseFilename, level)
[docs] def close(self): """ Closes the stream. """ self.acquire() try: self.mpiFile.Close() finally: self.release()
[docs] def columns(lines, returnColWidth=False, colWidth=0): if colWidth == 0: for line, _, _ in lines: colWidth = max(len(line), colWidth) s = [] for line, f, v in lines: if isinstance(f, str): lf = '{:<'+str(colWidth+2)+'}'+f s.append(lf.format(line+':', v)) else: lf = '{:<'+str(colWidth+2)+'}'+'{}' s.append(lf.format(line+':', f(v))) s = '\n'.join(s) if returnColWidth: return s, colWidth else: return s
[docs] class outputParam:
[docs] def __init__(self, label, value, format=None, aTol=None, rTol=None, tested=False): self.label = label if format is None: if isinstance(value, bool): format = '{}' elif isinstance(value, (float, REAL)): format = '{:.3}' elif isinstance(value, (int, INDEX)): format = '{:,}' elif isinstance(value, np.ndarray): formatter = {'float_kind': lambda x: '{:.3}'.format(x)} format = lambda s: np.array2string(s, formatter=formatter, max_line_width=200) else: format = '{}' self.format = format self.value = value self.aTol = aTol self.rTol = rTol self.tested = tested
[docs] class outputGroup:
[docs] def __init__(self, aTol=None, rTol=None, tested=False, driver=None): self.entries = [] self.tested = tested self.aTol = aTol self.rTol = rTol self.driver = driver
[docs] def add(self, label, value, format=None, aTol=None, rTol=None, tested=None): if aTol is None: aTol = self.aTol if rTol is None: rTol = self.rTol if tested is None: tested = self.tested p = outputParam(label, value, format, aTol, rTol, tested) self.entries.append(p)
def __repr__(self): lines = [(p.label, p.format, p.value) for p in self.entries] return columns(lines)
[docs] def log(self): if self.driver is not None: self.driver.logger.info('\n'+str(self)) else: raise NotImplementedError()
def __add__(self, other): c = outputGroup() from copy import deepcopy d = deepcopy(self.entries) d += other.entries c.entries = d return c
[docs] def toDict(self, tested=False): if not tested: return {p.label: p.value for p in self.entries} else: return {p.label: p.value for p in self.entries if p.tested}
[docs] def fromDict(self, d): for key, value in d.items(): self.add(key, value)
def __getattr__(self, key): for p in self.entries: if p.label == key: return p.value raise KeyError(key)
[docs] def diff(self, d): result = {} d = deepcopy(d) for p in self.entries: if p.tested: if p.label in d: aTol = p.aTol if p.aTol is not None else 1e-12 rTol = p.rTol if p.rTol is not None else 1e-12 if isinstance(p.value, (np.ndarray, list)): if len(p.value) == len(d[p.label]): if not np.allclose(p.value, d[p.label], rtol=rTol, atol=aTol): result[p.label] = (p.value, d[p.label]) else: result[p.label] = (p.value, d[p.label]) elif isinstance(p.value, (int, INDEX, REAL, float)): if not np.allclose(p.value, d[p.label], rtol=rTol, atol=aTol) and not (np.isnan(p.value) and np.isnan(d[p.label])): print(p.label, p.value, d[p.label], rTol, aTol, p.rTol, p.aTol) result[p.label] = (p.value, d[p.label]) else: if p.value != d[p.label]: result[p.label] = (p.value, d[p.label]) d.pop(p.label) else: result[p.label] = (p.value, 'Not available') for key in d: result[key] = ('Not available', d[key]) return result
[docs] class statisticOutputGroup(outputGroup):
[docs] def __init__(self, comm, driver=None): super(statisticOutputGroup, self).__init__(driver=driver) self.comm = comm self.doSum = {}
[docs] def add(self, label, value, format=None, aTol=None, rTol=None, tested=None, sumOverRanks=True): value = self.comm.gather(value) if self.comm.rank == 0: self.doSum[label] = sumOverRanks super(statisticOutputGroup, self).add(label, value, format=format, aTol=aTol, rTol=rTol, tested=tested)
def __repr__(self): lines = [] header = ['quantity', 'min', 'mean', 'med', 'max', 'sum'] for p in self.entries: key = p.label data = p.value if self.doSum[key]: lines.append((key, np.min(data), np.mean(data), np.median(data), np.max(data), np.sum(data))) else: lines.append((key, np.min(data), np.mean(data), np.median(data), np.max(data), None)) from tabulate import tabulate return tabulate(lines, headers=header)
[docs] class timerOutputGroup(outputGroup):
[docs] def __init__(self, driver=None): super(timerOutputGroup, self).__init__(driver=driver)
def __repr__(self): lines = [] if len(self.entries) > 0 and len(self.entries[0].value) > 1: header = ['timer', 'numCalls', 'minCall', 'meanCall', 'maxCall', 'minSum', 'meanSum', 'medSum', 'maxSum'] else: header = ['timer', 'numCalls', 'minCall', 'meanCall', 'maxCall', 'sum'] for p in self.entries: key = p.label data = p.value numCalls = np.array([p[0] for p in data]) minNumCalls = np.min(numCalls) meanNumCalls = np.mean(numCalls) medNumCalls = np.median(numCalls) maxNumCalls = np.max(numCalls) # min over min call counts minCall = np.min([p[1] for p in data]) # (\sum_{rank} numCalls*meanPerCall) / (\sum_{rank} numCalls) meanCall = np.sum([p[0]*p[2] for p in data])/numCalls.sum() # max over max call counts maxCall = np.max([p[4] for p in data]) # total time per rank sums = [p[0]*p[2] for p in data] if len(sums) > 1: minSum = np.min(sums) meanSum = np.mean(sums) medSum = np.median(sums) maxSum = np.max(sums) if minNumCalls != maxNumCalls: calls = (minNumCalls, meanNumCalls, medNumCalls, maxNumCalls) else: calls = maxNumCalls lines.append((key, calls, minCall, meanCall, maxCall, minSum, meanSum, medSum, maxSum)) else: lines.append((key, meanNumCalls, minCall, meanCall, maxCall, sums[0])) from tabulate import tabulate return tabulate(lines, headers=header)
[docs] class seriesOutputGroup:
[docs] def __init__(self, name, aTol=None, rTol=None, tested=False, driver=None): self.name = name self.aTol = aTol self.rTol = rTol self.tested = tested self.groups = {} self.driver = driver
[docs] def addGroup(self, label): label = str(label) if label in self.groups: group = self.groups[label] else: group = outputGroup(aTol=self.aTol, rTol=self.rTol, tested=self.tested, driver=self.driver) self.groups[label] = group return group
[docs] def get(self, keyName, valueNames=[], sortBy=None, reverse=False): if sortBy is None: sortBy = keyName if not isinstance(valueNames, (list, tuple)): valueNames = [valueNames] keys = [] values = {valueName: [] for valueName in valueNames} sortKeys = [] for label in sorted(self.groups): try: key = getattr(self.groups[label], keyName) sortKey = getattr(self.groups[label], sortBy) v = {} for valueName in valueNames: v[valueName] = getattr(self.groups[label], valueName) keys.append(key) for valueName in valueNames: values[valueName].append(v[valueName]) sortKeys.append(sortKey) except KeyError: pass idx = np.argsort(sortKeys) if reverse: idx = idx[::-1] keys = np.array(keys)[idx] for valueName in valueNames: values[valueName] = np.array(values[valueName])[idx] if len(valueNames) > 0: return keys, tuple([values[valueName] for valueName in valueNames]) else: return keys
[docs] def getPair(self, keyName, valueName, sortBy=None, reverse=False): if sortBy is None: sortBy = keyName keys = [] values = [] sortKeys = [] for label in sorted(self.groups): try: key = getattr(self.groups[label], keyName) value = getattr(self.groups[label], valueName) sortKey = getattr(self.groups[label], sortBy) keys.append(key) values.append(value) sortKeys.append(sortKey) except KeyError: pass idx = np.argsort(sortKeys) if reverse: idx = idx[::-1] keys = np.array(keys)[idx] values = np.array(values)[idx] return keys, values
[docs] def roc(self, keyName, valueName, reverse=False): keys, values = self.get(keyName, [valueName], reverse=reverse) return roc(keys, values).flatten()
[docs] def toDict(self, tested=False): d = {'type': 'series'} for label in self.groups: d[label] = self.groups[label].toDict(tested) return d
[docs] def fromDict(self, d): for label in d: if label != 'type': group = self.addGroup(label) group.fromDict(d[label])
[docs] def getTable(self, keyName, valueNames=[], reverse=False): rocs = [] if isinstance(valueNames, list): assert len(valueNames) > 0 newValueNames = [] for k in range(len(valueNames)): if valueNames[k][:3] == 'roc': rocs.append(('RoC '+valueNames[k][3:], np.concatenate(([None], self.roc(keyName, valueNames[k][3:], reverse=reverse))), k)) else: newValueNames.append(valueNames[k]) valueNames = newValueNames keys, values = self.get(keyName, valueNames, reverse=reverse) values = list(values) for label, value, pos in rocs: valueNames.insert(pos, label) values.insert(pos, value) header = [keyName]+valueNames lines = np.vstack((keys, *values)).T from tabulate import tabulate return tabulate(lines, headers=header)
[docs] def plot(self, keyName, valueName, bnd=None, **kwargs): import matplotlib.pyplot as plt key, (value, ) = self.get(keyName, valueName) plt.plot(key, value, **kwargs) if bnd is not None: assert isinstance(bnd, dict) exponent = bnd.pop('exponent') pos = bnd.pop('pos', None) y = key**exponent if pos is None: pos = value.argmin() y *= value[pos]/y[pos] plt.plot(key, y, **bnd)
[docs] def diff(self, d): result = {} for label in self.groups: p = self.groups[label].diff(d[label]) if len(p) > 0: result[label] = p return result
[docs] class driverArgGroup:
[docs] def __init__(self, parent, group): self.parent = parent self.group = group
[docs] def add(self, *args, **kwargs): if self.parent is not None: kwargs['group'] = self.group self.parent.add(*args, **kwargs)
dependencyLogger = logging.getLogger('Dependencies')
[docs] class driver:
[docs] def __init__(self, comm=None, setCommExitHandler=True, masterRank=0, description=None): self.comm = comm self._identifier = '' self.processHook = [] self.masterRank = masterRank self.isMaster = (self.comm is None or self.comm.rank == self.masterRank) self.argGroups = {} self.outputGroups = {} self._logger = None self._timer = None self._figures = {} self._display_available = None if self.comm is not None and setCommExitHandler: exitHandler(self.comm) if self.isMaster: # self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) from os import environ width = int(environ.get('COLUMNS', 200)) self.parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.ArgumentDefaultsHelpFormatter(prog, max_help_position=50, width=width), description=description) self.mainGroup = self.parser.add_argument_group('main') io = self.addGroup('input/output') io.add('disableHeader', False, help='Disable verbose header') io.add('displayConfig', False, help='Display configuration') io.add('displayRanks', False, help='Display MPI ranks in log') io.add('disableFileLog', False, help='Disable logging to file') io.add('logDependencies', False, help='log dependencies') io.add('disableTimeStamps', False, help='Disable time stamps in log') io.add('showTimers', True, help='Display timers') io.add('showMemory', False, help='Show memory info in timers') io.add('test', False, help='Run in test mode') io.add('yamlInput', '', help='YAML config file') io.add('hdf5Input', '', help='HDF5 config file') io.add('yamlOutput', '', help='YAML output file') io.add('hdf5Output', '', help='HDF5 output file') io.add('testCache', '', help='YAML cache file') io.add('overwriteCache', False, help='Overwrite the test cache file')
[docs] def setIdentifier(self, identifier): self._identifier = identifier
@property def identifier(self): return self._identifier
[docs] def setLogFile(self, filename=None): if filename is None: from pathlib import Path as path logs = path('logs') assert self._identifier != '' filename = logs/(self._identifier+'.log') filename.parent.mkdir(exist_ok=True, parents=True) if self.comm is not None and self.comm.size > 1: fileHandler = MPIFileHandler(filename, self.comm) else: fileHandler = logging.FileHandler(filename, 'w') logging.getLogger().addHandler(fileHandler) fmt = '{asctime} {name:40} {message}' fileHandler.setFormatter(logging.Formatter(fmt, style='{', datefmt="%Y-%m-%d %H:%M:%S"))
[docs] def addGroup(self, name): if name in self.argGroups: return self.argGroups[name] if self.isMaster: argGroup = driverArgGroup(self, self.parser.add_argument_group(name)) self.argGroups[name] = argGroup return argGroup else: return driverArgGroup(None, None)
[docs] def add(self, name, defaultValue=None, acceptedValues=[], help='No help defined', argInterpreter=None, group=None): if self.isMaster: if group is None: group = self.mainGroup if len(acceptedValues) > 0: if defaultValue is None: defaultValue = acceptedValues[0] else: if defaultValue not in acceptedValues: acceptedValues.insert(0, defaultValue) else: acceptedValues = None if isinstance(defaultValue, bool): if defaultValue is True: action = 'store_false' flagname = 'no-'+name elif defaultValue is False: action = 'store_true' flagname = name if len(flagname) == 1: flagname = ('-'+flagname, '--'+flagname) else: flagname = ('--'+flagname, ) group.add_argument(*flagname, action=action, help=help, dest=name) else: if acceptedValues is not None: types = [a for a in acceptedValues if type(a) == type] if len(types) > 0 and argInterpreter is None: acceptedValues2 = [a for a in acceptedValues if type(a) != type] def argInterpreter(s): from ast import literal_eval if s in acceptedValues2: return s for t in types: try: x = literal_eval(s) if isinstance(x, t): return x except Exception as e: print(e) raise argparse.ArgumentTypeError() acceptedValues = None if argInterpreter is None: argInterpreter = type(defaultValue) if len(name) == 1: name = ('-'+name, '--'+name) else: name = ('--'+name, ) group.add_argument(*name, default=defaultValue, type=argInterpreter, choices=acceptedValues, help=help)
[docs] def addPositional(self, name, nargs=1, group=None): if self.isMaster: if group is None: group = self.mainGroup group.add_argument(name, nargs=nargs)
[docs] def addToProcessHook(self, fun): self.processHook.append(fun)
[docs] def process(self, override={}): if self.isMaster: self.parser.set_defaults(**override) doTerminate = False if self.isMaster: if 'plots' in self.argGroups: io = self.addGroup('plots') io.add('plotFolder', '', help='folder for saving plots') io.add('plotFormat', acceptedValues=['pdf', 'png', 'jpeg', 'eps', 'ps', 'svg'], help='File format for saving plots') try: args, unknown = self.parser.parse_known_args() except SystemExit: doTerminate = True if self.comm: doTerminate = self.comm.bcast(doTerminate, root=0) if doTerminate: exit(0) if self.isMaster: if len(unknown) > 0: self.logger.warning('Unknown args: {}'.format(unknown)) params = vars(args) if params['yamlInput'] != '': import yaml yaml_filename = params['yamlInput'] conf = yaml.load(open(yaml_filename, 'r'), Loader=yaml.FullLoader) params.update(conf) if params['hdf5Input'] != '': import h5py hdf5_filename = params['hdf5Input'] f = h5py.File(hdf5_filename, 'r') conf = loadDictFromHDF5(f) f.close() params.update(conf) if params['test']: params['displayConfig'] = True if self.comm: params['mpiGlobalCommSize'] = self.comm.size else: params['mpiGlobalCommSize'] = 1 else: params = {} if self.comm: params = self.comm.bcast(params, root=0) self.params = params self._timer = TimerManager(self.logger, comm=self.comm, memoryProfiling=params['showMemory']) for fun in self.processHook: fun(self.params) if self._identifier != '' and not self.params['disableFileLog']: self.setLogFile() prefix = '' if params['displayRanks'] and self.comm is not None and self.comm.size > 1: prefix = '{}: '.format(self.comm.rank) if not params['disableTimeStamps']: fmt = prefix+'{asctime} {name:40} {message}' else: fmt = prefix+'{name:40} {message}' formatter = logging.Formatter(fmt=fmt, style='{', datefmt="%Y-%m-%d %H:%M:%S") if self.isMaster: logging.getLogger().setLevel(logging.INFO) else: logging.getLogger().setLevel(logging.WARN) for handler in logging.getLogger().handlers: handler.setFormatter(formatter) if params['displayConfig']: from pprint import pformat self.logger.info('\n'+pformat(params)) from sys import argv sysInfo = self.addOutputGroup('sysInfo') getSystemInfo(argv=argv, grp=sysInfo) if not params['disableHeader']: self.logger.info('\n'+str(sysInfo)) if params['logDependencies']: dependencyLogger.setLevel(logging.DEBUG) else: dependencyLogger.setLevel(logging.INFO) return params
[docs] def set(self, key, value): if hasattr(self, 'params'): self.params[key] = value else: raise KeyError
[docs] def getLogger(self): if self._logger is None: import logging if self.isMaster: level = logging.INFO else: level = logging.WARNING fmt = '{asctime} {name:40} {message}' logging.basicConfig(level=level, format=fmt, style='{', datefmt="%Y-%m-%d %H:%M:%S") self._logger = logging.getLogger('__main__') return self._logger
@property def logger(self): return self.getLogger()
[docs] def getTimer(self): return self._timer
@property def timer(self): return self.getTimer()
[docs] def addOutputGroup(self, name, group=None, aTol=None, rTol=None, tested=False): if name in self.outputGroups: group = self.outputGroups[name] else: if group is None: group = outputGroup(tested=tested, aTol=aTol, rTol=rTol, driver=self) self.outputGroups[name] = group assert group.tested == tested assert group.aTol == aTol assert group.rTol == rTol return group
[docs] def addStatsOutputGroup(self, name, group=None, aTol=None, rTol=None, tested=False): if name in self.outputGroups: group = self.outputGroups[name] assert isinstance(group, statisticOutputGroup) return group else: return self.addOutputGroup(name, statisticOutputGroup(comm=self.comm, driver=self))
[docs] def addOutputSeries(self, name, aTol=None, rTol=None, tested=False): group = seriesOutputGroup(name, aTol, rTol, tested, driver=self) group = self.addOutputGroup(name, group, aTol, rTol, tested) return group
[docs] def outputToDict(self, tested=False): d = {} for group in self.outputGroups: d[group] = self.outputGroups[group].toDict(tested=tested) return d
[docs] def timerReport(self): t = self.addOutputGroup('Timers', timerOutputGroup()) self.timer.setOutputGroup(self.masterRank, t) self.logger.info('\n'+str(t))
[docs] def saveOutput(self): if self.isMaster: failAfterOutput = False if self.params['testCache'] != '': try: import yaml cache = yaml.load(open(self.params['testCache'], 'r'), Loader=yaml.FullLoader) diff = {} for name in self.outputGroups: diff[name] = self.outputGroups[name].diff(cache.get(name, {})) if len(diff[name]) == 0: diff.pop(name) from pprint import pformat if len(diff) > 0: if self.params['overwriteCache']: failAfterOutput = True self.params['yamlOutput'] = self.params['testCache'] self.logger.info('No match (observed, expected)\n' + str(pformat(diff))) else: assert False, 'No match (observed, expected)\n' + str(pformat(diff)) else: self.logger.info('\nAll matched') except FileNotFoundError: self.params['yamlOutput'] = self.params['testCache'] failAfterOutput = True if self.params['hdf5Output'] == 'auto' and self._identifier != '': self.params['hdf5Output'] = self._identifier + '.hdf5' if self.params['yamlOutput'] == 'auto' and self._identifier != '': self.params['yamlOutput'] = self._identifier + '.yaml' if self.params['hdf5Output'] != '' or self.params['yamlOutput'] != '': d = self.outputToDict(tested=self.params['test']) if not self.params['test']: d.update(self.params) from pathlib import Path if self.params['hdf5Output'] != '': import h5py self.logger.info('Saving to {}'.format(self.params['hdf5Output'])) Path(self.params['hdf5Output']).parent.mkdir(exist_ok=True, parents=True) f = h5py.File(self.params['hdf5Output'], 'w') saveDictToHDF5(d, f) f.close() if self.params['yamlOutput'] != '': import yaml self.logger.info('Saving to {}'.format(self.params['yamlOutput'])) Path(self.params['yamlOutput']).parent.mkdir(exist_ok=True, parents=True) d = processDictForYaml(d) yaml.dump(d, open(self.params['yamlOutput'], 'w')) assert not failAfterOutput, 'No cache file'
@property def display_available(self): if self._display_available is None: from os import environ available = ("DISPLAY" in environ and "SSH_CONNECTION" not in environ and not self.params['skipPlots']) if available: try: import matplotlib.pyplot as plt except ImportError: self.logger.warn('No Matplotlib') available = False self._display_available = available return self._display_available
[docs] def declareFigure(self, name, description='No help defined', default=True): if self.isMaster: addSkipOption = 'plots' not in self.argGroups plots = self.addGroup('plots') if addSkipOption: plots.add('skipPlots', False, help='Do not plot anything') plots.add('plot_'+name, default, help=description) self._figures[name] = None
[docs] def willPlot(self, name): key = 'plot_'+name return key in self.params and self.params[key] and (self.display_available or self.params['plotFolder'] != '')
[docs] def startPlot(self, name, **kwargs): if self.isMaster: if (('plot_'+name not in self.params or self.params['plot_'+name]) and (self.display_available or self.params['plotFolder'] != '')): import matplotlib.pyplot as plt from . plot_utils import latexOptions MPLconf = latexOptions(**kwargs) plt.rcParams.update(MPLconf) if (name not in self._figures) or (self._figures[name] is None): fig = plt.figure() self._figures[name] = fig plt.get_current_fig_manager().set_window_title(name) else: plt.figure(self._figures[name].number) return self._figures[name] elif 'plot_'+name in self.params: del self._figures[name] return None else: return None
[docs] def savePlot(self, name, filenameSuffix='', **kwargs): if self._figures[name] is not None: if self._identifier != '': filename = self._identifier+'_'+name+filenameSuffix filename = filename.replace('_', '-') filename = filename.replace(' ', '-') filename = filename.replace('=', '') else: filename = name+filenameSuffix self._figures[name].tight_layout() from pathlib import Path Path(self.params['plotFolder']+'/'+filename+'.'+self.params['plotFormat']).parent.mkdir(exist_ok=True, parents=True) self._figures[name].savefig(self.params['plotFolder']+'/'+filename+'.'+self.params['plotFormat'], bbox_inches='tight', **kwargs) else: self.logger.warn('Figure \'{}\' not created'.format(name))
[docs] def finishPlots(self, **kwargs): newFigures = {} for name in self._figures: if self._figures[name] is not None: newFigures[name] = self._figures[name] self._figures = newFigures if len(self._figures) > 0: if self.params['plotFolder'] != '': for name in self._figures: self.savePlot(name, **kwargs) else: import matplotlib.pyplot as plt plt.show()
[docs] def finish(self, **kwargs): t = self.addOutputGroup('Timers', timerOutputGroup()) self.timer.setOutputGroup(self.masterRank, t) if self.params['showTimers'] and self.isMaster: self.logger.info('\n'+str(t)) self.saveOutput() self.finishPlots(**kwargs)
def __getattr__(self, name): if name in self.params: return self.params[name] else: return getattr(self, name)
[docs] def diffDict(d1, d2, aTol, relTol): diff = {} for key in d1: if isinstance(d1[key], dict): if key not in d2: p = diffDict(d1[key], {}, aTol, relTol) if len(p) > 0: diff[key] = p else: p = diffDict(d1[key], d2[key], aTol, relTol) if len(p) > 0: diff[key] = p else: if key not in d2: diff[key] = (d1[key], 'Not available') else: if isinstance(d1[key], (int, INDEX, REAL, float, np.ndarray, list)): if not np.allclose(d1[key], d2[key], rtol=relTol, atol=aTol): diff[key] = (d1[key], d2[key]) elif d1[key] != d2[key]: diff[key] = (d1[key], d2[key]) for key in d2: if isinstance(d2[key], dict): if key not in d1: p = diffDict({}, d2[key], aTol, relTol) if len(p) > 0: diff[key] = p else: if key not in d1: diff[key] = ('Not available', d2[key]) return diff
[docs] def runDriver(path, py, python=None, timeout=900, ranks=None, cacheDir='', overwriteCache=False, aTol=1e-12, relTol=1e-2, extra=None): from subprocess import Popen, PIPE, TimeoutExpired import logging import os from pathlib import Path logger = logging.getLogger('__main__') if not isinstance(py, (list, tuple)): py = [py] autotesterOutput = Path('/home/caglusa/autotester/html') if autotesterOutput.exists(): plotDir = autotesterOutput/('test-plots/'+''.join(py)+'/') else: extra = None if cacheDir != '': cache = cacheDir+'/cache_' + ''.join(py) runOutput = cacheDir+'/run_' + ''.join(py) if ranks is not None: cache += str(ranks) runOutput += str(ranks) py += ['--test', '--testCache={}'.format(cache)] if 'OVERWRITE_CACHE' in os.environ: overwriteCache = True if overwriteCache: py += ['--overwriteCache'] else: py += ['--test'] py += ['--disableFileLog'] if extra is not None: plotDir.mkdir(exist_ok=True, parents=True) py += ['--plotFolder={}'.format(plotDir), '--plotFormat=png'] else: py += ['--skipPlots'] assert (Path(path)/py[0]).exists(), 'Driver \"{}\" does not exist'.format(Path(path)/py[0]) if ranks is None: ranks = 1 if python is None: import sys python = sys.executable cmd = [python] + py if 'MPIEXEC_FLAGS' in os.environ: mpi_flags = str(os.environ['MPIEXEC_FLAGS']) else: mpi_flags = '--bind-to none' cmd = ['mpiexec'] + mpi_flags.split(' ') + ['-n', str(ranks)]+cmd logger.info('Launching "{}" from "{}"'.format(' '.join(cmd), path)) my_env = {} for key in os.environ: if key.find('OMPI') == -1: my_env[key] = os.environ[key] proc = Popen(cmd, cwd=path, stdout=PIPE, stderr=PIPE, universal_newlines=True, env=my_env) try: stdout, stderr = proc.communicate(timeout=timeout) except TimeoutExpired: proc.kill() raise if len(stdout) > 0: logger.info(stdout) if len(stderr) > 0: logger.error(stderr) assert proc.returncode == 0, stderr+'\n\n'+stdout if extra is not None: from pytest_html import extras for img in plotDir.glob('*.png'): filename = img.relative_to(autotesterOutput) filename = 'http://geminga.sandia.gov:8080/'+str(filename) extra.append(extras.png(str(filename)))
[docs] class parametrizedArg:
[docs] def __init__(self, name, params=[]): if not isinstance(params, (list, tuple)): params = [params] self.name = name self.params = params fields = [] for p in self.params: if p == str: fields.append('[a-zA-Z]+') elif p == int: fields.append('[+-]?[0-9]+') elif p == float: fields.append('[+-]?[0-9]+\.[0-9]*') elif p == bool: fields.append('True|False') else: raise NotImplementedError() self.regexp = re.compile(name+'\(?'+','.join(['\s*(' + f + ')\s*' for f in fields])+'\)?')
[docs] def match(self, s): return self.regexp.match(s) is not None
[docs] def interpret(self, s): m = self.regexp.match(s) return [p(v) for v, p in zip(m.groups(), self.params)]
def __repr__(self): params = [] for p in self.params: params.append(p.__name__) return "{}({})".format(self.name, ','.join(params))
[docs] class propertyBuilder:
[docs] def __init__(self, baseObj, fun): self.baseObj = baseObj self.fun = fun self.baseName = fun.__name__ self.requiredProperties = inspect.getfullargspec(self.fun).args[1:] self.generatedProperties = set() self.cached_args = {} self.logLevel = logging.DEBUG
[docs] def declareGeneratedProperty(self, prop): self.generatedProperties.add(prop)
[docs] def declareGeneratedProperties(self, props): for prop in props: self.declareGeneratedProperty(prop)
[docs] def __call__(self): from PyNucleus_fem.DoFMaps import fe_vector, complex_fe_vector cached_args = {} args = [] needToBuild = False for prop in self.requiredProperties: try: newValue = getattr(self.baseObj, prop) oldValue = self.cached_args.get(prop, None) args.append(newValue) # TODO: keep hash? try: if isinstance(newValue, np.ndarray): cached_args[prop] = newValue.copy() if (newValue != oldValue).any(): dependencyLogger.log(self.logLevel, 'Values for {} differ: \'{}\' != \'{}\', calling \'{}\''.format(prop, oldValue, newValue, self.fun.__name__)) needToBuild = True else: dependencyLogger.log(self.logLevel, 'Values for {} are identical: \'{}\' == \'{}\''.format(prop, oldValue, newValue)) elif isinstance(newValue, (fe_vector, complex_fe_vector)): cached_args[prop] = newValue.copy() if oldValue is None or (newValue.toarray() != oldValue.toarray()).any(): dependencyLogger.log(self.logLevel, 'Values for {} differ: \'{}\' != \'{}\', calling \'{}\''.format(prop, oldValue, newValue, self.fun.__name__)) needToBuild = True else: dependencyLogger.log(self.logLevel, 'Values for {} are identical: \'{}\' == \'{}\''.format(prop, oldValue, newValue)) elif newValue != oldValue: cached_args[prop] = newValue dependencyLogger.log(self.logLevel, 'Values for {} differ: \'{}\' != \'{}\', calling \'{}\''.format(prop, oldValue, newValue, self.fun.__name__)) needToBuild = True else: dependencyLogger.log(self.logLevel, 'Values for {} are identical: \'{}\' == \'{}\''.format(prop, oldValue, newValue)) except Exception as e: dependencyLogger.log(logging.WARN, 'Cannot compare values {}, {} for property \'{}\', exception {}, force call \'{}\''.format(oldValue, newValue, prop, e, self.fun.__name__)) needToBuild = True except AttributeError: raise AttributeError('Method \'{}\' has unsatisfied dependency on \'{}\''.format(self.fun.__name__, prop)) if needToBuild: self.cached_args = cached_args self.fun(*args) else: dependencyLogger.log(self.logLevel, 'Skipping call to \'{}\''.format(self.fun.__name__)) for prop in self.generatedProperties: self.baseObj.setState(prop, VALID)
VALID = 0 INVALID = 1
[docs] def generates(properties): def wrapper(fun): clsName = fun.__qualname__.split('.')[-2] if not isinstance(properties, (list, tuple, set)): props = [properties] else: props = properties try: generatorMethodsToProperties[clsName] except KeyError: generatorMethodsToProperties[clsName] = {} generatorMethodsToProperties[clsName][fun] = props return fun return wrapper
generatorMethodsToProperties = {}
[docs] class classWithComputedDependencies: generates = generates
[docs] def __init__(self): # maps properties to the builder functions that depend on them self.requiredPropToBuilder = {} # maps builder functions to their produced properties self.builderToProps = {} # map properties to their builder self.generatedPropToBuilder = {} self.remoteRequiredProperties = {} self.remoteGeneratedProperties = {} self.methodToBuilder = {} clsNames = [] classes = [self.__class__] while len(classes) > 0: newClasses = [] for cls in classes: clsNames.append(cls.__name__) for parent in cls.__bases__: if parent not in (classWithComputedDependencies, problem): newClasses.append(parent) classes = newClasses registeredGenerators = set() for clsName in clsNames: if clsName in generatorMethodsToProperties: for clsMethod in generatorMethodsToProperties[clsName]: methodName = clsMethod.__name__ if methodName in dir(self): if methodName in registeredGenerators: continue registeredGenerators.add(methodName) method = getattr(self, methodName) props = generatorMethodsToProperties[clsName][clsMethod] self.addProperties(props, method)
[docs] def getState(self, prop): stateFlag = '__state_'+prop return getattr(self, stateFlag)
[docs] def setState(self, prop, state): stateFlag = '__state_'+prop setattr(self, stateFlag, state)
[docs] def getValue(self, prop): valueFlag = '__value_'+prop return getattr(self, valueFlag)
[docs] def setValue(self, prop, value): valueFlag = '__value_'+prop setattr(self, valueFlag, value)
[docs] def directlyGetWithoutChecks(self, prop): return self.getValue(prop)
[docs] def directlySetWithoutChecks(self, prop, value): self.setState(prop, VALID) self.setValue(prop, value)
@property def allProperties(self): return set(self.generatedPropToBuilder.keys()) | set(self.requiredPropToBuilder.keys())
[docs] def addProperty(self, prop, method=None, postProcess=None): """Adds property that is generated by calling "method" The generated value is then cached.""" if method is not None: assert inspect.ismethod(method), "Need the builder to be a method." if method in self.methodToBuilder: builder = self.methodToBuilder[method] else: builder = propertyBuilder(self, method) self.methodToBuilder[method] = builder try: self.builderToProps[builder].add(prop) except KeyError: self.builderToProps[builder] = set() self.builderToProps[builder].add(prop) builder.declareGeneratedProperty(prop) for reqProp in builder.requiredProperties: try: self.requiredPropToBuilder[reqProp].add(builder) except KeyError: self.requiredPropToBuilder[reqProp] = set() self.requiredPropToBuilder[reqProp].add(builder) else: builder = None self.generatedPropToBuilder[prop] = builder self.setValue(prop, None) self.setState(prop, INVALID) def getter(self): builder = self.generatedPropToBuilder[prop] state = self.getState(prop) if builder is not None: if state == INVALID: dependencyLogger.log(builder.logLevel, 'Calling \'{}\''.format(builder.baseName)) builder() assert self.getState(prop) == VALID, 'Calling \'{}\' did not result in \'{}\' being set.'.format(builder.baseName, prop) else: assert state == VALID, 'Property \'{}\' needs to be set before it can be used.'.format(prop) return self.getValue(prop) def setter(self, value): # TODO: check whether we set from builder or not and don't allow setting generated properties directly dependencyLogger.log(logging.DEBUG, 'Setting \'{}\''.format(prop)) self.invalidateDependencies(prop) if prop in self.remoteGeneratedProperties: for obj in self.remoteGeneratedProperties[prop]: obj.invalidateDependencies(prop) self.setValue(prop, value) self.setState(prop, VALID) if postProcess is not None: postProcess() setattr(self.__class__, prop, property(getter, setter))
[docs] def invalidateDependencies(self, prop): """Invalidate all properties that get generated by builders that depend on "prop".""" if prop in self.requiredPropToBuilder: for builder in self.requiredPropToBuilder[prop]: for genProp in builder.generatedProperties: if self.getState(genProp) == VALID: dependencyLogger.log(builder.logLevel, 'Changing {}.{} => invalidating {}.{}'.format(self.__class__.__name__, prop, self.__class__.__name__, genProp)) self.setState(genProp, INVALID) self.invalidateDependencies(genProp) if prop in self.remoteGeneratedProperties: for obj in self.remoteGeneratedProperties[prop]: obj.invalidateDependencies(prop)
[docs] def addProperties(self, properties, method=None, postProcess=None): """Adds multiple properties that are generated by calling "method".""" for prop in properties: self.addProperty(prop, method, postProcess=postProcess)
[docs] def addRemoteRequiredProperty(self, obj, prop): """Adds a property of another object to self.""" self.remoteRequiredProperties[prop] = obj obj.addRemoteGeneratedProperty(self, prop) def getter(self): obj = self.remoteRequiredProperties[prop] return getattr(obj, prop) # def setter(self, value): # setattr(obj, prop, value) # setattr(cls, prop, property(getter, setter)) setattr(self.__class__, prop, property(getter))
[docs] def addRemoteRequiredProperties(self, obj, properties): """Adds multiple properties of another object.""" for prop in properties: self.addRemoteRequiredProperty(obj, prop)
[docs] def addRemoteGeneratedProperty(self, obj, prop): try: self.remoteGeneratedProperties[prop].add(obj) except KeyError: self.remoteGeneratedProperties[prop] = set([obj])
[docs] def addRemote(self, obj): for prop in obj.allProperties: self.addRemoteRequiredProperty(obj, prop)
[docs] def getGraph(self, includeRemote=False): import networkx as nx G = nx.DiGraph() for prop in self.allProperties: G.add_node(prop, color='blue') for builder in self.builderToProps: G.add_node(builder.baseName, color='green') for reqProp in builder.requiredProperties: G.add_edge(reqProp, builder.baseName) for prop in builder.generatedProperties: G.add_edge(builder.baseName, prop) if includeRemote: objs = set() for remoteProp in self.remoteRequiredProperties: G.add_node(remoteProp, color='blue') objs.add(self.remoteRequiredProperties[remoteProp]) for remoteProp in self.remoteGeneratedProperties: G.add_node(remoteProp, color='blue') for obj in self.remoteGeneratedProperties[remoteProp]: objs.add(obj) for obj in objs: G = nx.compose(G, obj.getGraph()) return G
[docs] def plot(self, includeRemote=False, layout='dot'): import networkx as nx import matplotlib.pyplot as plt G = self.getGraph(includeRemote) colors = [node[1]['color'] for node in G.nodes(data=True)] nx.draw(G, with_labels=True, # node_size=0, arrowsize=20, node_color=colors, pos=nx.nx_agraph.pygraphviz_layout(G, layout)) plt.show()
[docs] def changeLogLevel(self, properties, logLevel): for prop in properties: try: builder = self.generatedPropToBuilder[prop] builder.logLevel = logLevel except KeyError: pass try: for builder in self.requiredPropToBuilder[prop]: builder.logLevel = logLevel except KeyError: pass
[docs] class driverAddon:
[docs] def __init__(self, driver): self.driver = driver self._timer = None self.__parametrized_args__ = {} self.flags = [] self.setDriverArgs() self.driver.addToProcessHook(self.process) try: self.driver.addGroup('input/output').add('identifier', 'auto', help='identifier used as prefix for output ("auto" sets the identifier based on invocation)') except argparse.ArgumentError: pass
[docs] def addParametrizedArg(self, name, params=[]): self.__parametrized_args__[name] = parametrizedArg(name, params)
[docs] def parametrizedArg(self, name): return self.__parametrized_args__[name]
[docs] def argInterpreter(self, parametrizedArgs, acceptedValues=[]): from argparse import ArgumentTypeError def interpreter(v): if v in acceptedValues: return v for p in parametrizedArgs: if self.parametrizedArg(p).match(v): return v raise ArgumentTypeError(("\"{}\" is not in list of accepted values {} " + "or cannot be interpreted as parametrized arg {}.").format(v, acceptedValues, [repr(self.parametrizedArg(p)) for p in parametrizedArgs])) return interpreter
[docs] def conversionInterpreter(self): def interpreter(v): print(v) try: return int(v) except: pass try: return float(v) except: pass return v return interpreter
[docs] def setDriverFlag(self, *args, **kwargs): flag = args[0] self.addProperty(flag) if 'group' in kwargs: group = kwargs.pop('group') group.add(*args, **kwargs) else: self.driver.add(*args, **kwargs) self.flags.append(flag)
[docs] def setDriverArgs(self): pass
@property def timer(self): if self._timer is None: self._timer = self.driver.getTimer() return self._timer
[docs] def processCmdline(self, params): pass
[docs] def getIdentifier(self, params): from sys import argv identifier = '-'.join(argv) pos = identifier.rfind('/') if pos >= 0: return identifier[pos+1:] else: return identifier
[docs] def process(self, params): self.processCmdline(params) if params['identifier'] == 'auto': self.driver.setIdentifier(self.getIdentifier(params)) else: self.driver.setIdentifier(params['identifier'])
[docs] class problem(classWithComputedDependencies, driverAddon):
[docs] def __init__(self, driver): classWithComputedDependencies.__init__(self) driverAddon.__init__(self, driver) try: self.driver.addGroup('input/output').add('showDependencyGraph', False, help="Show dependency graph of problem classes.") self.driver.addGroup('input/output').add('logProperties', '', help='log select properties') except argparse.ArgumentError: pass
[docs] def processCmdline(self, params): driverAddon.processCmdline(self, params) for key in self.flags: self.directlySetWithoutChecks(key, params[key]) for key in self.flags: params[key] = getattr(self, key) if params['logProperties'] != '': propertiesToLog = params['logProperties'].split(',') self.changeLogLevel(propertiesToLog, logging.INFO)
[docs] def process(self, params): driverAddon.process(self, params) if self.driver.showDependencyGraph: if self.driver.isMaster: classWithComputedDependencies.plot(self, True) if self.driver.comm is not None: self.driver.comm.Barrier() exit(0)
def __repr__(self): return str(type(self))
# lines = [(label, '{}', e) for label, e in self.__values__.items()] # return columns(lines)