Coverage for src/pytribeam/GUI/config_ui/pipeline_model.py: 0%
271 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-06-16 18:30 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-06-16 18:30 +0000
1"""Pipeline configuration data model.
3This module provides data structures for representing and manipulating
4pipeline configurations, separating data concerns from UI concerns.
5"""
7from copy import deepcopy
8from dataclasses import dataclass, field
9from pathlib import Path
10from typing import Any, Optional, Dict, List, Tuple
12import pytribeam.GUI.config_ui.lookup as lut
13import pytribeam.utilities as ut
16def _check_value_type(value: Any, dtype: type) -> Any:
17 """Convert value to correct type based on dtype.
19 Handles conversion from string representations to proper Python types.
20 Converts '', 'null', 'None' to None, 'True'/'true' to True, etc.
22 Args:
23 value: Value to convert (usually a string)
24 dtype: Target data type
26 Returns:
27 Converted value with correct type
28 """
29 # Handle None values
30 if value in ["", "null", "None", None]:
31 return None
33 # If value is already the correct type, return as-is (especially for booleans)
34 if dtype is not None and isinstance(value, dtype):
35 return value
37 if isinstance(value, str):
38 value = value.strip()
40 # If no dtype specified, return as-is
41 if dtype is None:
42 return value
44 # Handle boolean strings
45 if value in ["True", "true"]:
46 return True
47 elif value in ["False", "false"]:
48 return False
50 # Convert to target type
51 try:
52 return dtype(value)
53 except (ValueError, TypeError):
54 # If conversion fails, return original value
55 return value
58def _apply_type_conversion(
59 params: Dict[str, Any], step_type: str, version: float
60) -> Dict[str, Any]:
61 """Apply type conversion to parameters based on LUT.
63 Args:
64 params: Flattened parameter dictionary (with "/" separators)
65 step_type: Step type (e.g., "general", "image", "fib")
66 version: Configuration version
68 Returns:
69 Dictionary with type-converted values
70 """
71 # Get LUT for this step type
72 try:
73 step_lut = lut.get_lut(step_type.lower(), version)
74 step_lut.flatten()
75 except Exception:
76 # If LUT not found, return params as-is
77 return params
79 converted = {}
80 for key, value in params.items():
81 if key in step_lut.keys():
82 # Get dtype from LUT and convert
83 dtype = step_lut[key].dtype
84 converted[key] = _check_value_type(value, dtype)
85 else:
86 # Keep parameters not in LUT as-is (they'll be filtered out later)
87 converted[key] = value
89 return converted
92@dataclass
93class StepConfig:
94 """Configuration for a single pipeline step.
96 Represents one step in the experiment pipeline with all its parameters.
97 Parameters are stored in a flattened dictionary with '/' separators.
99 Attributes:
100 index: Position in pipeline (0 for general, 1+ for steps)
101 step_type: Type of step (e.g., 'image', 'fib', 'laser')
102 name: Unique name for this step
103 parameters: Flattened dictionary of step parameters
104 version: Configuration file version
105 """
107 index: int
108 step_type: str
109 name: str
110 parameters: Dict[str, Any] = field(default_factory=dict)
111 version: float = field(default=float(lut.VERSIONS[-1]))
113 def get_param(self, path: str, default: Any = None) -> Any:
114 """Get parameter value by path.
116 Args:
117 path: Parameter path with '/' separators (e.g., 'beam/voltage_kv')
118 default: Value to return if parameter not found
120 Returns:
121 Parameter value or default
122 """
123 return self.parameters.get(path, default)
125 def set_param(self, path: str, value: Any):
126 """Set parameter value by path.
128 Args:
129 path: Parameter path with '/' separators
130 value: Value to set
131 """
132 self.parameters[path] = value
133 if path == "step_general/step_name":
134 self.name = str(value)
136 def has_param(self, path: str) -> bool:
137 """Check if parameter exists.
139 Args:
140 path: Parameter path with '/' separators
142 Returns:
143 True if parameter exists
144 """
145 return path in self.parameters
147 def get_all_params(self, flat: bool = True) -> Dict[str, Any]:
148 """Get all parameters as dictionary.
150 Returns:
151 Copy of parameters dictionary
152 """
153 db = _apply_type_conversion(
154 deepcopy(self.parameters), self.step_type, self.version
155 )
156 if flat:
157 return db
158 else:
159 return unflatten_dict(db, sep="/")
161 def update_params(self, params: Dict[str, Any]):
162 """Update multiple parameters at once.
164 Args:
165 params: Dictionary of parameters to update
166 """
167 self.parameters.update(params)
169 def clear_params(self):
170 """Remove all parameters."""
171 self.parameters.clear()
173 def __repr__(self) -> str:
174 return f"StepConfig(index={self.index}, type={self.step_type}, name={self.name}), version={self.version}"
177@dataclass
178class PipelineConfig:
179 """Complete pipeline configuration.
181 Represents the entire experiment configuration including general settings
182 and all pipeline steps.
184 Attributes:
185 version: Configuration file version
186 general: General configuration step
187 steps: List of pipeline steps
188 file_path: Path to configuration file (if loaded from file)
189 """
191 version: float
192 general: StepConfig
193 steps: List[StepConfig] = field(default_factory=list)
194 file_path: Optional[Path] = None
196 @classmethod
197 def create_new(cls, version: float = None) -> "PipelineConfig":
198 """Create new empty pipeline configuration.
200 Initializes general step with all parameters from LUT with default values.
202 Args:
203 version: Config file version (uses latest if not specified)
205 Returns:
206 New PipelineConfig with only general step
207 """
208 if version is None:
209 version = float(lut.VERSIONS[-1])
211 # Create temporary instance to use helper method
212 temp_pipeline = cls(
213 version=version,
214 general=StepConfig(
215 index=0,
216 step_type="general",
217 name="general",
218 parameters={},
219 version=version,
220 ),
221 steps=[],
222 )
224 # Get all default parameters from general LUT
225 general_params = temp_pipeline._populate_default_parameters("general")
227 # Ensure step_count is set to 0
228 general_params["step_count"] = "0"
230 general = StepConfig(
231 index=0,
232 step_type="general",
233 name="general",
234 parameters=general_params,
235 )
237 return cls(version=version, general=general, steps=[])
239 def _update_step_count(self):
240 """Update step_count parameter in general step to reflect current step count."""
241 self.general.set_param("step_count", str(len(self.steps)))
243 def _populate_default_parameters(self, step_type: str) -> Dict[str, Any]:
244 """Populate parameters with default values from LUT.
246 Args:
247 step_type: Type of step (e.g., 'general', 'image', 'fib')
249 Returns:
250 Dictionary of parameter paths to default values (preserves booleans)
251 """
252 try:
253 # Get LUT for this step type and version
254 step_lut = lut.get_lut(step_type.lower(), self.version)
255 step_lut_flat = deepcopy(step_lut)
256 step_lut_flat.flatten()
258 # Extract all parameters with their defaults
259 params = {}
260 for key, field in step_lut_flat.items():
261 # Get default value, preserving type (especially booleans)
262 default_value = field.default if field.default is not None else ""
263 # Preserve boolean type, convert others to string
264 if isinstance(default_value, bool):
265 params[key] = default_value
266 else:
267 params[key] = str(default_value)
269 return params
270 except Exception as e:
271 # If LUT lookup fails, return empty dict
272 print(f"Warning: Failed to get LUT defaults for {step_type}: {e}")
273 return {}
275 def add_step(self, step_type: str, name: Optional[str] = None) -> StepConfig:
276 """Add new step to pipeline.
278 Initializes step with all parameters from LUT with default values.
280 Args:
281 step_type: Type of step to add (e.g., 'image', 'fib')
282 name: Optional custom name (auto-generated if not provided)
284 Returns:
285 Newly created StepConfig
286 """
287 index = len(self.steps) + 1
289 # Auto-generate name if not provided
290 if name is None:
291 count = sum(1 for s in self.steps if s.step_type == step_type)
292 name = f"{step_type}_{count + 1}"
294 # Get all default parameters from LUT
295 parameters = self._populate_default_parameters(step_type)
297 # Override with step-specific values
298 parameters.update(
299 {
300 "step_general/step_type": step_type,
301 "step_general/step_name": name,
302 "step_general/step_number": str(index),
303 }
304 )
306 step = StepConfig(
307 index=index,
308 step_type=step_type,
309 name=name,
310 parameters=parameters,
311 version=self.version,
312 )
314 self.steps.append(step)
316 # Update step count in general
317 self._update_step_count()
319 return step
321 def remove_step(self, index: int) -> bool:
322 """Remove step at specified index.
324 Re-indexes remaining steps to maintain sequential ordering.
326 Args:
327 index: Index of step to remove (1-based, not 0 which is general)
329 Returns:
330 True if step was removed, False if index invalid
331 """
332 if index == 0 or index > len(self.steps):
333 return False
335 # Remove step
336 self.steps = [s for s in self.steps if s.index != index]
338 # Reindex remaining steps
339 for i, step in enumerate(self.steps, 1):
340 step.index = i
341 step.set_param("step_general/step_number", str(i))
343 # Update step count in general
344 self._update_step_count()
346 return True
348 def move_step(self, index: int, direction: int) -> bool:
349 """Move step up or down in pipeline.
351 Args:
352 index: Index of step to move (1-based)
353 direction: -1 to move up, +1 to move down
355 Returns:
356 True if step was moved, False if move invalid
357 """
358 if index == 0 or index > len(self.steps):
359 return False
361 new_index = index + direction
362 if new_index < 1 or new_index > len(self.steps):
363 return False
365 # Swap steps in list (convert to 0-based indices)
366 idx1 = index - 1
367 idx2 = new_index - 1
368 self.steps[idx1], self.steps[idx2] = self.steps[idx2], self.steps[idx1]
370 # Update indices
371 self.steps[idx1].index = index
372 self.steps[idx2].index = new_index
373 self.steps[idx1].set_param("step_general/step_number", str(index))
374 self.steps[idx2].set_param("step_general/step_number", str(new_index))
376 return True
378 def duplicate_step(self, index: int) -> Optional[StepConfig]:
379 """Duplicate an existing step.
381 Creates a copy of the step and adds it to the end of the pipeline
382 with an auto-generated name.
384 Args:
385 index: Index of step to duplicate
387 Returns:
388 Newly created StepConfig or None if index invalid
389 """
390 if index == 0 or index > len(self.steps):
391 return None
393 original = self.steps[index - 1]
394 new_step = deepcopy(original)
395 new_step.index = len(self.steps) + 1
397 # Generate new name
398 count = sum(1 for s in self.steps if s.step_type == original.step_type)
399 new_step.name = f"{original.step_type}_{count + 1}"
400 new_step.set_param("step_general/step_name", new_step.name)
401 new_step.set_param("step_general/step_number", str(new_step.index))
403 self.steps.append(new_step)
405 # Update step count in general
406 self._update_step_count()
408 return new_step
410 def get_step(self, index: int) -> Optional[StepConfig]:
411 """Get step by index.
413 Args:
414 index: Step index (0 for general, 1+ for pipeline steps)
416 Returns:
417 StepConfig or None if index invalid
418 """
419 if index == 0:
420 return self.general
421 elif 1 <= index <= len(self.steps):
422 return self.steps[index - 1]
423 return None
425 def get_step_by_name(self, name: str) -> Optional[StepConfig]:
426 """Get step by name.
428 Args:
429 name: Step name to search for
431 Returns:
432 StepConfig or None if not found
433 """
434 if name == "general":
435 return self.general
437 for step in self.steps:
438 if step.name == name:
439 return step
440 return None
442 def get_step_count(self) -> int:
443 """Get number of steps (excluding general).
445 Returns:
446 Number of pipeline steps
447 """
448 return len(self.steps)
450 def validate_step_names(self) -> Tuple[bool, List[str]]:
451 """Check for duplicate step names.
453 Returns:
454 Tuple of (is_valid, list_of_duplicate_names)
455 """
456 names = [s.name for s in self.steps]
457 duplicates = [name for name in set(names) if names.count(name) > 1]
458 return len(duplicates) == 0, duplicates
460 @classmethod
461 def from_yaml(cls, yaml_path: Path) -> "PipelineConfig":
462 """Load pipeline configuration from YAML file.
464 Args:
465 yaml_path: Path to YAML configuration file
467 Returns:
468 PipelineConfig instance
470 Raises:
471 FileNotFoundError: If YAML file doesn't exist
472 ValueError: If YAML is invalid or missing required fields
473 """
474 if not yaml_path.exists():
475 raise FileNotFoundError(f"Configuration file not found: {yaml_path}")
477 # Read YAML file
478 yml_version = ut.yml_version(yaml_path)
479 db = ut.yml_to_dict(
480 yml_path_file=yaml_path,
481 version=yml_version,
482 required_keys=("general", "config_file_version"),
483 )
485 # Extract general settings
486 general_dict = db["general"]
487 general_dict["step_type"] = "general"
488 general_flat = flatten_dict(general_dict, sep="/")
490 general = StepConfig(
491 index=0,
492 step_type="general",
493 name="general",
494 parameters={k: _value_to_string(v) for k, v in general_flat.items()},
495 version=float(yml_version),
496 )
498 # Extract steps
499 steps_list = []
500 if "steps" in db:
501 steps_dict = db["steps"]
502 step_order = []
504 for step_name, step_data in steps_dict.items():
505 step_type = step_data["step_general"]["step_type"]
506 step_number = step_data["step_general"]["step_number"]
507 step_order.append((step_name, step_type, step_number))
509 # Flatten and convert to strings
510 flat_step = flatten_dict(step_data, sep="/")
511 flat_step["step_general/step_name"] = step_name
513 step = StepConfig(
514 index=step_number,
515 step_type=step_type,
516 name=step_name,
517 parameters={k: _value_to_string(v) for k, v in flat_step.items()},
518 version=float(yml_version),
519 )
520 steps_list.append(step)
522 # Sort by step number
523 steps_list.sort(key=lambda s: s.index)
525 # Create pipeline instance
526 pipeline = cls(
527 version=float(yml_version),
528 general=general,
529 steps=steps_list,
530 file_path=yaml_path,
531 )
533 # Ensure step_count in general is accurate
534 pipeline._update_step_count()
536 return pipeline
538 def to_yaml(self, yaml_path: Path):
539 """Save pipeline configuration to YAML file.
541 Args:
542 yaml_path: Path where YAML should be saved
543 """
544 config_dict = self.to_dict()
545 ut.dict_to_yml(config_dict, str(yaml_path))
546 self.file_path = yaml_path
548 def to_dict(self) -> Dict:
549 """Convert pipeline to dictionary suitable for YAML export.
551 Applies type conversion based on LUT to ensure parameters have
552 correct types (int, float, str, bool) instead of all being strings.
554 Returns:
555 Dictionary with version, general, and steps
556 """
557 # Convert general parameters with type checking
558 general_params = {}
559 for key, value in self.general.parameters.items():
560 if key != "step_type":
561 general_params[key] = value
563 # Apply type conversion based on LUT
564 general_params = _apply_type_conversion(general_params, "general", self.version)
566 # Remove any parameters not in LUT
567 general_lut = lut.get_lut("general", self.version)
568 general_lut.flatten()
569 general_params = {
570 k: v for k, v in general_params.items() if k in general_lut.keys()
571 }
572 general_params = _apply_type_conversion(general_params, "general", self.version)
574 general_dict = unflatten_dict(general_params, sep="/")
576 # Convert steps with type checking
577 steps_dict = {}
578 for step in self.steps:
579 step_params = {}
580 for key, value in step.parameters.items():
581 if key not in ["step_general/step_name", "step_general/step_number"]:
582 step_params[key] = value
584 # Add step number
585 step_params["step_general/step_number"] = step.index
587 # Apply type conversion based on LUT
588 step_params = _apply_type_conversion(
589 step_params, step.step_type, self.version
590 )
592 # Remove any parameters not in LUT
593 step_lut = lut.get_lut(step.step_type.lower(), self.version)
594 step_lut.flatten()
595 step_params = {
596 k: v
597 for k, v in step_params.items()
598 if k in step_lut.keys() or k == "step_general/step_number"
599 }
601 steps_dict[step.name] = unflatten_dict(step_params, sep="/")
603 return {
604 "config_file_version": self.version,
605 "general": general_dict,
606 "steps": steps_dict,
607 }
609 def set_version(self, new_version: float):
610 """Update pipeline version and migrate all parameters to new version.
612 This method:
613 1. Updates version for pipeline and all steps
614 2. Adds new parameters introduced in new version (with defaults)
615 3. Removes parameters that don't exist in new version
616 4. Preserves existing parameter values where applicable
618 Args:
619 new_version: Target version to migrate to
620 """
621 if new_version == self.version:
622 return # No change needed
624 # Update general step
625 self._migrate_step_to_version(self.general, new_version)
627 # Update all pipeline steps
628 for step in self.steps:
629 self._migrate_step_to_version(step, new_version)
631 # Update pipeline version
632 self.version = new_version
634 def _migrate_step_to_version(self, step: StepConfig, new_version: float):
635 """Migrate a single step to a new version.
637 Args:
638 step: Step to migrate
639 new_version: Target version
640 """
641 old_version = step.version
642 step_type = step.step_type
644 # Get LUTs for old and new versions
645 try:
646 old_lut = lut.get_lut(step_type.lower(), old_version)
647 old_lut.flatten()
648 except Exception:
649 # If old LUT doesn't exist, use empty dict
650 old_lut = lut.LUT()
651 old_lut.flatten()
653 try:
654 new_lut = lut.get_lut(step_type.lower(), new_version)
655 new_lut.flatten()
656 except Exception:
657 # If new LUT doesn't exist, can't migrate
658 print(f"Warning: Cannot get LUT for {step_type} version {new_version}")
659 step.version = new_version
660 return
662 # Get current parameters
663 current_params = deepcopy(step.parameters)
665 # Build new parameter set
666 new_params = {}
668 # Process all parameters in new LUT
669 for param_key, field in new_lut.items():
670 if param_key in current_params:
671 # Parameter exists in both - keep current value
672 new_params[param_key] = current_params[param_key]
673 else:
674 # New parameter - use default from LUT
675 default_value = field.default if field.default is not None else ""
676 new_params[param_key] = str(default_value)
678 # Note: Parameters that exist in old but not new are automatically dropped
679 # by not adding them to new_params
681 # Update step with migrated parameters
682 step.parameters = new_params
683 step.version = new_version
685 def __repr__(self) -> str:
686 return f"PipelineConfig(version={self.version}, steps={len(self.steps)})"
689def _value_to_string(value: Any) -> str:
690 """Convert value to string, handling None/null as empty string.
692 Special handling:
693 - None/null → empty string
694 - Booleans → keep as boolean (not converted to string)
695 - Everything else → string
697 Args:
698 value: Value to convert
700 Returns:
701 String representation (empty string for None, bool for booleans)
702 """
703 if value is None or value == "null":
704 return ""
705 # Preserve boolean type for checkbuttons
706 if isinstance(value, bool):
707 return value
708 return str(value)
711def flatten_dict(d: Dict, parent_key: str = "", sep: str = "/") -> Dict:
712 """Flatten nested dictionary to single level with separator.
714 Args:
715 d: Dictionary to flatten
716 parent_key: Parent key prefix
717 sep: Separator for nested keys
719 Returns:
720 Flattened dictionary
722 Example:
723 {'a': {'b': 1}} -> {'a/b': 1}
724 """
725 items = []
726 for k, v in d.items():
727 new_key = f"{parent_key}{sep}{k}" if parent_key else k
728 if isinstance(v, dict):
729 items.extend(flatten_dict(v, new_key, sep=sep).items())
730 else:
731 items.append((new_key, v))
732 return dict(items)
735def unflatten_dict(d: Dict, sep: str = "/") -> Dict:
736 """Unflatten dictionary with separators to nested structure.
738 Args:
739 d: Flattened dictionary
740 sep: Separator used in keys
742 Returns:
743 Nested dictionary
745 Example:
746 {'a/b': 1} -> {'a': {'b': 1}}
747 """
748 result = {}
749 for key, value in d.items():
750 parts = key.split(sep)
751 current = result
752 for part in parts[:-1]:
753 current = current.setdefault(part, {})
754 current[parts[-1]] = value
755 return result