Coverage for src\pytribeam\utilities.py: 78%

232 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2025-03-04 17:41 -0800

1#!/usr/bin/python3 

2 

3# Default python modules 

4import os 

5from pathlib import Path 

6import time 

7import warnings 

8import math 

9from typing import Dict, NamedTuple, Tuple, Any, List 

10from enum import Enum, IntEnum 

11import platform 

12import pytest 

13from functools import singledispatch 

14import shutil 

15 

16# # Autoscript modules 

17import yaml 

18import contextlib 

19import sys 

20from pandas import json_normalize 

21 

22# # # 3rd party module 

23# from schema import Schema, And, Use, Optional, SchemaError 

24 

25# # Local scripts 

26import pytribeam.types as tbt 

27 

28# import pytribeam.constants as cs 

29from pytribeam.constants import Constants 

30 

31 

32@singledispatch 

33def beam_type(beam) -> property: 

34 """Returns beam property object as ion and electron beams have same internal hierarchy""" 

35 _ = beam # no operation 

36 raise NotImplementedError() 

37 

38 

39@beam_type.register 

40def _(beam: tbt.ElectronBeam, microscope: tbt.Microscope) -> property: 

41 """Returns electron beam property object""" 

42 return microscope.beams.electron_beam 

43 

44 

45@beam_type.register 

46def _(beam: tbt.IonBeam, microscope: tbt.Microscope) -> property: 

47 """Returns ion beam property object""" 

48 return microscope.beams.ion_beam 

49 

50 

51def connect_microscope( 

52 microscope: tbt.Microscope, 

53 quiet_output: bool = True, 

54 connection_host: str = None, 

55 connection_port: int = None, 

56): 

57 """Connects to the microscope with option to suppress printout""" 

58 

59 # TODO clean up inner function 

60 def connect( 

61 microscope: tbt.Microscope, 

62 connection_host: str = None, 

63 connection_port: int = None, 

64 ) -> bool: 

65 if connection_port is not None: 

66 microscope.connect(connection_host, connection_port) 

67 elif connection_host is not None: 

68 microscope.connect(connection_host) 

69 else: 

70 microscope.connect() 

71 

72 if quiet_output: 

73 with nostdout(): 

74 connect( 

75 microscope=microscope, 

76 connection_host=connection_host, 

77 connection_port=connection_port, 

78 ) 

79 else: 

80 connect( 

81 microscope=microscope, 

82 connection_host=connection_host, 

83 connection_port=connection_port, 

84 ) 

85 

86 if microscope.server_host is not None: 

87 return True 

88 else: 

89 raise ConnectionError( 

90 f"Connection failed with connection_host of '{connection_host}' and connection_port of '{connection_port}' microscope not connected." 

91 ) 

92 

93 

94def dict_to_yml(db: dict, file_path: Path) -> Path: 

95 """ 

96 Converts dict to yaml 

97 """ 

98 with open(file_path, "w", encoding="utf-8") as out_file: 

99 yaml.dump( 

100 db, 

101 out_file, 

102 default_flow_style=False, 

103 sort_keys=False, 

104 ) 

105 

106 return file_path 

107 

108 

109def disconnect_microscope( 

110 microscope: tbt.Microscope, 

111 quiet_output: bool = True, 

112): 

113 """Disconnects from the microscope with option to suppress printout""" 

114 if quiet_output: 

115 with nostdout(): 

116 microscope.disconnect() 

117 else: 

118 microscope.disconnect() 

119 

120 if microscope.server_host is None: 

121 return True 

122 else: 

123 raise ConnectionError("Disconnection failed, microscope still connected") 

124 

125 

126def general_settings(exp_settings: dict, yml_format: tbt.YMLFormat) -> dict: 

127 """Grabs general experiment settings from a .yml and returns them as a dictionary""" 

128 general_key = yml_format.general_section_key 

129 return exp_settings[general_key] 

130 

131 

132def step_type(settings: dict, yml_format: tbt.YMLFormat) -> tbt.StepType: 

133 """determine step type for an specific step settings dictioanry""" 

134 step_type = tbt.StepType( 

135 settings[yml_format.step_general_key][yml_format.step_type_key] 

136 ) 

137 

138 return step_type 

139 

140 

141def in_interval(val: float, limit: tbt.Limit, type: tbt.IntervalType) -> bool: 

142 """Tests where a value is within an interval, with interval type (close, open, half-open, etc. defined by enumerated IntervalType) 

143 

144 Args: 

145 val: The input value to be compared to against min and max. 

146 limit: The bounds of the interval 

147 type: The type of interval 

148 

149 Returns 

150 True if winthin interval, False otherwise""" 

151 if type == tbt.IntervalType.OPEN: 

152 return (val > limit.min) and (val < limit.max) 

153 if type == tbt.IntervalType.CLOSED: 

154 return (val >= limit.min) and (val <= limit.max) 

155 if type == tbt.IntervalType.LEFT_OPEN: 

156 return (val > limit.min) and (val <= limit.max) 

157 if type == tbt.IntervalType.RIGHT_OPEN: 

158 return (val >= limit.min) and (val < limit.max) 

159 

160 

161def gen_dict_extract(key, var): 

162 if hasattr(var, "items"): 

163 for k, v in var.items(): 

164 if k == key: 

165 yield v 

166 if isinstance(v, dict): 

167 for result in gen_dict_extract(key, v): 

168 yield result 

169 elif isinstance(v, list): 

170 for d in v: 

171 for result in gen_dict_extract(key, d): 

172 yield result 

173 # return v, result 

174 

175 

176# def list_enum_to_string(list: List[Enum]) -> List[str]: 

177# """Converts list of enum to strings""" 

178# return [str(i.value) for i in list] 

179 

180 

181def nested_dictionary_location(d: dict, key: str, value: Any) -> List[str]: 

182 """Finds nested location of key-value pair in dictionary, returns a list of key values 

183 from highest to lowest level of nested dictionaries. Checks if key value pair is found 

184 """ 

185 nesting = nested_find_key_value_pair(d=d, key=key, value=value) 

186 if nesting is None: 

187 raise KeyError( 

188 f'Key : value pair of "{key} : {value}" not found in the provided dictionary.' 

189 ) 

190 return nesting 

191 

192 

193def nested_find_key_value_pair(d: dict, key: str, value: Any) -> List[str]: 

194 """Finds key value pair in nested dictionary, returns a list of key 

195 values from highest to lowest level of nested dictionaries""" 

196 for k, v in d.items(): 

197 if k == key: 

198 if v == value: 

199 return [k] 

200 if isinstance(v, dict): 

201 p = nested_find_key_value_pair(v, key, value) 

202 if p: 

203 return [k] + p 

204 

205 

206def _flatten(dictionary: dict) -> dict: 

207 """Flattens a dictionary using pandas, which can be slow on large dictionaries. 

208 

209 From https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys 

210 """ 

211 data_frame = json_normalize(dictionary, sep="_") 

212 db_flat = data_frame.to_dict(orient="records")[0] 

213 return db_flat 

214 

215 

216def none_value_dictionary(dictionary: dict) -> bool: 

217 """Return true if all values in dictionary are None, false otherwise.""" 

218 # flatten the dictionary first 

219 db_flat = _flatten(dictionary) 

220 return all([v is None for v in db_flat.values()]) 

221 

222 

223@contextlib.contextmanager 

224def nostdout(): 

225 """Creates dummy file to suppress output""" 

226 save_stdout = sys.stdout 

227 sys.stdout = tbt.DummyFile() 

228 yield 

229 sys.stdout = save_stdout 

230 

231 

232def step_count( 

233 exp_settings: dict, 

234 yml_format: tbt.YMLFormatVersion, 

235): 

236 """Determine maximum step number from a settings dictionary, as specified by the step_number_key""" 

237 

238 step_number_key = yml_format.step_number_key 

239 non_step_sections = yml_format.non_step_section_count 

240 

241 # make sure dict from yml has correct section count 

242 # (steps should all be in one section) 

243 total_sections = len(exp_settings) 

244 if total_sections != non_step_sections + 1: 

245 raise ValueError( 

246 f"Invalid .yml file, {total_sections} sections were found but the input .yml should have {non_step_sections + 1} sections. Please verify that all top-level keys in the .yml have unique strings and that all steps are contained in a single top-level section." 

247 ) 

248 

249 expected_step_count = exp_settings[yml_format.general_section_key][ 

250 yml_format.step_count_key 

251 ] 

252 

253 found_step_count = 0 

254 while True: 

255 try: 

256 nested_dictionary_location( 

257 d=exp_settings, 

258 key=step_number_key, 

259 value=found_step_count + 1, 

260 ) 

261 except KeyError: 

262 break 

263 found_step_count += 1 

264 

265 # validate number of steps found with steps read by YAML loader 

266 # TODO YAML safeloader will ignore duplicate top level keys, so this check relies on unique step numbers in ascending order (no gaps) to be found. 

267 

268 if expected_step_count != found_step_count: 

269 raise ValueError( 

270 f"Invalid .yml file, {found_step_count} steps were found but the input .yml should have {expected_step_count} steps from the general setting key '{yml_format.step_count_key}' within the '{yml_format.general_section_key}' section. Please verify that all step_name keys in the .yml have unique strings and that step numbers are continuously-increasing positive integers starting at 1." 

271 ) 

272 

273 return found_step_count 

274 

275 

276def step_settings( 

277 exp_settings: dict, 

278 step_number_key: str, 

279 step_number_val: int, 

280 yml_format: tbt.YMLFormatVersion, 

281) -> Tuple[str, dict]: 

282 """Grabs specific step settings from an experimental dictionary and 

283 returns them as a dictionary along with the user-defined step name""" 

284 

285 nested_locations = nested_dictionary_location( 

286 d=exp_settings, 

287 key=step_number_key, 

288 value=step_number_val, 

289 ) 

290 ### top level dictionary key name is first index, need key name nested within it (second level, index = 1) 

291 step_name = nested_locations[1] 

292 step_section_key = yml_format.step_section_key 

293 return step_name, exp_settings[step_section_key][step_name] 

294 

295 

296def valid_microscope_connection(host: str, port: str) -> bool: 

297 """Determines if microscope connection can be made, disconnects if a connection can be made""" 

298 microscope = tbt.Microscope() 

299 if connect_microscope( 

300 microscope=microscope, 

301 quiet_output=True, 

302 connection_host=host, 

303 connection_port=port, 

304 ): 

305 if disconnect_microscope( 

306 microscope=microscope, 

307 quiet_output=True, 

308 ): 

309 return True 

310 return False 

311 

312 

313def enable_external_device(oem: tbt.ExternalDeviceOEM) -> bool: 

314 """Determines whether to enable External Device Control. 

315 Device must be a member of the ExternalDeviceOEM enum and not equal to ExternalDeviceOEM.NONE to enable control 

316 """ 

317 if not isinstance(oem, tbt.ExternalDeviceOEM): 

318 raise NotImplementedError( 

319 f"Unsupported type of {type(oem)}, only 'ExternalDeviceOEM' types are supported." 

320 ) 

321 if oem != tbt.ExternalDeviceOEM.NONE: 

322 return True 

323 return False 

324 

325 

326def valid_enum_entry(obj: Any, check_type: Enum) -> bool: 

327 """Determines if object is member of an Enum class""" 

328 return obj in check_type._value2member_map_ 

329 

330 

331def yml_format(version: float) -> tbt.YMLFormatVersion: 

332 """returns YMLFile format for a given version""" 

333 supported_versions = [file.version for file in tbt.YMLFormatVersion] 

334 if not version in supported_versions: 

335 raise NotImplementedError( 

336 f'Unsupported YML file version for version "{version}". Valid formats include: {[i.value for i in tbt.YMLFormatVersion]}' 

337 ) 

338 yml_file_idx = supported_versions.index(version) 

339 yml_format = list(tbt.YMLFormatVersion)[yml_file_idx] 

340 return yml_format 

341 

342 

343def yml_to_dict( 

344 *, yml_path_file: Path, version: float, required_keys: Tuple[str, ...] 

345) -> Dict: 

346 """Given a valid Path to a yml input file, read it in and return the 

347 result as a dictionary. 

348 

349 Args: 

350 yml_path_file: The fully pathed location to the input file. 

351 version: The version of the yml in x.y format. 

352 required_keys: The key(s) that must be in the yml file for conversion 

353 to a dictionary to occur. 

354 

355 Returns: 

356 The .yml file represented as a dictionary. 

357 """ 

358 

359 # Compared to the lower() method, the casefold() method is stronger. 

360 # It will convert more characters into lower case, and will find more 

361 # matches on comparison of two strings that are both are converted 

362 # using the casefold() method. 

363 file_type = yml_path_file.suffix.casefold() 

364 

365 supported_types = (".yaml", ".yml") 

366 

367 if file_type not in supported_types: 

368 raise TypeError("Only file types .yaml, and .yml are supported.") 

369 

370 try: 

371 with open(file=yml_path_file, mode="r", encoding="utf-8") as stream: 

372 # See deprecation warning for plain yaml.load(input) at 

373 # https://github.com/yaml/pyyaml/wiki/PyYAML-yaml.load(input)-Deprecation 

374 db = yaml.load(stream, Loader=yaml.SafeLoader) 

375 except yaml.YAMLError as error: 

376 print(f"Error with YAML file: {error}") 

377 # print(f"Could not open: {self.self.path_file_in}") 

378 print(f"Could not open or decode: {yml_path_file}") 

379 # raise yaml.YAMLError 

380 raise OSError from error 

381 

382 # check keys found in input file against required keys 

383 found_keys = tuple(db.keys()) 

384 keys_exist = tuple(map(lambda x: x in found_keys, required_keys)) 

385 has_required_keys = all(keys_exist) 

386 if not has_required_keys: 

387 raise KeyError(f"Input files must have these keys defined: {required_keys}") 

388 

389 version_specified = db["config_file_version"] 

390 version_requested = version 

391 

392 if version_specified != version_requested: 

393 ee = f"Version mismatch: specified in file was {version_specified}," 

394 ee += f"requested is {version_requested}" 

395 raise ValueError(ee) 

396 

397 return db 

398 

399 

400def yml_version( 

401 file: Path, 

402 key_name="config_file_version", 

403) -> float: 

404 """Returns version of yml file if proper key exists""" 

405 with open(file, "r") as stream: 

406 data = yaml.load(stream, Loader=yaml.SafeLoader) 

407 

408 try: 

409 version = data[key_name] 

410 except KeyError: 

411 # print(f"Error with version key: {error}") 

412 raise KeyError(f"Error with version key, '{key_name}' key not found in {file}.") 

413 try: 

414 version = float(version) 

415 except ValueError: 

416 raise ValueError( 

417 f"Could not find valid version in {file} for key {key_name}, found '{version}' which is not a float." 

418 ) 

419 return version 

420 

421 

422def yes_no(question): 

423 """Simple Yes/No Function.""" 

424 prompt = f"{question} (y/n): " 

425 ans = input(prompt).strip().lower() 

426 if ans not in ["y", "n"]: 

427 print(f"{ans} is invalid, please try again...") 

428 return yes_no(question) 

429 if ans == "y": 

430 return True 

431 return False 

432 

433 

434def remove_directory(directory: Path): 

435 """Recursively remove a directory""" 

436 shutil.rmtree(directory) 

437 

438 

439def split_list(data: List, chunk_size: int) -> List: 

440 """split list into equal sized chunks""" 

441 result = [] 

442 for i in range(0, len(data), chunk_size): 

443 result.append(data[i : i + chunk_size]) 

444 return result 

445 

446 

447def tabular_list( 

448 data: List, 

449 num_columns: int = Constants.default_column_count, 

450 column_width: int = Constants.default_column_width, 

451) -> str: 

452 rows = split_list(data, chunk_size=num_columns) 

453 result = "" 

454 for sublist in rows: 

455 result += "\n" 

456 for item in sublist: 

457 result += f"{item:^{column_width}}" 

458 return result 

459 

460 

461### Custom Decorators ### 

462 

463 

464def hardware_movement(func): 

465 @run_on_microscope_machine 

466 def wrapper_func(): 

467 if not Constants.test_hardware_movement: 

468 pytest.skip("Run only when hardware testing is enabled") 

469 func() 

470 

471 return wrapper_func 

472 

473 

474def run_on_standalone_machine(func): 

475 def wrapper_func(): 

476 current_machine = platform.uname().node.lower() 

477 test_machines = [machine.lower() for machine in Constants().offline_machines] 

478 if current_machine not in test_machines: 

479 pytest.skip("Run on Offline License Machine Only.") 

480 func() 

481 

482 return wrapper_func 

483 

484 

485def run_on_microscope_machine(func): 

486 def wrapper_func(): 

487 current_machine = platform.uname().node.lower() 

488 test_machines = [machine.lower() for machine in Constants().microscope_machines] 

489 if current_machine not in test_machines: 

490 pytest.skip("Run on Microscope Machine Only.") 

491 func() 

492 

493 return wrapper_func