Coverage for src/pytribeam/GUI/config_ui/pipeline_model.py: 0%

271 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-06-16 18:30 +0000

1"""Pipeline configuration data model. 

2 

3This module provides data structures for representing and manipulating 

4pipeline configurations, separating data concerns from UI concerns. 

5""" 

6 

7from copy import deepcopy 

8from dataclasses import dataclass, field 

9from pathlib import Path 

10from typing import Any, Optional, Dict, List, Tuple 

11 

12import pytribeam.GUI.config_ui.lookup as lut 

13import pytribeam.utilities as ut 

14 

15 

16def _check_value_type(value: Any, dtype: type) -> Any: 

17 """Convert value to correct type based on dtype. 

18 

19 Handles conversion from string representations to proper Python types. 

20 Converts '', 'null', 'None' to None, 'True'/'true' to True, etc. 

21 

22 Args: 

23 value: Value to convert (usually a string) 

24 dtype: Target data type 

25 

26 Returns: 

27 Converted value with correct type 

28 """ 

29 # Handle None values 

30 if value in ["", "null", "None", None]: 

31 return None 

32 

33 # If value is already the correct type, return as-is (especially for booleans) 

34 if dtype is not None and isinstance(value, dtype): 

35 return value 

36 

37 if isinstance(value, str): 

38 value = value.strip() 

39 

40 # If no dtype specified, return as-is 

41 if dtype is None: 

42 return value 

43 

44 # Handle boolean strings 

45 if value in ["True", "true"]: 

46 return True 

47 elif value in ["False", "false"]: 

48 return False 

49 

50 # Convert to target type 

51 try: 

52 return dtype(value) 

53 except (ValueError, TypeError): 

54 # If conversion fails, return original value 

55 return value 

56 

57 

58def _apply_type_conversion( 

59 params: Dict[str, Any], step_type: str, version: float 

60) -> Dict[str, Any]: 

61 """Apply type conversion to parameters based on LUT. 

62 

63 Args: 

64 params: Flattened parameter dictionary (with "/" separators) 

65 step_type: Step type (e.g., "general", "image", "fib") 

66 version: Configuration version 

67 

68 Returns: 

69 Dictionary with type-converted values 

70 """ 

71 # Get LUT for this step type 

72 try: 

73 step_lut = lut.get_lut(step_type.lower(), version) 

74 step_lut.flatten() 

75 except Exception: 

76 # If LUT not found, return params as-is 

77 return params 

78 

79 converted = {} 

80 for key, value in params.items(): 

81 if key in step_lut.keys(): 

82 # Get dtype from LUT and convert 

83 dtype = step_lut[key].dtype 

84 converted[key] = _check_value_type(value, dtype) 

85 else: 

86 # Keep parameters not in LUT as-is (they'll be filtered out later) 

87 converted[key] = value 

88 

89 return converted 

90 

91 

92@dataclass 

93class StepConfig: 

94 """Configuration for a single pipeline step. 

95 

96 Represents one step in the experiment pipeline with all its parameters. 

97 Parameters are stored in a flattened dictionary with '/' separators. 

98 

99 Attributes: 

100 index: Position in pipeline (0 for general, 1+ for steps) 

101 step_type: Type of step (e.g., 'image', 'fib', 'laser') 

102 name: Unique name for this step 

103 parameters: Flattened dictionary of step parameters 

104 version: Configuration file version 

105 """ 

106 

107 index: int 

108 step_type: str 

109 name: str 

110 parameters: Dict[str, Any] = field(default_factory=dict) 

111 version: float = field(default=float(lut.VERSIONS[-1])) 

112 

113 def get_param(self, path: str, default: Any = None) -> Any: 

114 """Get parameter value by path. 

115 

116 Args: 

117 path: Parameter path with '/' separators (e.g., 'beam/voltage_kv') 

118 default: Value to return if parameter not found 

119 

120 Returns: 

121 Parameter value or default 

122 """ 

123 return self.parameters.get(path, default) 

124 

125 def set_param(self, path: str, value: Any): 

126 """Set parameter value by path. 

127 

128 Args: 

129 path: Parameter path with '/' separators 

130 value: Value to set 

131 """ 

132 self.parameters[path] = value 

133 if path == "step_general/step_name": 

134 self.name = str(value) 

135 

136 def has_param(self, path: str) -> bool: 

137 """Check if parameter exists. 

138 

139 Args: 

140 path: Parameter path with '/' separators 

141 

142 Returns: 

143 True if parameter exists 

144 """ 

145 return path in self.parameters 

146 

147 def get_all_params(self, flat: bool = True) -> Dict[str, Any]: 

148 """Get all parameters as dictionary. 

149 

150 Returns: 

151 Copy of parameters dictionary 

152 """ 

153 db = _apply_type_conversion( 

154 deepcopy(self.parameters), self.step_type, self.version 

155 ) 

156 if flat: 

157 return db 

158 else: 

159 return unflatten_dict(db, sep="/") 

160 

161 def update_params(self, params: Dict[str, Any]): 

162 """Update multiple parameters at once. 

163 

164 Args: 

165 params: Dictionary of parameters to update 

166 """ 

167 self.parameters.update(params) 

168 

169 def clear_params(self): 

170 """Remove all parameters.""" 

171 self.parameters.clear() 

172 

173 def __repr__(self) -> str: 

174 return f"StepConfig(index={self.index}, type={self.step_type}, name={self.name}), version={self.version}" 

175 

176 

177@dataclass 

178class PipelineConfig: 

179 """Complete pipeline configuration. 

180 

181 Represents the entire experiment configuration including general settings 

182 and all pipeline steps. 

183 

184 Attributes: 

185 version: Configuration file version 

186 general: General configuration step 

187 steps: List of pipeline steps 

188 file_path: Path to configuration file (if loaded from file) 

189 """ 

190 

191 version: float 

192 general: StepConfig 

193 steps: List[StepConfig] = field(default_factory=list) 

194 file_path: Optional[Path] = None 

195 

196 @classmethod 

197 def create_new(cls, version: float = None) -> "PipelineConfig": 

198 """Create new empty pipeline configuration. 

199 

200 Initializes general step with all parameters from LUT with default values. 

201 

202 Args: 

203 version: Config file version (uses latest if not specified) 

204 

205 Returns: 

206 New PipelineConfig with only general step 

207 """ 

208 if version is None: 

209 version = float(lut.VERSIONS[-1]) 

210 

211 # Create temporary instance to use helper method 

212 temp_pipeline = cls( 

213 version=version, 

214 general=StepConfig( 

215 index=0, 

216 step_type="general", 

217 name="general", 

218 parameters={}, 

219 version=version, 

220 ), 

221 steps=[], 

222 ) 

223 

224 # Get all default parameters from general LUT 

225 general_params = temp_pipeline._populate_default_parameters("general") 

226 

227 # Ensure step_count is set to 0 

228 general_params["step_count"] = "0" 

229 

230 general = StepConfig( 

231 index=0, 

232 step_type="general", 

233 name="general", 

234 parameters=general_params, 

235 ) 

236 

237 return cls(version=version, general=general, steps=[]) 

238 

239 def _update_step_count(self): 

240 """Update step_count parameter in general step to reflect current step count.""" 

241 self.general.set_param("step_count", str(len(self.steps))) 

242 

243 def _populate_default_parameters(self, step_type: str) -> Dict[str, Any]: 

244 """Populate parameters with default values from LUT. 

245 

246 Args: 

247 step_type: Type of step (e.g., 'general', 'image', 'fib') 

248 

249 Returns: 

250 Dictionary of parameter paths to default values (preserves booleans) 

251 """ 

252 try: 

253 # Get LUT for this step type and version 

254 step_lut = lut.get_lut(step_type.lower(), self.version) 

255 step_lut_flat = deepcopy(step_lut) 

256 step_lut_flat.flatten() 

257 

258 # Extract all parameters with their defaults 

259 params = {} 

260 for key, field in step_lut_flat.items(): 

261 # Get default value, preserving type (especially booleans) 

262 default_value = field.default if field.default is not None else "" 

263 # Preserve boolean type, convert others to string 

264 if isinstance(default_value, bool): 

265 params[key] = default_value 

266 else: 

267 params[key] = str(default_value) 

268 

269 return params 

270 except Exception as e: 

271 # If LUT lookup fails, return empty dict 

272 print(f"Warning: Failed to get LUT defaults for {step_type}: {e}") 

273 return {} 

274 

275 def add_step(self, step_type: str, name: Optional[str] = None) -> StepConfig: 

276 """Add new step to pipeline. 

277 

278 Initializes step with all parameters from LUT with default values. 

279 

280 Args: 

281 step_type: Type of step to add (e.g., 'image', 'fib') 

282 name: Optional custom name (auto-generated if not provided) 

283 

284 Returns: 

285 Newly created StepConfig 

286 """ 

287 index = len(self.steps) + 1 

288 

289 # Auto-generate name if not provided 

290 if name is None: 

291 count = sum(1 for s in self.steps if s.step_type == step_type) 

292 name = f"{step_type}_{count + 1}" 

293 

294 # Get all default parameters from LUT 

295 parameters = self._populate_default_parameters(step_type) 

296 

297 # Override with step-specific values 

298 parameters.update( 

299 { 

300 "step_general/step_type": step_type, 

301 "step_general/step_name": name, 

302 "step_general/step_number": str(index), 

303 } 

304 ) 

305 

306 step = StepConfig( 

307 index=index, 

308 step_type=step_type, 

309 name=name, 

310 parameters=parameters, 

311 version=self.version, 

312 ) 

313 

314 self.steps.append(step) 

315 

316 # Update step count in general 

317 self._update_step_count() 

318 

319 return step 

320 

321 def remove_step(self, index: int) -> bool: 

322 """Remove step at specified index. 

323 

324 Re-indexes remaining steps to maintain sequential ordering. 

325 

326 Args: 

327 index: Index of step to remove (1-based, not 0 which is general) 

328 

329 Returns: 

330 True if step was removed, False if index invalid 

331 """ 

332 if index == 0 or index > len(self.steps): 

333 return False 

334 

335 # Remove step 

336 self.steps = [s for s in self.steps if s.index != index] 

337 

338 # Reindex remaining steps 

339 for i, step in enumerate(self.steps, 1): 

340 step.index = i 

341 step.set_param("step_general/step_number", str(i)) 

342 

343 # Update step count in general 

344 self._update_step_count() 

345 

346 return True 

347 

348 def move_step(self, index: int, direction: int) -> bool: 

349 """Move step up or down in pipeline. 

350 

351 Args: 

352 index: Index of step to move (1-based) 

353 direction: -1 to move up, +1 to move down 

354 

355 Returns: 

356 True if step was moved, False if move invalid 

357 """ 

358 if index == 0 or index > len(self.steps): 

359 return False 

360 

361 new_index = index + direction 

362 if new_index < 1 or new_index > len(self.steps): 

363 return False 

364 

365 # Swap steps in list (convert to 0-based indices) 

366 idx1 = index - 1 

367 idx2 = new_index - 1 

368 self.steps[idx1], self.steps[idx2] = self.steps[idx2], self.steps[idx1] 

369 

370 # Update indices 

371 self.steps[idx1].index = index 

372 self.steps[idx2].index = new_index 

373 self.steps[idx1].set_param("step_general/step_number", str(index)) 

374 self.steps[idx2].set_param("step_general/step_number", str(new_index)) 

375 

376 return True 

377 

378 def duplicate_step(self, index: int) -> Optional[StepConfig]: 

379 """Duplicate an existing step. 

380 

381 Creates a copy of the step and adds it to the end of the pipeline 

382 with an auto-generated name. 

383 

384 Args: 

385 index: Index of step to duplicate 

386 

387 Returns: 

388 Newly created StepConfig or None if index invalid 

389 """ 

390 if index == 0 or index > len(self.steps): 

391 return None 

392 

393 original = self.steps[index - 1] 

394 new_step = deepcopy(original) 

395 new_step.index = len(self.steps) + 1 

396 

397 # Generate new name 

398 count = sum(1 for s in self.steps if s.step_type == original.step_type) 

399 new_step.name = f"{original.step_type}_{count + 1}" 

400 new_step.set_param("step_general/step_name", new_step.name) 

401 new_step.set_param("step_general/step_number", str(new_step.index)) 

402 

403 self.steps.append(new_step) 

404 

405 # Update step count in general 

406 self._update_step_count() 

407 

408 return new_step 

409 

410 def get_step(self, index: int) -> Optional[StepConfig]: 

411 """Get step by index. 

412 

413 Args: 

414 index: Step index (0 for general, 1+ for pipeline steps) 

415 

416 Returns: 

417 StepConfig or None if index invalid 

418 """ 

419 if index == 0: 

420 return self.general 

421 elif 1 <= index <= len(self.steps): 

422 return self.steps[index - 1] 

423 return None 

424 

425 def get_step_by_name(self, name: str) -> Optional[StepConfig]: 

426 """Get step by name. 

427 

428 Args: 

429 name: Step name to search for 

430 

431 Returns: 

432 StepConfig or None if not found 

433 """ 

434 if name == "general": 

435 return self.general 

436 

437 for step in self.steps: 

438 if step.name == name: 

439 return step 

440 return None 

441 

442 def get_step_count(self) -> int: 

443 """Get number of steps (excluding general). 

444 

445 Returns: 

446 Number of pipeline steps 

447 """ 

448 return len(self.steps) 

449 

450 def validate_step_names(self) -> Tuple[bool, List[str]]: 

451 """Check for duplicate step names. 

452 

453 Returns: 

454 Tuple of (is_valid, list_of_duplicate_names) 

455 """ 

456 names = [s.name for s in self.steps] 

457 duplicates = [name for name in set(names) if names.count(name) > 1] 

458 return len(duplicates) == 0, duplicates 

459 

460 @classmethod 

461 def from_yaml(cls, yaml_path: Path) -> "PipelineConfig": 

462 """Load pipeline configuration from YAML file. 

463 

464 Args: 

465 yaml_path: Path to YAML configuration file 

466 

467 Returns: 

468 PipelineConfig instance 

469 

470 Raises: 

471 FileNotFoundError: If YAML file doesn't exist 

472 ValueError: If YAML is invalid or missing required fields 

473 """ 

474 if not yaml_path.exists(): 

475 raise FileNotFoundError(f"Configuration file not found: {yaml_path}") 

476 

477 # Read YAML file 

478 yml_version = ut.yml_version(yaml_path) 

479 db = ut.yml_to_dict( 

480 yml_path_file=yaml_path, 

481 version=yml_version, 

482 required_keys=("general", "config_file_version"), 

483 ) 

484 

485 # Extract general settings 

486 general_dict = db["general"] 

487 general_dict["step_type"] = "general" 

488 general_flat = flatten_dict(general_dict, sep="/") 

489 

490 general = StepConfig( 

491 index=0, 

492 step_type="general", 

493 name="general", 

494 parameters={k: _value_to_string(v) for k, v in general_flat.items()}, 

495 version=float(yml_version), 

496 ) 

497 

498 # Extract steps 

499 steps_list = [] 

500 if "steps" in db: 

501 steps_dict = db["steps"] 

502 step_order = [] 

503 

504 for step_name, step_data in steps_dict.items(): 

505 step_type = step_data["step_general"]["step_type"] 

506 step_number = step_data["step_general"]["step_number"] 

507 step_order.append((step_name, step_type, step_number)) 

508 

509 # Flatten and convert to strings 

510 flat_step = flatten_dict(step_data, sep="/") 

511 flat_step["step_general/step_name"] = step_name 

512 

513 step = StepConfig( 

514 index=step_number, 

515 step_type=step_type, 

516 name=step_name, 

517 parameters={k: _value_to_string(v) for k, v in flat_step.items()}, 

518 version=float(yml_version), 

519 ) 

520 steps_list.append(step) 

521 

522 # Sort by step number 

523 steps_list.sort(key=lambda s: s.index) 

524 

525 # Create pipeline instance 

526 pipeline = cls( 

527 version=float(yml_version), 

528 general=general, 

529 steps=steps_list, 

530 file_path=yaml_path, 

531 ) 

532 

533 # Ensure step_count in general is accurate 

534 pipeline._update_step_count() 

535 

536 return pipeline 

537 

538 def to_yaml(self, yaml_path: Path): 

539 """Save pipeline configuration to YAML file. 

540 

541 Args: 

542 yaml_path: Path where YAML should be saved 

543 """ 

544 config_dict = self.to_dict() 

545 ut.dict_to_yml(config_dict, str(yaml_path)) 

546 self.file_path = yaml_path 

547 

548 def to_dict(self) -> Dict: 

549 """Convert pipeline to dictionary suitable for YAML export. 

550 

551 Applies type conversion based on LUT to ensure parameters have 

552 correct types (int, float, str, bool) instead of all being strings. 

553 

554 Returns: 

555 Dictionary with version, general, and steps 

556 """ 

557 # Convert general parameters with type checking 

558 general_params = {} 

559 for key, value in self.general.parameters.items(): 

560 if key != "step_type": 

561 general_params[key] = value 

562 

563 # Apply type conversion based on LUT 

564 general_params = _apply_type_conversion(general_params, "general", self.version) 

565 

566 # Remove any parameters not in LUT 

567 general_lut = lut.get_lut("general", self.version) 

568 general_lut.flatten() 

569 general_params = { 

570 k: v for k, v in general_params.items() if k in general_lut.keys() 

571 } 

572 general_params = _apply_type_conversion(general_params, "general", self.version) 

573 

574 general_dict = unflatten_dict(general_params, sep="/") 

575 

576 # Convert steps with type checking 

577 steps_dict = {} 

578 for step in self.steps: 

579 step_params = {} 

580 for key, value in step.parameters.items(): 

581 if key not in ["step_general/step_name", "step_general/step_number"]: 

582 step_params[key] = value 

583 

584 # Add step number 

585 step_params["step_general/step_number"] = step.index 

586 

587 # Apply type conversion based on LUT 

588 step_params = _apply_type_conversion( 

589 step_params, step.step_type, self.version 

590 ) 

591 

592 # Remove any parameters not in LUT 

593 step_lut = lut.get_lut(step.step_type.lower(), self.version) 

594 step_lut.flatten() 

595 step_params = { 

596 k: v 

597 for k, v in step_params.items() 

598 if k in step_lut.keys() or k == "step_general/step_number" 

599 } 

600 

601 steps_dict[step.name] = unflatten_dict(step_params, sep="/") 

602 

603 return { 

604 "config_file_version": self.version, 

605 "general": general_dict, 

606 "steps": steps_dict, 

607 } 

608 

609 def set_version(self, new_version: float): 

610 """Update pipeline version and migrate all parameters to new version. 

611 

612 This method: 

613 1. Updates version for pipeline and all steps 

614 2. Adds new parameters introduced in new version (with defaults) 

615 3. Removes parameters that don't exist in new version 

616 4. Preserves existing parameter values where applicable 

617 

618 Args: 

619 new_version: Target version to migrate to 

620 """ 

621 if new_version == self.version: 

622 return # No change needed 

623 

624 # Update general step 

625 self._migrate_step_to_version(self.general, new_version) 

626 

627 # Update all pipeline steps 

628 for step in self.steps: 

629 self._migrate_step_to_version(step, new_version) 

630 

631 # Update pipeline version 

632 self.version = new_version 

633 

634 def _migrate_step_to_version(self, step: StepConfig, new_version: float): 

635 """Migrate a single step to a new version. 

636 

637 Args: 

638 step: Step to migrate 

639 new_version: Target version 

640 """ 

641 old_version = step.version 

642 step_type = step.step_type 

643 

644 # Get LUTs for old and new versions 

645 try: 

646 old_lut = lut.get_lut(step_type.lower(), old_version) 

647 old_lut.flatten() 

648 except Exception: 

649 # If old LUT doesn't exist, use empty dict 

650 old_lut = lut.LUT() 

651 old_lut.flatten() 

652 

653 try: 

654 new_lut = lut.get_lut(step_type.lower(), new_version) 

655 new_lut.flatten() 

656 except Exception: 

657 # If new LUT doesn't exist, can't migrate 

658 print(f"Warning: Cannot get LUT for {step_type} version {new_version}") 

659 step.version = new_version 

660 return 

661 

662 # Get current parameters 

663 current_params = deepcopy(step.parameters) 

664 

665 # Build new parameter set 

666 new_params = {} 

667 

668 # Process all parameters in new LUT 

669 for param_key, field in new_lut.items(): 

670 if param_key in current_params: 

671 # Parameter exists in both - keep current value 

672 new_params[param_key] = current_params[param_key] 

673 else: 

674 # New parameter - use default from LUT 

675 default_value = field.default if field.default is not None else "" 

676 new_params[param_key] = str(default_value) 

677 

678 # Note: Parameters that exist in old but not new are automatically dropped 

679 # by not adding them to new_params 

680 

681 # Update step with migrated parameters 

682 step.parameters = new_params 

683 step.version = new_version 

684 

685 def __repr__(self) -> str: 

686 return f"PipelineConfig(version={self.version}, steps={len(self.steps)})" 

687 

688 

689def _value_to_string(value: Any) -> str: 

690 """Convert value to string, handling None/null as empty string. 

691 

692 Special handling: 

693 - None/null → empty string 

694 - Booleans → keep as boolean (not converted to string) 

695 - Everything else → string 

696 

697 Args: 

698 value: Value to convert 

699 

700 Returns: 

701 String representation (empty string for None, bool for booleans) 

702 """ 

703 if value is None or value == "null": 

704 return "" 

705 # Preserve boolean type for checkbuttons 

706 if isinstance(value, bool): 

707 return value 

708 return str(value) 

709 

710 

711def flatten_dict(d: Dict, parent_key: str = "", sep: str = "/") -> Dict: 

712 """Flatten nested dictionary to single level with separator. 

713 

714 Args: 

715 d: Dictionary to flatten 

716 parent_key: Parent key prefix 

717 sep: Separator for nested keys 

718 

719 Returns: 

720 Flattened dictionary 

721 

722 Example: 

723 {'a': {'b': 1}} -> {'a/b': 1} 

724 """ 

725 items = [] 

726 for k, v in d.items(): 

727 new_key = f"{parent_key}{sep}{k}" if parent_key else k 

728 if isinstance(v, dict): 

729 items.extend(flatten_dict(v, new_key, sep=sep).items()) 

730 else: 

731 items.append((new_key, v)) 

732 return dict(items) 

733 

734 

735def unflatten_dict(d: Dict, sep: str = "/") -> Dict: 

736 """Unflatten dictionary with separators to nested structure. 

737 

738 Args: 

739 d: Flattened dictionary 

740 sep: Separator used in keys 

741 

742 Returns: 

743 Nested dictionary 

744 

745 Example: 

746 {'a/b': 1} -> {'a': {'b': 1}} 

747 """ 

748 result = {} 

749 for key, value in d.items(): 

750 parts = key.split(sep) 

751 current = result 

752 for part in parts[:-1]: 

753 current = current.setdefault(part, {}) 

754 current[parts[-1]] = value 

755 return result