Coverage for src\pytribeam\utilities.py: 78%
232 statements
« prev ^ index » next coverage.py v7.5.1, created at 2025-03-04 17:41 -0800
« prev ^ index » next coverage.py v7.5.1, created at 2025-03-04 17:41 -0800
1#!/usr/bin/python3
3# Default python modules
4import os
5from pathlib import Path
6import time
7import warnings
8import math
9from typing import Dict, NamedTuple, Tuple, Any, List
10from enum import Enum, IntEnum
11import platform
12import pytest
13from functools import singledispatch
14import shutil
16# # Autoscript modules
17import yaml
18import contextlib
19import sys
20from pandas import json_normalize
22# # # 3rd party module
23# from schema import Schema, And, Use, Optional, SchemaError
25# # Local scripts
26import pytribeam.types as tbt
28# import pytribeam.constants as cs
29from pytribeam.constants import Constants
32@singledispatch
33def beam_type(beam) -> property:
34 """Returns beam property object as ion and electron beams have same internal hierarchy"""
35 _ = beam # no operation
36 raise NotImplementedError()
39@beam_type.register
40def _(beam: tbt.ElectronBeam, microscope: tbt.Microscope) -> property:
41 """Returns electron beam property object"""
42 return microscope.beams.electron_beam
45@beam_type.register
46def _(beam: tbt.IonBeam, microscope: tbt.Microscope) -> property:
47 """Returns ion beam property object"""
48 return microscope.beams.ion_beam
51def connect_microscope(
52 microscope: tbt.Microscope,
53 quiet_output: bool = True,
54 connection_host: str = None,
55 connection_port: int = None,
56):
57 """Connects to the microscope with option to suppress printout"""
59 # TODO clean up inner function
60 def connect(
61 microscope: tbt.Microscope,
62 connection_host: str = None,
63 connection_port: int = None,
64 ) -> bool:
65 if connection_port is not None:
66 microscope.connect(connection_host, connection_port)
67 elif connection_host is not None:
68 microscope.connect(connection_host)
69 else:
70 microscope.connect()
72 if quiet_output:
73 with nostdout():
74 connect(
75 microscope=microscope,
76 connection_host=connection_host,
77 connection_port=connection_port,
78 )
79 else:
80 connect(
81 microscope=microscope,
82 connection_host=connection_host,
83 connection_port=connection_port,
84 )
86 if microscope.server_host is not None:
87 return True
88 else:
89 raise ConnectionError(
90 f"Connection failed with connection_host of '{connection_host}' and connection_port of '{connection_port}' microscope not connected."
91 )
94def dict_to_yml(db: dict, file_path: Path) -> Path:
95 """
96 Converts dict to yaml
97 """
98 with open(file_path, "w", encoding="utf-8") as out_file:
99 yaml.dump(
100 db,
101 out_file,
102 default_flow_style=False,
103 sort_keys=False,
104 )
106 return file_path
109def disconnect_microscope(
110 microscope: tbt.Microscope,
111 quiet_output: bool = True,
112):
113 """Disconnects from the microscope with option to suppress printout"""
114 if quiet_output:
115 with nostdout():
116 microscope.disconnect()
117 else:
118 microscope.disconnect()
120 if microscope.server_host is None:
121 return True
122 else:
123 raise ConnectionError("Disconnection failed, microscope still connected")
126def general_settings(exp_settings: dict, yml_format: tbt.YMLFormat) -> dict:
127 """Grabs general experiment settings from a .yml and returns them as a dictionary"""
128 general_key = yml_format.general_section_key
129 return exp_settings[general_key]
132def step_type(settings: dict, yml_format: tbt.YMLFormat) -> tbt.StepType:
133 """determine step type for an specific step settings dictioanry"""
134 step_type = tbt.StepType(
135 settings[yml_format.step_general_key][yml_format.step_type_key]
136 )
138 return step_type
141def in_interval(val: float, limit: tbt.Limit, type: tbt.IntervalType) -> bool:
142 """Tests where a value is within an interval, with interval type (close, open, half-open, etc. defined by enumerated IntervalType)
144 Args:
145 val: The input value to be compared to against min and max.
146 limit: The bounds of the interval
147 type: The type of interval
149 Returns
150 True if winthin interval, False otherwise"""
151 if type == tbt.IntervalType.OPEN:
152 return (val > limit.min) and (val < limit.max)
153 if type == tbt.IntervalType.CLOSED:
154 return (val >= limit.min) and (val <= limit.max)
155 if type == tbt.IntervalType.LEFT_OPEN:
156 return (val > limit.min) and (val <= limit.max)
157 if type == tbt.IntervalType.RIGHT_OPEN:
158 return (val >= limit.min) and (val < limit.max)
161def gen_dict_extract(key, var):
162 if hasattr(var, "items"):
163 for k, v in var.items():
164 if k == key:
165 yield v
166 if isinstance(v, dict):
167 for result in gen_dict_extract(key, v):
168 yield result
169 elif isinstance(v, list):
170 for d in v:
171 for result in gen_dict_extract(key, d):
172 yield result
173 # return v, result
176# def list_enum_to_string(list: List[Enum]) -> List[str]:
177# """Converts list of enum to strings"""
178# return [str(i.value) for i in list]
181def nested_dictionary_location(d: dict, key: str, value: Any) -> List[str]:
182 """Finds nested location of key-value pair in dictionary, returns a list of key values
183 from highest to lowest level of nested dictionaries. Checks if key value pair is found
184 """
185 nesting = nested_find_key_value_pair(d=d, key=key, value=value)
186 if nesting is None:
187 raise KeyError(
188 f'Key : value pair of "{key} : {value}" not found in the provided dictionary.'
189 )
190 return nesting
193def nested_find_key_value_pair(d: dict, key: str, value: Any) -> List[str]:
194 """Finds key value pair in nested dictionary, returns a list of key
195 values from highest to lowest level of nested dictionaries"""
196 for k, v in d.items():
197 if k == key:
198 if v == value:
199 return [k]
200 if isinstance(v, dict):
201 p = nested_find_key_value_pair(v, key, value)
202 if p:
203 return [k] + p
206def _flatten(dictionary: dict) -> dict:
207 """Flattens a dictionary using pandas, which can be slow on large dictionaries.
209 From https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys
210 """
211 data_frame = json_normalize(dictionary, sep="_")
212 db_flat = data_frame.to_dict(orient="records")[0]
213 return db_flat
216def none_value_dictionary(dictionary: dict) -> bool:
217 """Return true if all values in dictionary are None, false otherwise."""
218 # flatten the dictionary first
219 db_flat = _flatten(dictionary)
220 return all([v is None for v in db_flat.values()])
223@contextlib.contextmanager
224def nostdout():
225 """Creates dummy file to suppress output"""
226 save_stdout = sys.stdout
227 sys.stdout = tbt.DummyFile()
228 yield
229 sys.stdout = save_stdout
232def step_count(
233 exp_settings: dict,
234 yml_format: tbt.YMLFormatVersion,
235):
236 """Determine maximum step number from a settings dictionary, as specified by the step_number_key"""
238 step_number_key = yml_format.step_number_key
239 non_step_sections = yml_format.non_step_section_count
241 # make sure dict from yml has correct section count
242 # (steps should all be in one section)
243 total_sections = len(exp_settings)
244 if total_sections != non_step_sections + 1:
245 raise ValueError(
246 f"Invalid .yml file, {total_sections} sections were found but the input .yml should have {non_step_sections + 1} sections. Please verify that all top-level keys in the .yml have unique strings and that all steps are contained in a single top-level section."
247 )
249 expected_step_count = exp_settings[yml_format.general_section_key][
250 yml_format.step_count_key
251 ]
253 found_step_count = 0
254 while True:
255 try:
256 nested_dictionary_location(
257 d=exp_settings,
258 key=step_number_key,
259 value=found_step_count + 1,
260 )
261 except KeyError:
262 break
263 found_step_count += 1
265 # validate number of steps found with steps read by YAML loader
266 # TODO YAML safeloader will ignore duplicate top level keys, so this check relies on unique step numbers in ascending order (no gaps) to be found.
268 if expected_step_count != found_step_count:
269 raise ValueError(
270 f"Invalid .yml file, {found_step_count} steps were found but the input .yml should have {expected_step_count} steps from the general setting key '{yml_format.step_count_key}' within the '{yml_format.general_section_key}' section. Please verify that all step_name keys in the .yml have unique strings and that step numbers are continuously-increasing positive integers starting at 1."
271 )
273 return found_step_count
276def step_settings(
277 exp_settings: dict,
278 step_number_key: str,
279 step_number_val: int,
280 yml_format: tbt.YMLFormatVersion,
281) -> Tuple[str, dict]:
282 """Grabs specific step settings from an experimental dictionary and
283 returns them as a dictionary along with the user-defined step name"""
285 nested_locations = nested_dictionary_location(
286 d=exp_settings,
287 key=step_number_key,
288 value=step_number_val,
289 )
290 ### top level dictionary key name is first index, need key name nested within it (second level, index = 1)
291 step_name = nested_locations[1]
292 step_section_key = yml_format.step_section_key
293 return step_name, exp_settings[step_section_key][step_name]
296def valid_microscope_connection(host: str, port: str) -> bool:
297 """Determines if microscope connection can be made, disconnects if a connection can be made"""
298 microscope = tbt.Microscope()
299 if connect_microscope(
300 microscope=microscope,
301 quiet_output=True,
302 connection_host=host,
303 connection_port=port,
304 ):
305 if disconnect_microscope(
306 microscope=microscope,
307 quiet_output=True,
308 ):
309 return True
310 return False
313def enable_external_device(oem: tbt.ExternalDeviceOEM) -> bool:
314 """Determines whether to enable External Device Control.
315 Device must be a member of the ExternalDeviceOEM enum and not equal to ExternalDeviceOEM.NONE to enable control
316 """
317 if not isinstance(oem, tbt.ExternalDeviceOEM):
318 raise NotImplementedError(
319 f"Unsupported type of {type(oem)}, only 'ExternalDeviceOEM' types are supported."
320 )
321 if oem != tbt.ExternalDeviceOEM.NONE:
322 return True
323 return False
326def valid_enum_entry(obj: Any, check_type: Enum) -> bool:
327 """Determines if object is member of an Enum class"""
328 return obj in check_type._value2member_map_
331def yml_format(version: float) -> tbt.YMLFormatVersion:
332 """returns YMLFile format for a given version"""
333 supported_versions = [file.version for file in tbt.YMLFormatVersion]
334 if not version in supported_versions:
335 raise NotImplementedError(
336 f'Unsupported YML file version for version "{version}". Valid formats include: {[i.value for i in tbt.YMLFormatVersion]}'
337 )
338 yml_file_idx = supported_versions.index(version)
339 yml_format = list(tbt.YMLFormatVersion)[yml_file_idx]
340 return yml_format
343def yml_to_dict(
344 *, yml_path_file: Path, version: float, required_keys: Tuple[str, ...]
345) -> Dict:
346 """Given a valid Path to a yml input file, read it in and return the
347 result as a dictionary.
349 Args:
350 yml_path_file: The fully pathed location to the input file.
351 version: The version of the yml in x.y format.
352 required_keys: The key(s) that must be in the yml file for conversion
353 to a dictionary to occur.
355 Returns:
356 The .yml file represented as a dictionary.
357 """
359 # Compared to the lower() method, the casefold() method is stronger.
360 # It will convert more characters into lower case, and will find more
361 # matches on comparison of two strings that are both are converted
362 # using the casefold() method.
363 file_type = yml_path_file.suffix.casefold()
365 supported_types = (".yaml", ".yml")
367 if file_type not in supported_types:
368 raise TypeError("Only file types .yaml, and .yml are supported.")
370 try:
371 with open(file=yml_path_file, mode="r", encoding="utf-8") as stream:
372 # See deprecation warning for plain yaml.load(input) at
373 # https://github.com/yaml/pyyaml/wiki/PyYAML-yaml.load(input)-Deprecation
374 db = yaml.load(stream, Loader=yaml.SafeLoader)
375 except yaml.YAMLError as error:
376 print(f"Error with YAML file: {error}")
377 # print(f"Could not open: {self.self.path_file_in}")
378 print(f"Could not open or decode: {yml_path_file}")
379 # raise yaml.YAMLError
380 raise OSError from error
382 # check keys found in input file against required keys
383 found_keys = tuple(db.keys())
384 keys_exist = tuple(map(lambda x: x in found_keys, required_keys))
385 has_required_keys = all(keys_exist)
386 if not has_required_keys:
387 raise KeyError(f"Input files must have these keys defined: {required_keys}")
389 version_specified = db["config_file_version"]
390 version_requested = version
392 if version_specified != version_requested:
393 ee = f"Version mismatch: specified in file was {version_specified},"
394 ee += f"requested is {version_requested}"
395 raise ValueError(ee)
397 return db
400def yml_version(
401 file: Path,
402 key_name="config_file_version",
403) -> float:
404 """Returns version of yml file if proper key exists"""
405 with open(file, "r") as stream:
406 data = yaml.load(stream, Loader=yaml.SafeLoader)
408 try:
409 version = data[key_name]
410 except KeyError:
411 # print(f"Error with version key: {error}")
412 raise KeyError(f"Error with version key, '{key_name}' key not found in {file}.")
413 try:
414 version = float(version)
415 except ValueError:
416 raise ValueError(
417 f"Could not find valid version in {file} for key {key_name}, found '{version}' which is not a float."
418 )
419 return version
422def yes_no(question):
423 """Simple Yes/No Function."""
424 prompt = f"{question} (y/n): "
425 ans = input(prompt).strip().lower()
426 if ans not in ["y", "n"]:
427 print(f"{ans} is invalid, please try again...")
428 return yes_no(question)
429 if ans == "y":
430 return True
431 return False
434def remove_directory(directory: Path):
435 """Recursively remove a directory"""
436 shutil.rmtree(directory)
439def split_list(data: List, chunk_size: int) -> List:
440 """split list into equal sized chunks"""
441 result = []
442 for i in range(0, len(data), chunk_size):
443 result.append(data[i : i + chunk_size])
444 return result
447def tabular_list(
448 data: List,
449 num_columns: int = Constants.default_column_count,
450 column_width: int = Constants.default_column_width,
451) -> str:
452 rows = split_list(data, chunk_size=num_columns)
453 result = ""
454 for sublist in rows:
455 result += "\n"
456 for item in sublist:
457 result += f"{item:^{column_width}}"
458 return result
461### Custom Decorators ###
464def hardware_movement(func):
465 @run_on_microscope_machine
466 def wrapper_func():
467 if not Constants.test_hardware_movement:
468 pytest.skip("Run only when hardware testing is enabled")
469 func()
471 return wrapper_func
474def run_on_standalone_machine(func):
475 def wrapper_func():
476 current_machine = platform.uname().node.lower()
477 test_machines = [machine.lower() for machine in Constants().offline_machines]
478 if current_machine not in test_machines:
479 pytest.skip("Run on Offline License Machine Only.")
480 func()
482 return wrapper_func
485def run_on_microscope_machine(func):
486 def wrapper_func():
487 current_machine = platform.uname().node.lower()
488 test_machines = [machine.lower() for machine in Constants().microscope_machines]
489 if current_machine not in test_machines:
490 pytest.skip("Run on Microscope Machine Only.")
491 func()
493 return wrapper_func