Coverage for  / opt / hostedtoolcache / Python / 3.11.14 / x64 / lib / python3.11 / site-packages / rattlesnake / components / sine_sys_id_utilities.py: 6%

1322 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-27 18:22 +0000

1# -*- coding: utf-8 -*- 

2""" 

3Created on Mon Mar 31 10:19:37 2025 

4 

5@author: dprohe 

6""" 

7 

8import os 

9 

10# import time 

11 

12import numpy as np 

13from scipy.io import loadmat 

14from scipy.signal import windows 

15from scipy.sparse import linalg 

16from scipy import sparse 

17from scipy.signal import lfilter, lfiltic, butter 

18import pyqtgraph as pqtg 

19from qtpy import QtWidgets, uic 

20from qtpy.QtCore import Qt, QLocale # pylint: disable=no-name-in-module 

21 

22from .environments import ( 

23 sine_sweep_table_ui_path, 

24 filter_explorer_ui_path, 

25) 

26from .ui_utilities import VaryingNumberOfLinePlot 

27from .utilities import wrap 

28 

29 

30DEBUG = False 

31 

32if DEBUG: 

33 import pickle 

34 

35 

36def load_specification(spec_path): 

37 """Loads a sine specification from a .mat or .npz file 

38 

39 Assumes the phases are represented in degrees 

40 

41 Parameters 

42 ---------- 

43 spec_path : str 

44 The path to the specification to load. 

45 

46 Returns 

47 ------- 

48 frequency : ndarray 

49 The frequency breakpoint of the specificatyion. 

50 amplitude : ndarray 

51 The amplitudes at the breakpoints for each channel. 

52 phase : ndarray 

53 The phase in degrees of the specification 

54 sweep_type : ndarray 

55 The sweep types at each breakpoint. 

56 sweep_rate : ndarray 

57 The sweep rate at each breakpoint. 

58 warning : ndarray 

59 Upper, lower, left, and right warning levels at each 

60 breakpoint 

61 abort : ndarray 

62 Upper, lower, left, and right abort levels at each breakpoint. 

63 start_time : float 

64 The start time for the sine sweep. 

65 name : str 

66 The name of the sine sweep. 

67 

68 """ 

69 _, extension = os.path.splitext(spec_path) 

70 if extension.lower() == ".mat": 

71 data = loadmat(spec_path) 

72 else: 

73 data = np.load(spec_path) 

74 frequency = data["frequency"].flatten() 

75 amplitude = data["amplitude"] 

76 if "phase" in data: 

77 phase = data["phase"] # Degrees 

78 else: 

79 phase = None 

80 if "sweep_rate" in data: 

81 sweep_rate = data["sweep_rate"].flatten() 

82 else: 

83 sweep_rate = None 

84 if "sweep_type" in data: 

85 sweep_type = data["sweep_type"].flatten() 

86 else: 

87 sweep_type = None 

88 if "warning" in data: 

89 warning = data["warning"] 

90 else: 

91 warning = None 

92 if "abort" in data: 

93 abort = data["abort"] 

94 else: 

95 abort = None 

96 if "start_time" in data: 

97 start_time = data["start_time"] 

98 else: 

99 start_time = None 

100 if "name" in data: 

101 name = data["name"][()] 

102 else: 

103 name = None 

104 return ( 

105 frequency, 

106 amplitude, 

107 phase, 

108 sweep_type, 

109 sweep_rate, 

110 warning, 

111 abort, 

112 start_time, 

113 name, 

114 ) 

115 

116 

117def sine_sweep( 

118 dt, 

119 frequencies, 

120 sweep_rates, 

121 sweep_types, 

122 amplitudes=1, 

123 phases=0, 

124 return_argument=False, 

125 return_frequency=False, 

126 return_amplitude=False, 

127 return_phase=False, 

128 return_abscissa=False, 

129 only_breakpoints=False, 

130): 

131 """ 

132 Generates a sweeping sine wave with linear or logarithmic sweep rate 

133 

134 Parameters 

135 ---------- 

136 dt : float 

137 The time step of the output signal 

138 frequencies : iterable 

139 A list of frequency breakpoints for the sweep. Can be ascending or 

140 decending or both. Frequencies are specified in Hz, not rad/s. 

141 sweep_rates : iterable 

142 A list of sweep rates between the breakpoints. This array should have 

143 one fewer element than the `frequencies` array. The ith element of this 

144 array specifies the sweep rate between `frequencies[i]` and 

145 `frequencies[i+1]`. For a linear sweep, 

146 the rate is in Hz/s. For a logarithmic sweep, the rate is in octave/s. 

147 sweep_types : iterable or str 

148 The type of sweep to perform between each frequency breakpoint. Can be 

149 'lin' or 'log'. If a string is specified, it will be used for all 

150 breakpoints. Otherwise it should be an array containing strings with 

151 one fewer element than that of the `frequencies` array. 

152 amplitudes : iterable or float, optional 

153 Amplitude of the cosine wave at each of the frequency breakpoints. Can 

154 be specified as a single floating point value, or as an array with a 

155 value specified for each breakpoint. The default is 1. 

156 phases : iterable or float, optional 

157 Phases in radians of the cosine wave at each of the frequency breakpoints. Can 

158 be specified as a single floating point value, or as an array with a 

159 value specified for each breakpoint. Be aware that modifying the phase 

160 between breakpoints will effectively change the frequency of the signal, 

161 because the phase will change over time. The default is 0. 

162 return_argument : bool 

163 If True, return cosine argument over time 

164 return_frequency : bool 

165 If True, return the instantaneous frequency over time 

166 return_amplitude : bool 

167 If True, return the instantaneous amplitude over time 

168 return_phase : bool 

169 If True, return the instantaneous phase over time 

170 return_abscissa : bool 

171 If True, return the instantaneous abscissa over time 

172 only_breakpoints : bool 

173 If True, only returns data at breakpoints. Default is False 

174 

175 Raises 

176 ------ 

177 ValueError 

178 If the sweep rate and start and end frequency would result in a negative 

179 sweep time, for example if the start frequency is above the end frequency 

180 and a positive sweep rate is specified. 

181 

182 Returns 

183 ------- 

184 ordinate : np.ndarray 

185 A numpy array consisting of the generated sine sweep signal. The length 

186 of the signal will be determined by the frequency breakpoints and sweep 

187 rates. 

188 arg_over_time : np.ndarray 

189 A numpy array consisting of the argument to the cosine wave over time. 

190 freq_over_time : np.ndarray 

191 A numpy array consisting of the frequency of the cosine wave over time. 

192 amp_over_time : np.ndarray 

193 A numpy array consistsing of the amplitude of the cosine wave over time. 

194 phs_over_time : np.ndarray 

195 A numpy array consisting of the added phase in radians of the cosine 

196 wave over time. 

197 abscissa : np.ndarray 

198 A numpy array consisting of the time value at each time step returned 

199 

200 """ 

201 last_phase = 0 

202 last_abscissa = 0 

203 abscissa = [] 

204 ordinate = [] 

205 arg_over_time = [] 

206 freq_over_time = [] 

207 amp_over_time = [] 

208 phs_over_time = [] 

209 

210 # Go through each section 

211 for i in range(len(frequencies) - 1): 

212 # Extract the terms 

213 start_frequency = frequencies[i] 

214 end_frequency = frequencies[i + 1] 

215 omega_start = start_frequency * 2 * np.pi 

216 try: 

217 sweep_rate = sweep_rates[i] 

218 except TypeError: 

219 sweep_rate = sweep_rates 

220 if isinstance(sweep_types, str): 

221 sweep_type = sweep_types 

222 else: 

223 sweep_type = sweep_types[i] 

224 try: 

225 start_amplitude = amplitudes[i] 

226 end_amplitude = amplitudes[i + 1] 

227 except TypeError: 

228 start_amplitude = amplitudes 

229 end_amplitude = amplitudes 

230 try: 

231 start_phase = phases[i] # Radians 

232 end_phase = phases[i + 1] # Radians 

233 except TypeError: 

234 start_phase = phases # Radians 

235 end_phase = phases # Radians 

236 # Compute the length of this portion of the signal 

237 if sweep_type.lower() in ["lin", "linear"]: 

238 sweep_time = +(end_frequency - start_frequency) / sweep_rate 

239 elif sweep_type.lower() in ["log", "logarithmic"]: 

240 sweep_time = np.log(end_frequency / start_frequency) / (sweep_rate * np.log(2)) 

241 else: 

242 raise ValueError("Sweep type should be one of lin, linear, log, or logarithmic") 

243 if sweep_time < 0: 

244 raise ValueError(f"Sweep time for segment index {i} is negative. Check sweep rate.") 

245 sweep_samples = int(np.floor(sweep_time / dt)) 

246 # Construct the abscissa 

247 if only_breakpoints: 

248 this_abscissa = np.array([0, sweep_samples * dt]) 

249 else: 

250 this_abscissa = np.arange(sweep_samples + 1) * dt 

251 # Compute the phase over time 

252 if sweep_type.lower() in ["lin", "linear"]: 

253 this_argument = (1 / 2) * ( 

254 sweep_rate * 2 * np.pi 

255 ) * this_abscissa**2 + omega_start * this_abscissa 

256 this_frequency = (sweep_rate) * this_abscissa + omega_start / (2 * np.pi) 

257 elif sweep_type.lower() in ["log", "logarithmic"]: 

258 this_argument = 2 ** (sweep_rate * this_abscissa) * omega_start / ( 

259 sweep_rate * np.log(2) 

260 ) - omega_start / (sweep_rate * np.log(2)) 

261 this_frequency = 2 ** (sweep_rate * this_abscissa) * omega_start / (2 * np.pi) 

262 else: 

263 raise ValueError("Invalid sweep type, should be linear, lin, logarithmic, or log") 

264 # Compute the phase at each time step 

265 if end_frequency > start_frequency: 

266 freq_interp = [start_frequency, end_frequency] 

267 phase_interp = [start_phase, end_phase] 

268 amp_interp = [start_amplitude, end_amplitude] 

269 else: 

270 freq_interp = [end_frequency, start_frequency] 

271 phase_interp = [end_phase, start_phase] 

272 amp_interp = [end_amplitude, start_amplitude] 

273 this_phases = np.interp(this_frequency, freq_interp, phase_interp) 

274 # Compute the amplitude at each time step 

275 this_amplitudes = np.interp(this_frequency, freq_interp, amp_interp) 

276 this_ordinate = this_amplitudes * np.cos(this_argument + this_phases + last_phase) 

277 if i == len(frequencies) - 2: 

278 last_index = None # If it's the last segment, go up until the end 

279 else: 

280 last_index = -1 # Otherwise, we remove the last point because the first point of the 

281 # next segment will be this value 

282 arg_over_time.append(this_argument[:last_index] + last_phase) 

283 last_phase += this_argument[-1] 

284 abscissa.append(this_abscissa[:last_index] + last_abscissa) 

285 last_abscissa += this_abscissa[-1] 

286 ordinate.append(this_ordinate[:last_index]) 

287 freq_over_time.append(this_frequency[:last_index]) 

288 amp_over_time.append(this_amplitudes[:last_index]) 

289 phs_over_time.append(this_phases[:last_index]) 

290 ordinate = np.concatenate(ordinate) 

291 return_vals = [ordinate] 

292 if return_argument: 

293 return_vals.append(np.concatenate(arg_over_time)) 

294 if return_frequency: 

295 return_vals.append(np.concatenate(freq_over_time)) 

296 if return_amplitude: 

297 return_vals.append(np.concatenate(amp_over_time)) 

298 if return_phase: 

299 return_vals.append(np.concatenate(phs_over_time)) 

300 if return_abscissa: 

301 return_vals.append(np.concatenate(abscissa)) 

302 if len(return_vals) == 1: 

303 return_vals = return_vals[0] 

304 else: 

305 return_vals = tuple(return_vals) 

306 return return_vals 

307 

308 

309class NoWheelSpinBox(QtWidgets.QDoubleSpinBox): 

310 """A simple class to remove the scroll wheel capability from a spin box""" 

311 

312 def wheelEvent(self, event): # pylint: disable=invalid-name 

313 """Capture the wheel event but ignore it""" 

314 event.ignore() 

315 

316 

317class AdaptiveNoWheelSpinBox(NoWheelSpinBox): 

318 """A spinbox that changes number of decimals based on the value provided""" 

319 

320 localization = QLocale(QLocale.English, QLocale.UnitedStates) 

321 

322 def __init__(self, parent=None): 

323 super().__init__(parent) 

324 

325 self.setDecimals(10) 

326 

327 def textFromValue(self, value): # pylint: disable=invalid-name 

328 """Gets the text to show in the spinbox based on the value stored in the spinbox""" 

329 return AdaptiveNoWheelSpinBox.localization.toString(value, "g", self.decimals()) 

330 

331 

332class NoWheelComboBox(QtWidgets.QComboBox): 

333 """A simple class to remove the scroll wheel capability from a combo box""" 

334 

335 def wheelEvent(self, event): # pylint: disable=invalid-name 

336 """Capture the wheel event but ignore it""" 

337 event.ignore() 

338 

339 

340class SineSweepTable: 

341 """A class representing a breakpoint table defining a sine sweep""" 

342 

343 def __init__( 

344 self, 

345 parent_tabwidget: QtWidgets.QTabWidget, 

346 update_specification_function, 

347 remove_function, 

348 control_names, 

349 data_acquisition_parameters, 

350 ): 

351 """Initializes a sine sweep table to represent the breakpoints of a sine tone 

352 

353 Parameters 

354 ---------- 

355 parent_tabwidget : QtWidgets.QTabWidget 

356 The parent tabwidget in the sine ui class, which is needed to 

357 propogate changes in this widget back up to the main UI class 

358 update_specification_function : function 

359 The function to call to update the specification when the values 

360 in this table have changed 

361 remove_function : function 

362 The function to call when we remove a table from the tab widget 

363 control_names : list of str 

364 A list of strings to be used as the control channel names in the 

365 table 

366 data_acquisition_parameters : DataAcquisitionParameters 

367 The data acquisition parameters, including sample rate 

368 """ 

369 self.parent_tabwidget = parent_tabwidget 

370 self.update_specification_function = update_specification_function 

371 self.remove_function = remove_function 

372 self.control_names = control_names 

373 self.data_acquisition_parameters = data_acquisition_parameters 

374 self.widget = QtWidgets.QWidget() 

375 uic.loadUi(sine_sweep_table_ui_path, self.widget) 

376 self.index = self.parent_tabwidget.count() - 1 

377 self.parent_tabwidget.insertTab(self.index, self.widget, f"Sine {self.index+1}") 

378 self.widget.name_editor.setText(f"Sine {self.index+1}") 

379 self.parent_tabwidget.setCurrentIndex(self.index) 

380 self.connect_callbacks() 

381 self.clear_and_update_specification_table() 

382 

383 def connect_callbacks(self): 

384 """Connects the widgets in the UI to methods of the object""" 

385 self.widget.add_breakpoint_button.clicked.connect(self.add_breakpoint) 

386 self.widget.remove_breakpoint_button.clicked.connect(self.remove_breakpoint) 

387 self.widget.load_breakpoints_button.clicked.connect(self.load_specification) 

388 self.widget.name_editor.editingFinished.connect(self.update_name) 

389 self.widget.start_time_selector.valueChanged.connect(self.update_specification_function) 

390 self.widget.remove_tone_button.clicked.connect(self.remove_tone) 

391 

392 def add_breakpoint(self): 

393 """Adds a breakpoint to the table""" 

394 selected_indices = self.widget.breakpoint_table.selectedIndexes() 

395 if selected_indices: 

396 selected_row = selected_indices[0].row() 

397 else: 

398 # If no row is selected, add the row at the start 

399 selected_row = 0 

400 control_names = self.control_names 

401 self.widget.breakpoint_table.insertRow(selected_row) 

402 self.widget.warning_table.insertRow(selected_row) 

403 self.widget.abort_table.insertRow(selected_row) 

404 # Frequency display, Breakpoint Table 

405 spinbox = AdaptiveNoWheelSpinBox() 

406 spinbox.setRange(0, self.data_acquisition_parameters.sample_rate / 2) 

407 spinbox.setSingleStep(1) 

408 spinbox.setValue(0) 

409 spinbox.setKeyboardTracking(False) 

410 spinbox.valueChanged.connect(self.update_specification_function) 

411 self.widget.breakpoint_table.setCellWidget(selected_row, 0, spinbox) 

412 # Frequency display, warning table 

413 spinbox = AdaptiveNoWheelSpinBox() 

414 spinbox.setRange(0, self.data_acquisition_parameters.sample_rate / 2) 

415 spinbox.setSingleStep(1) 

416 spinbox.setValue(0) 

417 spinbox.setKeyboardTracking(False) 

418 spinbox.setReadOnly(True) 

419 spinbox.setButtonSymbols(AdaptiveNoWheelSpinBox.NoButtons) 

420 self.widget.warning_table.setCellWidget(selected_row, 0, spinbox) 

421 # Frequency display, abort table 

422 spinbox = AdaptiveNoWheelSpinBox() 

423 spinbox.setRange(0, self.data_acquisition_parameters.sample_rate / 2) 

424 spinbox.setSingleStep(1) 

425 spinbox.setValue(0) 

426 spinbox.setKeyboardTracking(False) 

427 spinbox.setReadOnly(True) 

428 self.widget.abort_table.setCellWidget(selected_row, 0, spinbox) 

429 # Linear or logarithmic selector 

430 combobox = NoWheelComboBox() 

431 combobox.addItems(["Linear", "Logarithmic"]) 

432 combobox.setCurrentIndex(0) 

433 combobox.currentIndexChanged.connect(self.update_specification_function) 

434 self.widget.breakpoint_table.setCellWidget(selected_row, 1, combobox) 

435 # Rate selector 

436 spinbox = AdaptiveNoWheelSpinBox() 

437 spinbox.setRange(-1000000, 1000000) 

438 spinbox.setSingleStep(1) 

439 spinbox.setValue(1) 

440 spinbox.setSuffix(" Hz/s") 

441 spinbox.setKeyboardTracking(False) 

442 spinbox.valueChanged.connect(self.update_specification_function) 

443 self.widget.breakpoint_table.setCellWidget(selected_row, 2, spinbox) 

444 # All of the individual values 

445 for j in range(len(control_names)): 

446 spinbox = AdaptiveNoWheelSpinBox() 

447 spinbox.setRange(0, 1000000) 

448 spinbox.setSingleStep(1) 

449 spinbox.setValue(1) 

450 spinbox.setKeyboardTracking(False) 

451 spinbox.valueChanged.connect(self.update_specification_function) 

452 self.widget.breakpoint_table.setCellWidget(selected_row, 3 + j * 2, spinbox) 

453 spinbox = AdaptiveNoWheelSpinBox() 

454 spinbox.setRange(-1000000, 1000000) 

455 spinbox.setSingleStep(1) 

456 spinbox.setValue(0) 

457 spinbox.setKeyboardTracking(False) 

458 spinbox.valueChanged.connect(self.update_specification_function) 

459 self.widget.breakpoint_table.setCellWidget(selected_row, 4 + j * 2, spinbox) 

460 for k in range(4): 

461 if selected_row == 0 and k in (0, 2): 

462 item = self.widget.warning_table.item(selected_row, 1 + k + j * 4) 

463 if item is None: 

464 item = QtWidgets.QTableWidgetItem() 

465 self.widget.warning_table.setItem(selected_row, 1 + k + j * 4, item) 

466 item.setFlags(item.flags() & ~Qt.ItemIsEditable) 

467 item = self.widget.abort_table.item(selected_row, 1 + k + j * 4) 

468 if item is None: 

469 item = QtWidgets.QTableWidgetItem() 

470 self.widget.abort_table.setItem(selected_row, 1 + k + j * 4, item) 

471 item.setFlags(item.flags() & ~Qt.ItemIsEditable) 

472 spinbox = AdaptiveNoWheelSpinBox() 

473 spinbox.setRange(0, 1000000) 

474 spinbox.setSingleStep(1) 

475 spinbox.setValue(0) 

476 spinbox.setKeyboardTracking(False) 

477 spinbox.setSpecialValueText("Disabled") 

478 spinbox.valueChanged.connect(self.update_specification_function) 

479 self.widget.warning_table.setCellWidget( 

480 selected_row + (1 if selected_row == 0 and k in (0, 2) else 0), 

481 1 + k + j * 4, 

482 spinbox, 

483 ) 

484 spinbox = AdaptiveNoWheelSpinBox() 

485 spinbox.setRange(0, 1000000) 

486 spinbox.setSingleStep(1) 

487 spinbox.setValue(0) 

488 spinbox.setKeyboardTracking(False) 

489 spinbox.setSpecialValueText("Disabled") 

490 spinbox.valueChanged.connect(self.update_specification_function) 

491 self.widget.abort_table.setCellWidget( 

492 selected_row + (1 if selected_row == 0 and k in (0, 2) else 0), 

493 1 + k + j * 4, 

494 spinbox, 

495 ) 

496 self.update_specification_function() 

497 

498 def remove_breakpoint(self): 

499 """Removes a breakpoint from the table""" 

500 selected_indices = self.widget.breakpoint_table.selectedIndexes() 

501 if selected_indices: 

502 selected_row = selected_indices[0].row() 

503 else: 

504 # If no row is selected, remove the last row 

505 selected_row = self.widget.breakpoint_table.rowCount() - 1 

506 if selected_row == self.widget.breakpoint_table.rowCount() - 1: 

507 last_row = True 

508 else: 

509 last_row = False 

510 if selected_row == 0: 

511 first_row = True 

512 else: 

513 first_row = False 

514 self.widget.breakpoint_table.removeRow(selected_row) 

515 self.widget.warning_table.removeRow(selected_row) 

516 self.widget.abort_table.removeRow(selected_row) 

517 if last_row: 

518 new_last_row_index = self.widget.breakpoint_table.rowCount() - 1 

519 for column in [1, 2]: 

520 widget = self.widget.breakpoint_table.cellWidget(new_last_row_index, column) 

521 if widget: 

522 # Remove the widget from the cell 

523 self.widget.breakpoint_table.removeCellWidget(new_last_row_index, column) 

524 widget.deleteLater() 

525 item = self.widget.breakpoint_table.item(new_last_row_index, column) 

526 if item is None: 

527 item = QtWidgets.QTableWidgetItem() 

528 self.widget.breakpoint_table.setItem(new_last_row_index, column, item) 

529 item.setFlags(item.flags() & ~Qt.ItemIsEditable) 

530 for column in np.arange(2, self.widget.warning_table.columnCount(), 2): 

531 for table in [self.widget.warning_table, self.widget.abort_table]: 

532 widget = table.cellWidget(new_last_row_index, column) 

533 if widget: 

534 # Remove the widget from the cell 

535 table.removeCellWidget(new_last_row_index, column) 

536 widget.deleteLater() 

537 item = table.item(new_last_row_index, column) 

538 if item is None: 

539 item = QtWidgets.QTableWidgetItem() 

540 table.setItem(new_last_row_index, column, item) 

541 item.setFlags(item.flags() & ~Qt.ItemIsEditable) 

542 if first_row: 

543 for column in np.arange(1, self.widget.warning_table.columnCount(), 2): 

544 for table in [self.widget.warning_table, self.widget.abort_table]: 

545 widget = table.cellWidget(0, column) 

546 if widget: 

547 # Remove the widget from the cell 

548 table.removeCellWidget(0, column) 

549 widget.deleteLater() 

550 item = table.item(0, column) 

551 if item is None: 

552 item = QtWidgets.QTableWidgetItem() 

553 table.setItem(0, column, item) 

554 item.setFlags(item.flags() & ~Qt.ItemIsEditable) 

555 self.update_specification_function() 

556 

557 def load_specification(self, clicked, filename=None): # pylint: disable=unused-argument 

558 """Loads a breakpoint table using a dialog or the specified filename 

559 

560 Parameters 

561 ---------- 

562 clicked : 

563 The clicked event that triggered the callback. 

564 filename : 

565 File name defining the specification for bypassing the callback when 

566 loading from a file (Default value = None). 

567 

568 """ 

569 if filename is None: 

570 filename, _ = QtWidgets.QFileDialog.getOpenFileName( 

571 self.widget, 

572 "Select Specification File", 

573 filter="Numpy or Mat (*.npy *.npz *.mat)", 

574 ) 

575 if filename == "": 

576 return 

577 ( 

578 frequencies, 

579 amplitudes, 

580 phases, # Degrees 

581 sweep_types, 

582 sweep_rates, 

583 warnings, 

584 aborts, 

585 start_time, 

586 name, 

587 ) = load_specification(filename) 

588 self.clear_and_update_specification_table( 

589 frequencies, 

590 amplitudes, 

591 phases, # Degrees 

592 sweep_types, 

593 sweep_rates, 

594 warnings, 

595 aborts, 

596 start_time, 

597 name, 

598 ) 

599 self.update_specification_function() 

600 

601 def clear_and_update_specification_table( 

602 self, 

603 frequencies=None, 

604 amplitudes=None, 

605 phases=None, 

606 sweep_types=None, 

607 sweep_rates=None, 

608 warning_amplitudes=None, 

609 abort_amplitudes=None, 

610 start_time=None, 

611 sine_name=None, 

612 control_names=None, 

613 ): 

614 """Clears the table and updates it with the optional parameters supplied. 

615 

616 Parameters 

617 ---------- 

618 frequencies : ndarray, optional 

619 A 1D array containing the frequencies to use as the breakpoints. By 

620 default, a table consisting of two breakpoints will be specified. 

621 amplitudes : ndarray, optional 

622 A 2D array consisting of amplitudes for (channel, frequency) pairs. 

623 If not specified, an amplitude of zero will be used. 

624 phases : ndarray, optional 

625 A 2D array consisting of phases for (channel, frequency) pairs. 

626 If not specified, an amplitude of zero will be used. Phases 

627 are specified in degrees. 

628 sweep_types : ndarray or list of strings, optional 

629 A 1D array of strings to use as the sweep type. They should 

630 be one of lin, log, linear, or logarithmic. Linear is used 

631 if not specified. 

632 sweep_rates : ndarray, optional 

633 A 1D array of values to use as the sweep rate. They should 

634 be in Hz/s for linear sweeps or octave per minute for 

635 logarithmic sweeps. 

636 warning_amplitudes : ndarray, optional 

637 A 4D ndarray with shape 2, 2, num_channels, num_frequencies. 

638 The first dimension specifies upper and lower limits, the 

639 second dimension specifies frequencies greater or lower than 

640 the frequency breakpoint. If a value is not desired, it 

641 should be set to nan. 

642 abort_amplitudes : ndarray, optional 

643 A 4D ndarray with shape 2, 2, num_channels, num_frequencies. 

644 The first dimension specifies upper and lower limits, the 

645 second dimension specifies frequencies greater or lower than 

646 the frequency breakpoint. If a value is not desired, it 

647 should be set to nan. 

648 start_time : float, optional 

649 The starting time for the specified sine tone. If not specified, 

650 it will be set to 0 

651 sine_name : str, optional 

652 The name of the sine tone used in the software. 

653 control_names : array of str, optional 

654 Channel names to use in the table. If not specified, the 

655 existing channel names will be used. 

656 """ 

657 # print(f'Called clear_and_update_specification with {control_names=}') 

658 # print(f'Called clear_and_update_specification with {start_time=}') 

659 # print(f'Called clear_and_update_specification with {sine_name=}') 

660 if control_names is not None: 

661 self.control_names = control_names 

662 control_names = self.control_names 

663 if frequencies is None: 

664 num_rows = 2 

665 else: 

666 num_rows = frequencies.size 

667 self.widget.breakpoint_table.clear() 

668 self.widget.breakpoint_table.setRowCount(num_rows) 

669 self.widget.breakpoint_table.setColumnCount(3 + 2 * len(control_names)) 

670 self.widget.warning_table.setRowCount(num_rows) 

671 self.widget.warning_table.setColumnCount(1 + 4 * len(control_names)) 

672 self.widget.abort_table.setRowCount(num_rows) 

673 self.widget.abort_table.setColumnCount(1 + 4 * len(control_names)) 

674 breakpoint_header_labels = ["Frequency", "Sweep Type", "Sweep Rate"] 

675 other_header_labels = ["Frequency"] 

676 for name in control_names: 

677 breakpoint_header_labels.append(name + " Amplitude") 

678 breakpoint_header_labels.append(name + " Phase") 

679 other_header_labels.append(name + " Lower Left") 

680 other_header_labels.append(name + " Lower Right") 

681 other_header_labels.append(name + " Upper Left") 

682 other_header_labels.append(name + " Upper Right") 

683 self.widget.breakpoint_table.setHorizontalHeaderLabels(breakpoint_header_labels) 

684 self.widget.warning_table.setHorizontalHeaderLabels(other_header_labels) 

685 self.widget.abort_table.setHorizontalHeaderLabels(other_header_labels) 

686 # Set up widgets in the table 

687 for row in range(num_rows): 

688 # Frequency Breakpoint 

689 spinbox = AdaptiveNoWheelSpinBox() 

690 spinbox.setRange(0, self.data_acquisition_parameters.sample_rate / 2) 

691 spinbox.setSingleStep(1) 

692 if frequencies is None: 

693 spinbox.setValue(0) 

694 else: 

695 spinbox.setValue(frequencies[row]) 

696 spinbox.setKeyboardTracking(False) 

697 spinbox.setDecimals(4) 

698 spinbox.valueChanged.connect(self.update_specification_function) 

699 self.widget.breakpoint_table.setCellWidget(row, 0, spinbox) 

700 # Frequency display, warning table 

701 spinbox = AdaptiveNoWheelSpinBox() 

702 spinbox.setRange(0, self.data_acquisition_parameters.sample_rate / 2) 

703 spinbox.setSingleStep(1) 

704 if frequencies is None: 

705 spinbox.setValue(0) 

706 else: 

707 spinbox.setValue(frequencies[row]) 

708 spinbox.setKeyboardTracking(False) 

709 spinbox.setReadOnly(True) 

710 spinbox.setButtonSymbols(AdaptiveNoWheelSpinBox.NoButtons) 

711 self.widget.warning_table.setCellWidget(row, 0, spinbox) 

712 # Frequency display, abort table 

713 spinbox = AdaptiveNoWheelSpinBox() 

714 spinbox.setRange(0, self.data_acquisition_parameters.sample_rate / 2) 

715 spinbox.setSingleStep(1) 

716 if frequencies is None: 

717 spinbox.setValue(0) 

718 else: 

719 spinbox.setValue(frequencies[row]) 

720 spinbox.setKeyboardTracking(False) 

721 spinbox.setReadOnly(True) 

722 spinbox.setButtonSymbols(AdaptiveNoWheelSpinBox.NoButtons) 

723 self.widget.abort_table.setCellWidget(row, 0, spinbox) 

724 # Rate and type 

725 if row < num_rows - 1: 

726 combobox = NoWheelComboBox() 

727 combobox.addItems(["Linear", "Logarithmic"]) 

728 if sweep_types is not None: 

729 if str(sweep_types[row]).lower() in ["lin", "linear"]: 

730 combobox.setCurrentIndex(0) 

731 else: 

732 combobox.setCurrentIndex(1) 

733 combobox.currentIndexChanged.connect(self.update_specification_function) 

734 self.widget.breakpoint_table.setCellWidget(row, 1, combobox) 

735 spinbox = AdaptiveNoWheelSpinBox() 

736 spinbox.setRange(-1000000, 1000000) 

737 spinbox.setSingleStep(1) 

738 if sweep_rates is not None: 

739 spinbox.setValue(sweep_rates[row]) 

740 else: 

741 spinbox.setValue(1) 

742 if combobox.currentIndex() == 0: 

743 spinbox.setSuffix(" Hz/s") 

744 else: 

745 spinbox.setSuffix(" oct/min") 

746 spinbox.setKeyboardTracking(False) 

747 spinbox.valueChanged.connect(self.update_specification_function) 

748 self.widget.breakpoint_table.setCellWidget(row, 2, spinbox) 

749 else: 

750 item = self.widget.breakpoint_table.item(row, 1) 

751 if item is None: 

752 item = QtWidgets.QTableWidgetItem() 

753 self.widget.breakpoint_table.setItem(row, 1, item) 

754 item.setFlags(item.flags() & ~Qt.ItemIsEditable) 

755 item = self.widget.breakpoint_table.item(row, 2) 

756 if item is None: 

757 item = QtWidgets.QTableWidgetItem() 

758 self.widget.breakpoint_table.setItem(row, 2, item) 

759 item.setFlags(item.flags() & ~Qt.ItemIsEditable) 

760 # Amplitude and Phases 

761 for j in range(len(control_names)): 

762 spinbox = AdaptiveNoWheelSpinBox() 

763 spinbox.setRange(0, 1000000) 

764 spinbox.setSingleStep(1) 

765 if amplitudes is None: 

766 spinbox.setValue(1) 

767 else: 

768 spinbox.setValue(amplitudes[j, row]) 

769 spinbox.setKeyboardTracking(False) 

770 spinbox.valueChanged.connect(self.update_specification_function) 

771 self.widget.breakpoint_table.setCellWidget(row, 3 + j * 2, spinbox) 

772 spinbox = AdaptiveNoWheelSpinBox() 

773 spinbox.setRange(-1000000, 1000000) 

774 spinbox.setSingleStep(1) 

775 if phases is None: 

776 spinbox.setValue(0) 

777 else: 

778 spinbox.setValue(phases[j, row]) 

779 spinbox.valueChanged.connect(self.update_specification_function) 

780 spinbox.setKeyboardTracking(False) 

781 self.widget.breakpoint_table.setCellWidget(row, 4 + j * 2, spinbox) 

782 for k in range(4): 

783 spinbox = AdaptiveNoWheelSpinBox() 

784 spinbox.setRange(0, 1000000) 

785 spinbox.setSingleStep(1) 

786 if ( 

787 row == 0 and k in (0, 2) 

788 ) or ( # If first frequency and looking at left side 

789 row == num_rows - 1 and k in (1, 3) 

790 ): # or if last frequency and looking at right side 

791 item = self.widget.warning_table.item(row, 1 + k + j * 4) 

792 if item is None: 

793 item = QtWidgets.QTableWidgetItem() 

794 self.widget.warning_table.setItem(row, 1 + k + j * 4, item) 

795 item.setFlags(item.flags() & ~Qt.ItemIsEditable) 

796 item = self.widget.abort_table.item(row, 1 + k + j * 4) 

797 if item is None: 

798 item = QtWidgets.QTableWidgetItem() 

799 self.widget.abort_table.setItem(row, 1 + k + j * 4, item) 

800 item.setFlags(item.flags() & ~Qt.ItemIsEditable) 

801 else: 

802 if warning_amplitudes is None: 

803 spinbox.setValue(0) 

804 else: 

805 val = warning_amplitudes[np.unravel_index(k, (2, 2)) + (j, row)] 

806 spinbox.setValue(0 if np.isnan(val) else val) 

807 spinbox.setKeyboardTracking(False) 

808 spinbox.setSpecialValueText("Disabled") 

809 spinbox.valueChanged.connect(self.update_specification_function) 

810 self.widget.warning_table.setCellWidget(row, 1 + k + j * 4, spinbox) 

811 spinbox = AdaptiveNoWheelSpinBox() 

812 spinbox.setRange(0, 1000000) 

813 spinbox.setSingleStep(1) 

814 if abort_amplitudes is None: 

815 spinbox.setValue(0) 

816 else: 

817 val = abort_amplitudes[np.unravel_index(k, (2, 2)) + (j, row)] 

818 spinbox.setValue(0 if np.isnan(val) else val) 

819 spinbox.setKeyboardTracking(False) 

820 spinbox.setSpecialValueText("Disabled") 

821 spinbox.valueChanged.connect(self.update_specification_function) 

822 self.widget.abort_table.setCellWidget(row, 1 + k + j * 4, spinbox) 

823 if sine_name is not None: 

824 self.widget.name_editor.setText(sine_name) 

825 self.update_name() 

826 if start_time is not None: 

827 self.widget.start_time_selector.setValue(start_time) 

828 

829 def update_name(self): 

830 """Called when the name of the sine tone is changed""" 

831 self.parent_tabwidget.setTabText(self.index, self.widget.name_editor.text()) 

832 

833 def remove_tone(self): 

834 """Called when the remove button is pressed""" 

835 self.remove_function(self.index) 

836 

837 def get_specification(self): 

838 """Computes a sine sweep specification from the table""" 

839 num_control = (self.widget.breakpoint_table.columnCount() - 3) // 2 

840 spec = SineSpecification( 

841 self.widget.name_editor.text(), 

842 self.widget.start_time_selector.value(), 

843 num_control, 

844 self.widget.breakpoint_table.rowCount(), 

845 ) 

846 for row, spec_row in enumerate(spec.breakpoint_table): 

847 spec_row["frequency"] = self.widget.breakpoint_table.cellWidget(row, 0).value() 

848 if row < len(spec.breakpoint_table) - 1: 

849 spec_row["sweep_type"] = self.widget.breakpoint_table.cellWidget( 

850 row, 1 

851 ).currentIndex() 

852 spec_row["sweep_rate"] = self.widget.breakpoint_table.cellWidget(row, 2).value() 

853 for i in range(num_control): 

854 spec_row["amplitude"][i] = self.widget.breakpoint_table.cellWidget( 

855 row, 3 + 2 * i 

856 ).value() 

857 spec_row["phase"][i] = ( 

858 self.widget.breakpoint_table.cellWidget(row, 4 + 2 * i).value() * np.pi / 180 

859 ) # Convert degrees to radians for all calculations 

860 for k in range(4): 

861 ind = np.unravel_index(k, (2, 2)) 

862 if (row == 0 and k in (0, 2)) or ( 

863 row == len(spec.breakpoint_table) - 1 and k in (1, 3) 

864 ): 

865 spec_row["warning"][ind + (i,)] = np.nan 

866 spec_row["abort"][ind + (i,)] = np.nan 

867 else: 

868 val = self.widget.warning_table.cellWidget(row, 1 + k + i * 4).value() 

869 spec_row["warning"][ind + (i,)] = np.nan if val == 0 else val 

870 val = self.widget.abort_table.cellWidget(row, 1 + k + i * 4).value() 

871 spec_row["abort"][ind + (i,)] = np.nan if val == 0 else val 

872 return spec 

873 

874 

875def digital_tracking_filter_generator( 

876 dt, 

877 cutoff_frequency_ratio=0.15, 

878 filter_order=2, 

879 phase_estimate=None, 

880 amplitude_estimate=None, 

881): 

882 """ 

883 Computes amplitudes and phases using a digital tracking filter 

884 

885 Parameters 

886 ---------- 

887 dt : float 

888 The time step of the signal 

889 cutoff_frequency_ratio : float 

890 The cutoff frequency of the low-pass filter compared to the lowest 

891 frequency sine tone in each block. Default is 0.15. 

892 filter_order : float 

893 The filter order of the low-pass butterworth filter. Default is 2. 

894 phase_estimate : float 

895 An estimate of the initial phase to seed the low-pass filter. 

896 amplitude_estimate : float 

897 An estimate of the initial amplitude to seed the low-pass filter. 

898 

899 Sends 

900 ------ 

901 xi : iterable 

902 The next block of the signal to be filtered 

903 fi : iterable 

904 The frequencies at the time steps in xi 

905 argsi : iterable 

906 The argument to a cosine function at the time steps in xi 

907 

908 Yields 

909 ------- 

910 amplitude : np.ndarray 

911 The amplitude at each time step 

912 phase : np.ndarray 

913 The phase at each time step 

914 """ 

915 # if plot_results: 

916 # fig,ax = plt.subplots(2,2,sharex=True) 

917 # ax[0,0].set_ylabel('Signal and Amplitude') 

918 # ax[0,1].set_ylabel('Phase') 

919 # ax[1,0].set_ylabel('Filtered COLA Signal (cos)') 

920 # ax[1,1].set_ylabel('Filtered COLA Signal (sin)') 

921 # sample_index = 0 

922 # fig.tight_layout() 

923 if phase_estimate is None: 

924 phase_estimate = 0 

925 if amplitude_estimate is None: 

926 amplitude_estimate = 0 

927 

928 xi_0_filt = None 

929 xi_90_filt = None 

930 xi_0 = None 

931 xi_90 = None 

932 amplitude = None 

933 phase = None 

934 while True: 

935 xi, fi, argsi = yield amplitude, phase 

936 xi = np.array(xi) 

937 fi = np.array(fi) 

938 argsi = np.array(argsi) 

939 # print(f"{cutoff_frequency_ratio=}") 

940 # print(f"{np.min(fi)=}") 

941 # print(f"{cutoff_frequency_ratio*np.min(fi)=}") 

942 b, a = butter(filter_order, cutoff_frequency_ratio * np.min(fi), fs=1 / dt) 

943 if xi_0_filt is None: 

944 # Set up some fake data to initialize the filter to a good value 

945 past_ts = np.arange(-filter_order * 2 - 1, 0) * dt 

946 past_xs = amplitude_estimate * np.cos(2 * np.pi * fi[0] * past_ts + phase_estimate) 

947 xi_0 = np.cos(2 * np.pi * fi[0] * past_ts) * past_xs 

948 xi_90 = -np.sin(2 * np.pi * fi[0] * past_ts) * past_xs 

949 xi_0_filt = 0.5 * amplitude_estimate * np.cos(phase_estimate) * np.ones(xi_0.shape) 

950 xi_90_filt = 0.5 * amplitude_estimate * np.sin(phase_estimate) * np.ones(xi_90.shape) 

951 # if plot_results: 

952 # ax[1,0].plot(past_ts,xi_0,'r') 

953 # ax[1,0].plot(past_ts,xi_0_filt,'m') 

954 # ax[1,1].plot(past_ts,xi_90,'r') 

955 # ax[1,1].plot(past_ts,xi_90_filt,'m') 

956 # Set up the filter initial states 

957 z0i = lfiltic(b, a, xi_0_filt[::-1], xi_0[::-1]) 

958 z90i = lfiltic(b, a, xi_90_filt[::-1], xi_90[::-1]) 

959 # Now set up the tracking filter 

960 cola0 = np.cos(argsi) 

961 cola90 = -np.sin(argsi) 

962 xi_0 = cola0 * xi 

963 xi_90 = cola90 * xi 

964 xi_0_filt, z0i = lfilter(b, a, xi_0, zi=z0i) 

965 xi_90_filt, z90i = lfilter(b, a, xi_90, zi=z90i) 

966 phase = np.arctan2(xi_90_filt, xi_0_filt) 

967 amplitude = 2 * np.sqrt(xi_0_filt**2 + xi_90_filt**2) 

968 # if plot_results: 

969 # ti = np.arange(sample_index,sample_index + xi.shape[-1])*dt 

970 # ax[0,0].plot(ti,xi,'b') 

971 # ax[0,0].plot(ti,amplitude,'g') 

972 # ax[0,1].plot(ti,phase,'g') 

973 # ax[1,0].plot(ti,xi_0,'b') 

974 # ax[1,0].plot(ti,xi_0_filt,'g') 

975 # ax[1,1].plot(ti,xi_90,'b') 

976 # ax[1,1].plot(ti,xi_90_filt,'g') 

977 # sample_index += xi.shape[-1] 

978 

979 

980class DefaultSineControlLaw: 

981 """A default control law for the sine environment""" 

982 

983 def __unpickleable_fields__(self): 

984 """Defines fields that can't be pickled in the case of an error""" 

985 return ["tracking_filters"] 

986 

987 def __getstate__(self): 

988 """Defines how the object is pickled""" 

989 state = self.__dict__.copy() 

990 for field in self.__unpickleable_fields__(): 

991 if field in state: 

992 del state[field] 

993 return state 

994 

995 def __setstate__(self, state): 

996 """Defines how the object is restored from a pickle""" 

997 self.__dict__.update(state) 

998 for field in self.__unpickleable_fields__(): 

999 setattr(self, field, None) 

1000 

1001 def __init__( 

1002 self, 

1003 sample_rate, # Sample Rate of the data acquisition 

1004 specifications, # Specification structured array 

1005 output_oversample, # Oversampling required for output 

1006 ramp_time, # Length of the ramp on the start and end of the signal 

1007 convergence_factor, # Scale factor on the convergence of the closed-loop control 

1008 block_size, # Size of writing blocks 

1009 buffer_blocks, # Number of write blocks to keep in the buffer 

1010 extra_control_parameters, # Required parameters 

1011 sysid_frequency_spacing, # Frequency Spacing 

1012 sysid_transfer_functions, # Transfer Functions 

1013 sysid_response_noise, # Noise levels and correlation 

1014 sysid_reference_noise, # from the system identification 

1015 sysid_response_cpsd, # Response levels and correlation 

1016 sysid_reference_cpsd, # from the system identification 

1017 sysid_coherence, # Coherence from the system identification 

1018 sysid_frames, # Number of frames in the FRF matrices 

1019 ): 

1020 """ 

1021 Initialize the sine control law 

1022 

1023 Parameters 

1024 ---------- 

1025 sample_rate : float 

1026 The sample rate of the data acquisition system. 

1027 specifications : list of SineSpecification 

1028 One SineSpecification object for each tone in the environment. 

1029 output_oversample : int 

1030 The oversample factor on the output. 

1031 ramp_time : float 

1032 The time to ramp up in level to start or ramp down to end. 

1033 convergence_factor : float 

1034 A value between 0 and 1 specifying how quickly the error correction 

1035 should take place. 

1036 block_size : int 

1037 The number of samples generated for each analysis and error 

1038 correction step 

1039 buffer_blocks : int 

1040 The number of blocks to keep in a buffer to ensure we don't run 

1041 out of data to generate. 

1042 extra_control_parameters : str 

1043 A string containing any extr information the control law needs 

1044 to know about. 

1045 sysid_frequency_spacing : float 

1046 The frequency spacing in the transfer function. 

1047 sysid_transfer_functions : ndarray 

1048 A 3D ndarray with shape num_freq, num_control, num_drive. 

1049 sysid_response_noise : ndarray 

1050 A 3D ndarray containing CPSDs with values populated from the 

1051 noise floor check. Shape is num_freq, num_control, num_control. 

1052 sysid_reference_noise : ndarray 

1053 A 3D ndarray containing CPSDs with the values populated from the 

1054 noise floor check. Shape is num_freq, num_drive, num_drive. 

1055 sysid_response_cpsd : ndarray 

1056 A 3D ndarray containing CPSDs with values populated from the 

1057 system identification. Shape is num_freq, num_control, num_control. 

1058 sysid_reference_cpsd : ndarray 

1059 A 3D ndarray containing CPSDs with the values populated from the 

1060 system identification. Shape is num_freq, num_drive, num_drive. 

1061 sysid_coherence : ndarray 

1062 Multiple coherence of the control channels from the system 

1063 identification 

1064 sysid_frames : int 

1065 NNumber of frames used in the system identification. 

1066 

1067 """ 

1068 # start_time = time.time() 

1069 self.block_size = block_size 

1070 self.sample_rate = sample_rate 

1071 self.buffer_blocks = buffer_blocks 

1072 self.specifications = [spec.copy() for spec in specifications] 

1073 self.output_oversample = output_oversample 

1074 self.extra_control_parameters = extra_control_parameters 

1075 self.ramp_samples = int(ramp_time * sample_rate) * output_oversample 

1076 self.convergence_factor = convergence_factor 

1077 # Loop through the different specifications and perform the initial control 

1078 ( 

1079 self.specified_response, 

1080 self.specified_tone_response, 

1081 self.specified_frequency, 

1082 self.specified_argument, 

1083 self.specified_amplitude, 

1084 self.specified_phase, # Radians 

1085 _, 

1086 _, 

1087 ) = SineSpecification.create_combined_signals( 

1088 self.specifications, 

1089 self.sample_rate * self.output_oversample, 

1090 self.ramp_samples, 

1091 ) 

1092 self.specified_phase = self.specified_phase # Radians 

1093 self.tone_slices = [] 

1094 for amp in self.specified_amplitude: 

1095 nonzero_indices = np.any(amp != 0, axis=0) 

1096 self.tone_slices.append( 

1097 slice( 

1098 np.argmax(nonzero_indices), 

1099 nonzero_indices.size - np.argmax(nonzero_indices[::-1]), 

1100 ) 

1101 ) 

1102 

1103 # System ID Parameters 

1104 self.frfs = None 

1105 self.frf_frequency_spacing = None 

1106 self.frf_frequencies = None 

1107 self.sysid_response_noise = None 

1108 self.sysid_reference_noise = None 

1109 self.sysid_response_cpsd = None 

1110 self.sysid_reference_cpsd = None 

1111 self.sysid_coherence = None 

1112 self.frames = None 

1113 self.frf_pinv = None 

1114 self.interpolated_frf_pinv = None 

1115 self.largest_correction_factors = None 

1116 

1117 # Preshaped drive parameters 

1118 self.preshaped_drive_amplitudes = None 

1119 self.preshaped_drive_phases = None # Radians 

1120 self.preshaped_drive_signals = None 

1121 

1122 # Control parameters 

1123 self.start_index = None 

1124 self.end_index = None 

1125 self.signal_slice = None 

1126 self.control_tones = None 

1127 self.control_ramp_up = None 

1128 self.control_ramp_down = None 

1129 self.target_ramp_up = None 

1130 self.target_ramp_down = None 

1131 self.control_response_signals = None 

1132 self.control_response_amplitudes = None 

1133 self.control_response_phases = None # Radians 

1134 self.control_drive_correction = None 

1135 self.control_sent_complex_excitation = None 

1136 self.control_analysis_index = None 

1137 self.control_write_index = None 

1138 self.max_singular_values = None 

1139 

1140 self.maximum_drive_voltage = None 

1141 self.harddisk_storage = None 

1142 for string in extra_control_parameters.split("\n"): 

1143 if string.strip() == "": 

1144 continue 

1145 try: 

1146 command, value = string.split("=") 

1147 except ValueError: 

1148 print(f'Unable to Parse Extra Parameters Line"{string:}"') 

1149 continue 

1150 command = command.strip() 

1151 value = value.strip() 

1152 if command in [ 

1153 "maximum_drive_voltage", 

1154 "max_drive_voltage", 

1155 "maximum_excitation_voltage", 

1156 "max_excitation_voltage", 

1157 ]: 

1158 self.maximum_drive_voltage = float(value) 

1159 print(f"Set Maximum Drive Voltage to {self.maximum_drive_voltage}") 

1160 elif command in ["harddisk_storage"]: 

1161 self.harddisk_storage = value 

1162 else: 

1163 print(f"Unknown extra parameter {command}") 

1164 

1165 if sysid_transfer_functions is not None: 

1166 self.system_id_update( 

1167 sysid_frequency_spacing, # Frequency Spacing 

1168 sysid_transfer_functions, # Transfer Functions 

1169 sysid_response_noise, # Noise levels and correlation 

1170 sysid_reference_noise, # from the system identification 

1171 sysid_response_cpsd, # Response levels and correlation 

1172 sysid_reference_cpsd, # from the system identification 

1173 sysid_coherence, # Coherence from the system identification 

1174 sysid_frames, # Number of frames in the FRF matrices 

1175 ) 

1176 # finish_time = time.time() 

1177 # print(f'__init__ called in {finish_time - start_time:0.2f}s.') 

1178 

1179 def system_id_update( 

1180 self, 

1181 sysid_frequency_spacing, # Frequency Spacing 

1182 sysid_transfer_functions, # Transfer Functions 

1183 sysid_response_noise, # Noise levels and correlation 

1184 sysid_reference_noise, # from the system identification 

1185 sysid_response_cpsd, # Response levels and correlation 

1186 sysid_reference_cpsd, # from the system identification 

1187 sysid_coherence, # Coherence from the system identification 

1188 sysid_frames, # Number of frames in the FRF matrices 

1189 ): 

1190 """ 

1191 Updates the control law after system identification has finished 

1192 

1193 Parameters 

1194 ---------- 

1195 sysid_frequency_spacing : float 

1196 The frequency spacing in the transfer function. 

1197 sysid_transfer_functions : ndarray 

1198 A 3D ndarray with shape num_freq, num_control, num_drive. 

1199 sysid_response_noise : ndarray 

1200 A 3D ndarray containing CPSDs with values populated from the 

1201 noise floor check. Shape is num_freq, num_control, num_control. 

1202 sysid_reference_noise : ndarray 

1203 A 3D ndarray containing CPSDs with the values populated from the 

1204 noise floor check. Shape is num_freq, num_drive, num_drive. 

1205 sysid_response_cpsd : ndarray 

1206 A 3D ndarray containing CPSDs with values populated from the 

1207 system identification. Shape is num_freq, num_control, num_control. 

1208 sysid_reference_cpsd : ndarray 

1209 A 3D ndarray containing CPSDs with the values populated from the 

1210 system identification. Shape is num_freq, num_drive, num_drive. 

1211 sysid_coherence : ndarray 

1212 Multiple coherence of the control channels from the system 

1213 identification 

1214 sysid_frames : int 

1215 Number of frames used in the system identification. 

1216 

1217 Returns 

1218 ------- 

1219 preshaped_drive_signals : ndarray 

1220 Excitation signals with shape num_tones, num_drives, num_timesteps. 

1221 specified_frequency : ndarray 

1222 The instantaneous frequencies at each of the drive timesteps with 

1223 shape num_tones, num_timesteps 

1224 specified_argument : ndarray 

1225 The instantaneous sine argument at each of the drive timesteps with 

1226 shape num_tones, num_drives, num_timesteps 

1227 preshaped_drive_amplitudes : ndarray 

1228 The instantaneous amplitude at each of the drive timesteps with 

1229 shape num_tones, num_drives, num_timesteps 

1230 preshaped_drive_phases : ndarray 

1231 The instantaneous phase in radians at each of the drive timesteps 

1232 with shape num_tones, num_drives, num_timesteps. 

1233 """ 

1234 # start_time = time.time() 

1235 # print('Updating System ID Information') 

1236 self.frf_frequency_spacing = sysid_frequency_spacing 

1237 self.frfs = sysid_transfer_functions 

1238 self.frf_frequencies = self.frf_frequency_spacing * np.arange(self.frfs.shape[0]) 

1239 # print('Inverting FRFs') 

1240 self.frf_pinv = np.linalg.pinv(self.frfs) 

1241 self.max_singular_values = np.max( 

1242 np.linalg.svd(self.frfs, compute_uv=False, full_matrices=False), axis=-1 

1243 ) 

1244 self.sysid_response_noise = sysid_response_noise 

1245 self.sysid_reference_noise = sysid_reference_noise 

1246 self.sysid_response_cpsd = sysid_response_cpsd 

1247 self.sysid_reference_cpsd = sysid_reference_cpsd 

1248 self.sysid_coherence = sysid_coherence 

1249 self.frames = sysid_frames 

1250 

1251 # Go through and compute the response amplitudes and phases from each of the sine tones 

1252 # print('Preallocating Amplitude, Phase, FRFs, and Correction Factors') 

1253 if self.harddisk_storage is not None: 

1254 filename = os.path.join(self.harddisk_storage, "preshaped_drive_amplitudes.mmap") 

1255 shape = ( 

1256 self.specified_frequency.shape[0], 

1257 self.frfs.shape[-1], 

1258 self.specified_frequency.shape[-1], 

1259 ) 

1260 self.preshaped_drive_amplitudes = np.memmap( 

1261 filename, dtype=float, shape=shape, mode="w+" 

1262 ) 

1263 filename = os.path.join(self.harddisk_storage, "preshaped_drive_phases.mmap") 

1264 shape = ( 

1265 self.specified_frequency.shape[0], 

1266 self.frfs.shape[-1], 

1267 self.specified_frequency.shape[-1], 

1268 ) 

1269 self.preshaped_drive_phases = np.memmap(filename, dtype=float, shape=shape, mode="w+") 

1270 filename = os.path.join(self.harddisk_storage, "largest_correction_factors.mmap") 

1271 shape = self.specified_frequency.shape 

1272 self.largest_correction_factors = np.memmap( 

1273 filename, dtype=float, shape=shape, mode="w+" 

1274 ) 

1275 filename = os.path.join(self.harddisk_storage, "interpolated_frf_pinv.mmap") 

1276 shape = ( 

1277 self.specified_frequency.shape[0], 

1278 self.specified_frequency.shape[-1], 

1279 ) + self.frf_pinv.shape[-2:] 

1280 self.interpolated_frf_pinv = np.memmap(filename, dtype="c16", shape=shape, mode="w+") 

1281 else: 

1282 self.preshaped_drive_amplitudes = np.zeros( 

1283 ( 

1284 self.specified_frequency.shape[0], 

1285 self.frfs.shape[-1], 

1286 self.specified_frequency.shape[-1], 

1287 ) 

1288 ) 

1289 self.preshaped_drive_phases = np.zeros( # Radians 

1290 ( 

1291 self.specified_frequency.shape[0], 

1292 self.frfs.shape[-1], 

1293 self.specified_frequency.shape[-1], 

1294 ) 

1295 ) 

1296 self.largest_correction_factors = np.zeros(self.specified_frequency.shape) 

1297 self.interpolated_frf_pinv = np.zeros( 

1298 ( 

1299 self.specified_frequency.shape[0], 

1300 self.specified_frequency.shape[-1], 

1301 ) 

1302 + self.frf_pinv.shape[-2:], 

1303 dtype="c16", 

1304 ) 

1305 for tone_index, (freq, amp, phs, control_slice) in enumerate( # phs is radians 

1306 zip( 

1307 self.specified_frequency, 

1308 self.specified_amplitude, 

1309 self.specified_phase, # Radians 

1310 self.tone_slices, 

1311 ) 

1312 ): 

1313 # print('Interpolating Tone {:}'.format(tone_index)) 

1314 control_amp = amp[..., control_slice] 

1315 control_phs = phs[..., control_slice] # Radians 

1316 control_freq = freq[..., control_slice] 

1317 # Interpolate the pseudoinverse of the FRF 

1318 for index in np.ndindex(*self.frf_pinv.shape[1:]): 

1319 # print(' Interpolating Response {:}'.format(index)) 

1320 interpolated_index = (tone_index, control_slice) + index 

1321 frf_pinv_index = (Ellipsis,) + index 

1322 self.interpolated_frf_pinv[interpolated_index] = np.interp( 

1323 control_freq, self.frf_frequencies, self.frf_pinv[frf_pinv_index] 

1324 ) 

1325 # print('Computing Largest Correction Factors') 

1326 self.largest_correction_factors[tone_index, control_slice] = ( 

1327 1 / np.interp(control_freq, self.frf_frequencies, self.max_singular_values) ** 2 

1328 ) 

1329 # print('Computing Complex Response') 

1330 complex_response = np.moveaxis( 

1331 control_amp * np.exp(1j * control_phs), -1, 0 

1332 )[ # Radians 

1333 ..., np.newaxis 

1334 ] 

1335 # print('Computing Complex Excitation') 

1336 complex_excitation = np.moveaxis( 

1337 (self.interpolated_frf_pinv[tone_index, control_slice] @ complex_response)[..., 0], 

1338 0, 

1339 1, 

1340 ) 

1341 # print('Extracting Excitation Amplitude and Phase') 

1342 self.preshaped_drive_amplitudes[tone_index, :, control_slice] = np.abs( 

1343 complex_excitation 

1344 ) 

1345 self.preshaped_drive_phases[tone_index, :, control_slice] = np.unwrap( 

1346 np.angle(complex_excitation) 

1347 ) # Radians 

1348 

1349 if self.maximum_drive_voltage is not None: 

1350 # print('Truncating for Maximum Voltage') 

1351 self.preshaped_drive_amplitudes[ 

1352 self.preshaped_drive_amplitudes > self.maximum_drive_voltage 

1353 ] = self.maximum_drive_voltage 

1354 self.preshaped_drive_amplitudes[ 

1355 self.preshaped_drive_amplitudes < -self.maximum_drive_voltage 

1356 ] = -self.maximum_drive_voltage 

1357 # print('Computing Excitation Signal') 

1358 self.preshaped_drive_signals = self.preshaped_drive_amplitudes * np.cos( 

1359 self.specified_argument[:, np.newaxis, :] + self.preshaped_drive_phases # Radians 

1360 ) 

1361 

1362 if self.harddisk_storage is not None: 

1363 # print('Flushing memmaps') 

1364 self.preshaped_drive_amplitudes.flush() 

1365 self.preshaped_drive_phases.flush() # Radians 

1366 self.largest_correction_factors.flush() 

1367 self.interpolated_frf_pinv.flush() 

1368 

1369 if DEBUG: 

1370 print("Writing Sine Debug Pickle") 

1371 with open("debug_data/sine_control_law_debug.pkl", "wb") as f: 

1372 pickle.dump(self, f) 

1373 print("Done!") 

1374 

1375 # finish_time = time.time() 

1376 # print(f'system_id_update called in {finish_time - start_time:0.2f}s.') 

1377 

1378 return ( 

1379 self.preshaped_drive_signals, 

1380 self.specified_frequency, 

1381 self.specified_argument, 

1382 self.preshaped_drive_amplitudes, 

1383 self.preshaped_drive_phases, # Radians for return value 

1384 ) 

1385 

1386 def get_control_targets(self, block_start, block_end): 

1387 """Gets up the control targets for a specified block, adding ramps 

1388 

1389 Parameters 

1390 ---------- 

1391 block_start : int 

1392 The starting index for the block of data that targets are computed from 

1393 block_end : int 

1394 The end index for the block of data that targets are computed from 

1395 

1396 Returns 

1397 ------- 

1398 amplitudes 

1399 The amplitudes over time including ramp up and ramp down 

1400 phases 

1401 The phases over time including ramp up and ramp down portions, in radians 

1402 arguments 

1403 The argument of the cosine wave over time including the ramp 

1404 up and ramp down portions 

1405 """ 

1406 ramp_up_start = block_start - self.ramp_samples 

1407 if ramp_up_start >= 0: 

1408 ramp_up_start = self.ramp_samples 

1409 ramp_up_end = block_end - self.ramp_samples 

1410 if ramp_up_end >= 0: 

1411 ramp_up_end = self.ramp_samples 

1412 ramp_down_start = block_start - self.end_index + self.start_index + self.ramp_samples 

1413 if ramp_down_start < 0: 

1414 ramp_down_start = 0 

1415 ramp_down_end = block_end - self.end_index + self.start_index + self.ramp_samples 

1416 if ramp_down_end < 0: 

1417 ramp_down_end = 0 

1418 middle_start = block_start 

1419 if middle_start < self.ramp_samples: 

1420 middle_start = self.ramp_samples 

1421 if middle_start > self.end_index - self.start_index - self.ramp_samples: 

1422 middle_start = self.end_index - self.start_index - self.ramp_samples 

1423 middle_end = block_end 

1424 if middle_end < self.ramp_samples: 

1425 middle_end = self.ramp_samples 

1426 if middle_end > self.end_index - self.start_index - self.ramp_samples: 

1427 middle_end = self.end_index - self.start_index - self.ramp_samples 

1428 amplitudes = np.concatenate( 

1429 ( 

1430 self.target_ramp_up[..., ramp_up_start:ramp_up_end], 

1431 self.specified_amplitude[ 

1432 self.control_tones, 

1433 ..., 

1434 self.start_index + middle_start : self.start_index + middle_end, 

1435 ], 

1436 self.target_ramp_down[..., ramp_down_start:ramp_down_end], 

1437 ), 

1438 axis=-1, 

1439 ) 

1440 phases = self.specified_phase[ # Radians 

1441 self.control_tones, 

1442 ..., 

1443 self.start_index + block_start : self.start_index + block_start + amplitudes.shape[-1], 

1444 ] 

1445 arguments = self.specified_argument[ 

1446 self.control_tones, 

1447 self.start_index + block_start : self.start_index + block_start + amplitudes.shape[-1], 

1448 ] 

1449 return amplitudes, phases, arguments # Radians 

1450 

1451 def get_control_preshaped_excitations(self, block_start, block_end): 

1452 """Gets the initial guess at an excitation signal over a portion of time 

1453 

1454 Parameters 

1455 ---------- 

1456 block_start : int 

1457 The starting index for the block of data that excitations are computed from 

1458 block_end : int 

1459 The end index for the block of data that excitations are computed from 

1460 

1461 Returns 

1462 ------- 

1463 amplitudes 

1464 The amplitudes over time including ramp up and ramp down 

1465 phases 

1466 The phases over time including ramp up and ramp down portions 

1467 arguments 

1468 The argument of the cosine wave over time including the ramp 

1469 up and ramp down portions, in radians 

1470 """ 

1471 ramp_up_start = block_start - self.ramp_samples 

1472 if ramp_up_start >= 0: 

1473 ramp_up_start = self.ramp_samples 

1474 ramp_up_end = block_end - self.ramp_samples 

1475 if ramp_up_end >= 0: 

1476 ramp_up_end = self.ramp_samples 

1477 ramp_down_start = block_start - self.end_index + self.start_index + self.ramp_samples 

1478 if ramp_down_start < 0: 

1479 ramp_down_start = 0 

1480 ramp_down_end = block_end - self.end_index + self.start_index + self.ramp_samples 

1481 if ramp_down_end < 0: 

1482 ramp_down_end = 0 

1483 middle_start = block_start 

1484 if middle_start < self.ramp_samples: 

1485 middle_start = self.ramp_samples 

1486 if middle_start > self.end_index - self.start_index - self.ramp_samples: 

1487 middle_start = self.end_index - self.start_index - self.ramp_samples 

1488 middle_end = block_end 

1489 if middle_end < self.ramp_samples: 

1490 middle_end = self.ramp_samples 

1491 if middle_end > self.end_index - self.start_index - self.ramp_samples: 

1492 middle_end = self.end_index - self.start_index - self.ramp_samples 

1493 amplitudes = np.concatenate( 

1494 ( 

1495 self.control_ramp_up[..., ramp_up_start:ramp_up_end], 

1496 self.preshaped_drive_amplitudes[ 

1497 self.control_tones, 

1498 ..., 

1499 self.start_index + middle_start : self.start_index + middle_end, 

1500 ], 

1501 self.control_ramp_down[..., ramp_down_start:ramp_down_end], 

1502 ), 

1503 axis=-1, 

1504 ) 

1505 phases = self.preshaped_drive_phases[ # Radians 

1506 self.control_tones, 

1507 ..., 

1508 self.start_index + block_start : self.start_index + block_start + amplitudes.shape[-1], 

1509 ] 

1510 arguments = self.specified_argument[ 

1511 self.control_tones, 

1512 self.start_index + block_start : self.start_index + block_start + amplitudes.shape[-1], 

1513 ] 

1514 return amplitudes, phases, arguments # Phase in Radians 

1515 

1516 def initialize_control(self, control_tones, start_index, end_index): 

1517 """ 

1518 Initializes the control and creates a preshaped drive signal 

1519 

1520 Aguments are provided to specify which tones and portion of time 

1521 to generate the signal over 

1522 

1523 Parameters 

1524 ---------- 

1525 control_tones : ndarray or slice 

1526 Indicies into the specifications to determine which control tones 

1527 should be used. 

1528 start_index : int 

1529 The starting time step index. 

1530 end_index : int 

1531 The ending time step index. 

1532 

1533 Returns 

1534 ------- 

1535 excitation_signals : ndarray 

1536 The drive signal at each shaker over time 

1537 

1538 """ 

1539 # start_time = time.time() 

1540 # Parse the frequency content to get the portion of the signal we care 

1541 # about 

1542 self.control_tones = control_tones 

1543 self.start_index = start_index 

1544 self.end_index = self.preshaped_drive_signals.shape[-1] if end_index is None else end_index 

1545 

1546 if DEBUG: 

1547 print("Writing Sine Debug Pickle") 

1548 with open("debug_data/sine_control_law_initialize_control_debug.pkl", "wb") as f: 

1549 pickle.dump(self, f) 

1550 print("Done!") 

1551 

1552 # Set up the analysis and write_indices 

1553 self.control_analysis_index = 0 

1554 self.control_write_index = self.ramp_samples + self.buffer_blocks * self.block_size 

1555 

1556 # Set up the ramp-ups and ramp downs for the excitation signal 

1557 self.control_ramp_up = ( 

1558 np.linspace(0, 1, self.ramp_samples) 

1559 * self.preshaped_drive_amplitudes[ 

1560 self.control_tones, 

1561 ..., 

1562 self.start_index + self.ramp_samples, 

1563 np.newaxis, 

1564 ] 

1565 ) 

1566 self.control_ramp_down = ( 

1567 np.linspace(1, 0, self.ramp_samples) 

1568 * self.preshaped_drive_amplitudes[ 

1569 self.control_tones, ..., self.end_index - self.ramp_samples, np.newaxis 

1570 ] 

1571 ) 

1572 self.target_ramp_up = ( 

1573 np.linspace(0, 1, self.ramp_samples) 

1574 * self.specified_amplitude[ 

1575 self.control_tones, 

1576 ..., 

1577 self.start_index + self.ramp_samples, 

1578 np.newaxis, 

1579 ] 

1580 ) 

1581 self.target_ramp_down = ( 

1582 np.linspace(1, 0, self.ramp_samples) 

1583 * self.specified_amplitude[ 

1584 self.control_tones, ..., self.end_index - self.ramp_samples, np.newaxis 

1585 ] 

1586 ) 

1587 

1588 ( 

1589 starting_drive_amplitudes, 

1590 starting_drive_phases, 

1591 starting_arguments, 

1592 ) = self.get_control_preshaped_excitations( # Radians 

1593 0, self.control_write_index 

1594 ) # Radians 

1595 

1596 complex_excitation = starting_drive_amplitudes * ( 

1597 np.exp(1j * starting_drive_phases) 

1598 ) # Radians 

1599 excitation_signals = np.sum( 

1600 starting_drive_amplitudes 

1601 * np.cos(starting_drive_phases + starting_arguments[:, np.newaxis, :]), # Radians 

1602 axis=0, 

1603 ) 

1604 

1605 # Set up control parameters 

1606 self.control_drive_correction = np.zeros(starting_drive_amplitudes.shape[:2], dtype=complex) 

1607 

1608 # Set up the amplitude and phase tracking 

1609 self.control_response_amplitudes = [] 

1610 self.control_response_phases = [] # Radians 

1611 self.control_response_signals = [] 

1612 self.control_sent_complex_excitation = [] 

1613 

1614 if DEBUG: 

1615 print("Writing Sine Debug Pickle") 

1616 with open("debug_data/sine_control_law_debug.pkl", "wb") as f: 

1617 pickle.dump(self, f) 

1618 print("Done!") 

1619 

1620 # Sending ramp and first two blocks to start, so add them to the list of blocks sent. 

1621 self.control_sent_complex_excitation.append(complex_excitation) 

1622 

1623 # finish_time = time.time() 

1624 # print(f'initialize_control called in {finish_time - start_time:0.2f}s.') 

1625 return excitation_signals 

1626 

1627 def update_control( 

1628 self, 

1629 control_signals, 

1630 control_amplitudes, 

1631 control_phases, # Radians 

1632 control_frequencies, # pylint: disable=unused-argument 

1633 time_delay, # pylint: disable=unused-argument 

1634 ): 

1635 """ 

1636 Updates the control parameters based on previous responses 

1637 

1638 Parameters 

1639 ---------- 

1640 control_signals : ndarray 

1641 Time histories acquired by the environment. 

1642 control_amplitudes : ndarray 

1643 Amplitudes extracted from the time signals with shape num_tones, 

1644 num_channels, num_timesteps. 

1645 control_phases : ndarray 

1646 Phases extracted from the time signals in radians with shape 

1647 num_tones, num_channels, num_timesteps 

1648 control_frequencies : ndarray 

1649 Instantaneous frequencies at the timesteps analyzed with shape 

1650 num_tones, num_timesteps. 

1651 time_delay : float 

1652 Time delay computed between the acquisition and output signals, 

1653 which can be used to adjust for phase drifts due to delays. 

1654 

1655 Returns 

1656 ------- 

1657 drive_correction : ndarray 

1658 A correction factor on the drive signals with shape num_tones, 

1659 num_drives. 

1660 

1661 """ 

1662 

1663 # start_time = time.time() 

1664 self.control_response_signals.append(control_signals) 

1665 self.control_response_amplitudes.append(control_amplitudes) 

1666 self.control_response_phases.append(control_phases) # Radians 

1667 

1668 if DEBUG: 

1669 print("Writing Sine Debug Pickle") 

1670 with open("debug_data/sine_control_law_update_control_debug.pkl", "wb") as f: 

1671 pickle.dump(self, f) 

1672 print("Done!") 

1673 

1674 # Find the equivalent block in the signal 

1675 block_start_index = self.control_analysis_index 

1676 block_end_index = ( 

1677 control_signals.shape[-1] * self.output_oversample + self.control_analysis_index 

1678 ) 

1679 if self.convergence_factor != 0: 

1680 reduction_slice = slice( 

1681 block_start_index + self.start_index, 

1682 block_end_index + self.start_index, 

1683 self.output_oversample, 

1684 ) 

1685 # Compute the target of the current block 

1686 target_response_amplitudes, target_response_phases, _ = ( # Radians 

1687 self.get_control_targets(block_start_index, block_end_index) 

1688 ) 

1689 complex_targets = target_response_amplitudes[..., :: self.output_oversample] * np.exp( 

1690 1j * target_response_phases[..., :: self.output_oversample] 

1691 ) 

1692 complex_achieved = control_amplitudes * np.exp(1j * control_phases) # Radians 

1693 complex_error = ( 

1694 complex_targets - complex_achieved 

1695 ) # Number of Tones x Num Responses x Num Freqs 

1696 block_correction_factor = self.convergence_factor * np.min( 

1697 self.largest_correction_factors[self.control_tones, ..., reduction_slice], 

1698 axis=-1, 

1699 keepdims=True, 

1700 ) # Num Tones x 1 

1701 block_frf = self.interpolated_frf_pinv[ 

1702 self.control_tones, reduction_slice, ... 

1703 ] # Number of Tones x Num Freqs x Num Excitations x Num Responses 

1704 block_drive_correction = ( # Number of Tones x Num Freqs x Num Excitations x 1 

1705 block_frf # Number of Tones x Num Freqs x Num Excitations x Num Responses 

1706 @ complex_error.transpose(0, 2, 1)[..., np.newaxis] 

1707 ) # Number of Tones x Number Freqs x Num Responses x 1 

1708 self.control_drive_correction = ( 

1709 self.control_drive_correction # Number of Tones x Number of Excitation Signals 

1710 + block_correction_factor # Number of Tones x 1 

1711 * np.mean(block_drive_correction[..., 0], axis=1) 

1712 ) # Mean across frequency lines, Number of Tones x Num Excitations 

1713 self.control_analysis_index = block_end_index 

1714 # finish_time = time.time() 

1715 # print(f'update_control called in {finish_time - start_time:0.2f}s.') 

1716 return self.control_drive_correction 

1717 

1718 def generate_signal(self): 

1719 """ 

1720 Generates the next portion of the signal during the control 

1721 calculations 

1722 

1723 Returns 

1724 ------- 

1725 excitation_signal : ndarray 

1726 The next portion of the signal to generate, with shape num_drives, 

1727 num_timesteps. 

1728 done_controlling : bool 

1729 A flag specifying that the entire signal has been generated, so no 

1730 more control decisions should be made. 

1731 

1732 """ 

1733 if DEBUG: 

1734 print("Writing Sine Debug Pickle") 

1735 with open("debug_data/sine_control_law_generate_signal_debug.pkl", "wb") as f: 

1736 pickle.dump(self, f) 

1737 print("Done!") 

1738 

1739 # start_time = time.time() 

1740 start_index = self.control_write_index 

1741 end_index = self.control_write_index + self.block_size 

1742 excitation_amplitudes, excitation_phases, excitation_arguments = ( # Radians 

1743 self.get_control_preshaped_excitations(start_index, end_index) 

1744 ) 

1745 complex_excitation = ( 

1746 excitation_amplitudes * np.exp(1j * excitation_phases) 

1747 + self.control_drive_correction[..., np.newaxis] 

1748 ) # Num tones x num signals x num freqs 

1749 amplitudes = np.abs(complex_excitation) 

1750 if self.maximum_drive_voltage is not None: 

1751 over_indices = amplitudes > self.maximum_drive_voltage 

1752 complex_excitation[over_indices] = ( 

1753 self.maximum_drive_voltage 

1754 * complex_excitation[over_indices] 

1755 / amplitudes[over_indices] 

1756 ) 

1757 excitation_signals = np.abs(complex_excitation) * np.cos( 

1758 excitation_arguments[:, np.newaxis, :] + np.angle(complex_excitation) 

1759 ) 

1760 # Combine all tones into one signal 

1761 excitation_signal = np.sum(excitation_signals, axis=0) 

1762 # Store this value so we know what was output 

1763 self.control_sent_complex_excitation.append(complex_excitation) 

1764 # Check if we've exhausted all of our data 

1765 done_controlling = end_index >= (self.end_index - self.start_index) 

1766 self.control_write_index = end_index 

1767 

1768 # finish_time = time.time() 

1769 # print(f'generate_signal called in {finish_time - start_time:0.2f}s.') 

1770 return excitation_signal, done_controlling 

1771 

1772 def finalize_control(self): 

1773 """ 

1774 A method to update the control based on previous results, generating a 

1775 new preshaped drive signal 

1776 

1777 Returns 

1778 ------- 

1779 preshaped_drive_signals : ndarray 

1780 Excitation signals with shape num_tones, num_drives, num_timesteps. 

1781 specified_frequency : ndarray 

1782 The instantaneous frequencies at each of the drive timesteps with 

1783 shape num_tones, num_timesteps 

1784 specified_argument : ndarray 

1785 The instantaneous sine argument at each of the drive timesteps with 

1786 shape num_tones, num_drives, num_timesteps 

1787 preshaped_drive_amplitudes : ndarray 

1788 The instantaneous amplitude at each of the drive timesteps with 

1789 shape num_tones, num_drives, num_timesteps 

1790 preshaped_drive_phases : ndarray 

1791 The instantaneous phase in radians at each of the drive timesteps 

1792 with shape num_tones, num_drives, num_timesteps. 

1793 ramp_samples : ndarray 

1794 The number of ramp samples used in the signal 

1795 

1796 """ 

1797 return ( 

1798 self.preshaped_drive_signals, 

1799 self.specified_frequency, 

1800 self.specified_argument, 

1801 self.preshaped_drive_amplitudes, 

1802 self.preshaped_drive_phases, # Radians 

1803 self.ramp_samples, 

1804 ) 

1805 

1806 

1807def vold_kalman_filter( 

1808 sample_rate, 

1809 signal, 

1810 arguments, 

1811 filter_order=None, 

1812 bandwidth=None, 

1813 method=None, 

1814 return_amp_phs=False, 

1815 return_envelope=False, 

1816 return_r=False, 

1817): 

1818 """ 

1819 Extract sinusoidal components from a signal using the second generation 

1820 Vold-Kalman filter. 

1821 

1822 Parameters 

1823 ---------- 

1824 sample_rate : float 

1825 The sample rate of the signal in Hz. 

1826 signal : ndarray 

1827 A 1D signal containing sinusoidal components that need to be extracted 

1828 arguments : ndarray 

1829 A 2D array consisting of the arguments to the sinusoidal components of 

1830 the form exp(1j*argument). This is the integral over time of the 

1831 angular frequency, which can be approximated as 

1832 2*np.pi*scipy.integrate.cumulative_trapezoid(frequencies,timesteps,initial=0) 

1833 if frequencies is the frequency at each time step in Hz timesteps is 

1834 the vector of time steps in seconds. This is a 2D array where the 

1835 number of rows is the 

1836 number of different sinusoidal components that are desired to be 

1837 extracted, and the number of columns are the number of time steps in 

1838 the `signal` argument. 

1839 filter_order : int, optional 

1840 The order of the VK filter, which should be 1, 2, or 3. The default is 

1841 2. The low-pass filter roll-off is approximately -40 dB per times the 

1842 filter order. 

1843 bandwidth : ndarray, optional 

1844 The prescribed bandwidth of the filter. This is related to the filter 

1845 selectivity parameter `r` in the literature. This will be broadcast to 

1846 the same shape as the `arguments` argument. The default is the sample 

1847 rate divided by 1000. 

1848 method : str, optional 

1849 Can be set to either 'single' or 'multi'. In a 'single' solution, each 

1850 sinusoidal component will be solved independently without any coupling. 

1851 This can be more efficient, but will result in errors if the 

1852 frequencies of the sine waves cross. The 'multi' solution will solve 

1853 for all sinusoidal components simultaneously, resulting in a better 

1854 estimate of crossing frequencies. The default is 'multi'. 

1855 return_amp_phs : bool 

1856 Returns the amplitude and phase of the reconstructed signals at each 

1857 time step. Default is False 

1858 return_envelope : bool 

1859 Returns the complex envelope and phasors at each time step. Default is 

1860 False 

1861 return_r : bool 

1862 Returns the computed selectivity parameters for the filter. Default is 

1863 False 

1864 

1865 Raises 

1866 ------ 

1867 ValueError 

1868 If arguments are not the correct size or values. 

1869 

1870 Returns 

1871 ------- 

1872 reconstructed_signals : ndarray 

1873 Returns a time history the same size as `signal` for each of the 

1874 sinusoidal components solved for. 

1875 reconstructed_amplitudes : ndarray 

1876 Returns the amplitude over time for each of the sinusoidal components 

1877 solved for. Only returned if return_amp_phs is True. 

1878 reconstructed_phases : ndarray 

1879 Returns the phase over time for each of the sinusoidal components 

1880 solved for. Only returned if return_amp_phs is True. 

1881 reconstructed_envelope : ndarray 

1882 Returns the complex envelope `x` over time for each of the sinusoidal 

1883 components solved for. Only returned if return_envelope is True. 

1884 reconstructed_phasor : ndarray 

1885 Returns the phasor `c` over time for each of the sinusoidal components 

1886 solved for. Only returned if return_envelope is True. 

1887 r : ndarray 

1888 Returns the selectivity `r` over time for each of the sinusoidal 

1889 components solved for. Only returned if return_r is True. 

1890 

1891 """ 

1892 # pylint: disable=invalid-name 

1893 if filter_order is None: 

1894 filter_order = 2 

1895 

1896 if bandwidth is None: 

1897 bandwidth = sample_rate / 1000 

1898 

1899 # Make sure input data are numpy arrays 

1900 signal = np.array(signal) 

1901 arguments = np.atleast_2d(arguments) 

1902 bandwidth = np.atleast_2d(bandwidth) 

1903 bandwidth = np.broadcast_to(bandwidth, arguments.shape) 

1904 relative_bandwidth = bandwidth / sample_rate 

1905 

1906 # Extract some sizes to make sure everything is correctly sized 

1907 n_samples = signal.shape[-1] 

1908 

1909 n_orders_arg, n_arg = arguments.shape 

1910 if n_arg != n_samples: 

1911 raise ValueError( 

1912 "Argument array must have identical number of columns as samples in signal" 

1913 ) 

1914 

1915 if method is None: 

1916 if n_orders_arg > 1: 

1917 method = "multi" 

1918 else: 

1919 method = "single" 

1920 if method.lower() not in ["multi", "single"]: 

1921 raise ValueError('`method` must be either "multi" or "single"') 

1922 

1923 # Construct phasors to multiply the signals by 

1924 phasor = np.exp(1j * arguments) 

1925 

1926 # Construct the matrices for the least squares solution 

1927 if filter_order == 1: 

1928 coefs = np.array([1, -1]) 

1929 r = np.sqrt((np.sqrt(2) - 1) / (2 * (1 - np.cos(np.pi * relative_bandwidth)))) 

1930 elif filter_order == 2: 

1931 coefs = np.array([1, -2, 1]) 

1932 r = np.sqrt( 

1933 (np.sqrt(2) - 1) 

1934 / ( 

1935 6 

1936 - 8 * np.cos(np.pi * relative_bandwidth) 

1937 + 2 * np.cos(2 * np.pi * relative_bandwidth) 

1938 ) 

1939 ) 

1940 elif filter_order == 3: 

1941 coefs = np.array([1, -3, 3, -1]) 

1942 r = np.sqrt( 

1943 (np.sqrt(2) - 1) 

1944 / ( 

1945 20 

1946 - 30 * np.cos(np.pi * relative_bandwidth) 

1947 + 12 * np.cos(2 * np.pi * relative_bandwidth) 

1948 - 2 * np.cos(3 * np.pi * relative_bandwidth) 

1949 ) 

1950 ) 

1951 else: 

1952 raise ValueError("filter order must be 1, 2, or 3") 

1953 

1954 # Construct the solution matrices 

1955 A = sparse.spdiags( 

1956 np.tile(coefs, (n_samples, 1)).T, 

1957 np.arange(filter_order + 1), 

1958 n_samples - filter_order, 

1959 n_samples, 

1960 ) 

1961 B = [] 

1962 for rvec in r: 

1963 R = sparse.spdiags(rvec, 0, n_samples, n_samples) 

1964 AR = A @ R 

1965 B.append((AR).T @ (AR) + sparse.eye(n_samples)) 

1966 

1967 if method.lower() == "multi": 

1968 # This solves the multiple order approach, constructing a big matrix of 

1969 # Bs on the diagonal and CHCs on the off-diagonals. We can set up the 

1970 # matrix as diagonals and upper diagonals then add the transpose to get the 

1971 # lower diagonals 

1972 B_multi_diagonal = sparse.block_diag(B) 

1973 # There will be number of orders**2 B matrices, and number of orders 

1974 # diagonals, so there will be n_orders**2-n_orders off diagonals, half on 

1975 # on the upper triangle. We need to fill in all of these values for all 

1976 # time steps. 

1977 num_off_diags = (n_orders_arg**2 - n_orders_arg) // 2 

1978 row_indices = np.zeros((n_samples, num_off_diags), dtype=int) 

1979 col_indices = np.zeros((n_samples, num_off_diags), dtype=int) 

1980 CHC = np.zeros((n_samples, num_off_diags), dtype="c16") 

1981 # Keep track of the off-diagonal index so we know which column to put the 

1982 # data in 

1983 off_diagonal_index = 0 

1984 # Now we need to step through the off-diagonal blocks and create the arrays 

1985 for row_index in range(n_orders_arg): 

1986 # Since we need to stay on the upper triangle, column indices will start 

1987 # after the diagonal entry 

1988 for col_index in range(row_index + 1, n_orders_arg): 

1989 row_indices[:, off_diagonal_index] = np.arange( 

1990 row_index * n_samples, (row_index + 1) * n_samples 

1991 ) 

1992 col_indices[:, off_diagonal_index] = np.arange( 

1993 col_index * n_samples, (col_index + 1) * n_samples 

1994 ) 

1995 CHC[:, off_diagonal_index] = phasor[row_index].conj() * phasor[col_index] 

1996 off_diagonal_index += 1 

1997 # We set up the variables as multidimensional so we could store them easier, 

1998 # but now we need to flatten them to put them into the sparse matrix. 

1999 # We choose CSR because we can do math with it easier 

2000 B_multi_utri = sparse.csr_matrix( 

2001 (CHC.flatten(), (row_indices.flatten(), col_indices.flatten())), 

2002 shape=B_multi_diagonal.shape, 

2003 ) 

2004 

2005 # Now we can assemble the entire matrix by adding with the complex conjugate 

2006 # of the upper triangle to get the lower triangle 

2007 B_multi = B_multi_diagonal + B_multi_utri + B_multi_utri.getH() 

2008 

2009 # We also need to construct the right hand side of the equation. This 

2010 # should be a multiplication of the phasor^H with the signal 

2011 RHS = phasor.flatten().conj() * np.tile(signal, n_orders_arg) 

2012 

2013 x_multi = linalg.spsolve(B_multi, RHS[:, np.newaxis]) 

2014 x = 2 * x_multi.reshape( 

2015 n_orders_arg, -1 

2016 ) # Multiply by 2 to account for missing negative frequency components 

2017 else: 

2018 # This solves the single order approach. If the user has put in multiple 

2019 # orders, it will solve them all independently instead of combining them 

2020 # into a single larger solve. 

2021 x = np.zeros((n_orders_arg, n_samples), dtype=np.complex128) 

2022 for i, (phasor_i, B_i) in enumerate(zip(phasor, B)): 

2023 # We already have the left side of the equation B, now we just need the 

2024 # right side of the equation, which is the phasor hermetian 

2025 # times the signal elementwise-multiplied 

2026 RHS = phasor_i.conj() * signal 

2027 x[i] = 2 * linalg.spsolve(B_i, RHS) 

2028 

2029 return_value = [np.real(x * phasor)] 

2030 if return_amp_phs: 

2031 return_value.extend([np.abs(x), np.angle(x)]) 

2032 if return_envelope: 

2033 return_value.extend([x, phasor]) 

2034 if return_r: 

2035 return_value.extend([r]) 

2036 if len(return_value) == 1: 

2037 return return_value[0] 

2038 else: 

2039 return return_value 

2040 

2041 

2042def vold_kalman_filter_generator( 

2043 sample_rate, 

2044 num_orders, 

2045 block_size, 

2046 overlap, 

2047 filter_order=None, 

2048 bandwidth=None, 

2049 method=None, 

2050 buffer_size_factor=3, 

2051): 

2052 """ 

2053 Extracts sinusoidal information using a Vold-Kalman Filter 

2054 

2055 This uses an windowed-overlap-and-add process to solve for the signal while 

2056 removing start and end effects of the filter. Each time the generator is 

2057 called, it will yield a further section of the results up until the overlap 

2058 section. 

2059 

2060 Parameters 

2061 ---------- 

2062 sample_rate : float 

2063 The sample rate of the signal in Hz. 

2064 num_orders : int 

2065 The number of orders that will be found in the signal 

2066 block_size : int 

2067 The size of the blocks used in the analysis. 

2068 overlap : float, optional 

2069 Fraction of the block size to overlap when computing the results. If 

2070 not specified, it will default to 0.15. 

2071 filter_order : int, optional 

2072 The order of the VK filter, which should be 1, 2, or 3. The default is 

2073 2. The low-pass filter roll-off is approximately -40 dB per times the 

2074 filter order. 

2075 bandwidth : ndarray, optional 

2076 The prescribed bandwidth of the filter. This is related to the filter 

2077 selectivity parameter `r` in the literature. This will be broadcast to 

2078 the same shape as the `arguments` argument. The default is the sample 

2079 rate divided by 1000. 

2080 method : str, optional 

2081 Can be set to either 'single' or 'multi'. In a 'single' solution, each 

2082 sinusoidal component will be solved independently without any coupling. 

2083 This can be more efficient, but will result in errors if the 

2084 frequencies of the sine waves cross. The 'multi' solution will solve 

2085 for all sinusoidal components simultaneously, resulting in a better 

2086 estimate of crossing frequencies. The default is 'multi'. 

2087 buffer_size_factor : int, optional 

2088 Specifies the size of the buffer. buffer_size_factor * block_size is 

2089 the size of the buffer. 

2090 

2091 Raises 

2092 ------ 

2093 ValueError 

2094 If arguments are not the correct size or values. 

2095 ValueError 

2096 If data is provided subsequently to specifying last_signal = True 

2097 

2098 Sends 

2099 ----- 

2100 xi : iterable 

2101 The next block of the signal to be filtered. This should be a 1D 

2102 signal containing sinusoidal components that need to be extracted. 

2103 argsi : iterable 

2104 A 2D array consisting of the arguments to the sinusoidal components of 

2105 the form exp(1j*argsi). This is the integral over time of the 

2106 angular frequency, which can be approximated as 

2107 2*np.pi*scipy.integrate.cumulative_trapezoid(frequencies,timesteps,initial=0) 

2108 if frequencies is the frequency at each time step in Hz timesteps is 

2109 the vector of time steps in seconds. This is a 2D array where the 

2110 number of rows is the 

2111 number of different sinusoidal components that are desired to be 

2112 extracted, and the number of columns are the number of time steps in 

2113 the `signal` argument. 

2114 last_signal : bool 

2115 If True, the remainder of the data will be returned and the 

2116 overlap-and-add process will be finished. 

2117 

2118 Yields 

2119 ------- 

2120 reconstructed_signals : ndarray 

2121 Returns a time history the same size as `signal` for each of the 

2122 sinusoidal components solved for. 

2123 reconstructed_amplitudes : ndarray 

2124 Returns the amplitude over time for each of the sinusoidal components 

2125 solved for. Only returned if return_amp_phs is True. 

2126 reconstructed_phases : ndarray 

2127 Returns the phase over time for each of the sinusoidal components 

2128 solved for. Only returned if return_amp_phs is True. 

2129 

2130 """ 

2131 previous_envelope = None 

2132 reconstructed_signals = None 

2133 reconstructed_amplitudes = None 

2134 reconstructed_phases = None 

2135 overlap_samples = int(overlap * block_size) 

2136 window = windows.hann(overlap_samples * 2, False) 

2137 start_window = window[:overlap_samples] 

2138 end_window = window[overlap_samples:] 

2139 buffer = CircularBufferWithOverlap( 

2140 buffer_size_factor * block_size, 

2141 block_size, 

2142 overlap_samples, 

2143 data_shape=(num_orders + 1,), 

2144 ) 

2145 first_output = True 

2146 last_signal = False 

2147 while True: 

2148 xi, argsi, check_last_signal = ( 

2149 yield reconstructed_signals, 

2150 reconstructed_amplitudes, 

2151 reconstructed_phases, 

2152 ) 

2153 if last_signal and check_last_signal: 

2154 raise ValueError("Generator has been exhausted.") 

2155 last_signal = check_last_signal 

2156 argsi = np.atleast_2d(argsi) 

2157 buffer_data = np.concatenate([xi[np.newaxis], argsi]) 

2158 # print(f"{buffer_data.shape=}") 

2159 buffer_output = buffer.write_get_data(buffer_data, last_signal) 

2160 if buffer_output is not None: 

2161 # print(f"{buffer_output.shape=}") 

2162 if first_output: 

2163 buffer_output = buffer_output[..., overlap_samples:] 

2164 first_output = False 

2165 signal = buffer_output[0] 

2166 arguments = buffer_output[1:] 

2167 else: 

2168 signal = buffer_output[0] 

2169 arguments = buffer_output[1:] 

2170 signal[:overlap_samples] = signal[:overlap_samples] * start_window 

2171 if not last_signal: 

2172 signal[-overlap_samples:] = signal[-overlap_samples:] * end_window 

2173 # print(f"{signal.shape=}") 

2174 # Do the VK Filtering 

2175 _, vk_envelope, vk_phasor = vold_kalman_filter( 

2176 sample_rate, 

2177 signal, 

2178 arguments, 

2179 filter_order, 

2180 bandwidth, 

2181 method, 

2182 return_envelope=True, 

2183 ) 

2184 # print(f"{vk_signal.shape=}") 

2185 # If necessary, do the overlap 

2186 if previous_envelope is not None: 

2187 vk_envelope[..., :overlap_samples] = ( 

2188 vk_envelope[..., :overlap_samples] + previous_envelope[..., -overlap_samples:] 

2189 ) 

2190 if not last_signal: 

2191 reconstructed_signals = np.real( 

2192 vk_envelope[..., :-overlap_samples] * vk_phasor[..., :-overlap_samples] 

2193 ) 

2194 reconstructed_amplitudes = np.abs(vk_envelope[..., :-overlap_samples]) 

2195 reconstructed_phases = np.angle(vk_envelope[..., :-overlap_samples]) 

2196 else: 

2197 reconstructed_signals = np.real(vk_envelope * vk_phasor) 

2198 reconstructed_amplitudes = np.abs(vk_envelope) 

2199 reconstructed_phases = np.angle(vk_envelope) 

2200 previous_envelope = vk_envelope 

2201 else: 

2202 # print(f"{buffer_output=}") 

2203 reconstructed_signals = None 

2204 reconstructed_amplitudes = None 

2205 reconstructed_phases = None 

2206 

2207 

2208class CircularBufferWithOverlap: 

2209 """ 

2210 A Circular buffer that allows data to be added and removed 

2211 from the buffer with overlap 

2212 """ 

2213 

2214 def __init__(self, buffer_size, block_size, overlap_size, dtype="float", data_shape=()): 

2215 """Initialize the circular buffer 

2216 

2217 Parameters 

2218 ---------- 

2219 buffer_size : int 

2220 The total size of the circular buffer 

2221 block_size : int 

2222 The number of samples written or read in each block 

2223 overlap_size : int 

2224 The number of samples from the previous block to include in the read. 

2225 dtype : dtype, optional 

2226 The type of data in the buffer. The default is "float". 

2227 data_shape : tuple, optional 

2228 The shape of data in the buffer. The default is (). 

2229 """ 

2230 self.buffer_size = buffer_size 

2231 self.block_size = block_size 

2232 self.overlap_size = overlap_size 

2233 self.buffer = np.zeros( 

2234 tuple(data_shape) + (buffer_size,), dtype=dtype 

2235 ) # Initialize buffer with zeros 

2236 self.buffer_read = np.ones((buffer_size,), dtype=bool) 

2237 self.write_index = 0 # Index where the next block will be written 

2238 self.read_index = 0 # Index where the next block will be read from 

2239 self.debug = False 

2240 if self.debug: 

2241 self.report_buffer_state() 

2242 

2243 def report_buffer_state(self): 

2244 """Prints the current buffer state""" 

2245 read_samples = np.sum(self.buffer_read) 

2246 write_samples = self.buffer_size - read_samples 

2247 print(f"{read_samples} of {self.buffer_size} have been read") 

2248 print(f"{write_samples} of {self.buffer_size} have been written but not read") 

2249 

2250 def write_get_data(self, data, read_remaining=False): 

2251 """ 

2252 Writes a block of data and then returns a block if available 

2253 

2254 Parameters: 

2255 - data: Array to write to the buffer. 

2256 """ 

2257 self.write(data) 

2258 try: 

2259 return self.read(read_remaining) 

2260 except ValueError: 

2261 return None 

2262 

2263 def write(self, data): 

2264 """ 

2265 Write a block of data to the circular buffer. 

2266 

2267 Parameters: 

2268 - data: Array to write to the buffer. 

2269 """ 

2270 # Compute the end index for the write operation 

2271 indices = ( 

2272 np.arange(self.write_index, self.write_index + data.shape[-1] + self.overlap_size) 

2273 % self.buffer_size 

2274 ) 

2275 

2276 if np.any(~self.buffer_read[indices]): 

2277 raise ValueError( 

2278 "Overwriting data on buffer that has not been read. " 

2279 "Read data before writing again." 

2280 ) 

2281 

2282 self.buffer[..., indices[: None if self.overlap_size == 0 else -self.overlap_size]] = data 

2283 self.buffer_read[indices[: None if self.overlap_size == 0 else -self.overlap_size]] = False 

2284 

2285 # Update the write index 

2286 self.write_index = (self.write_index + data.shape[-1]) % self.buffer_size 

2287 

2288 if self.debug: 

2289 print("Wrote Data to Buffer") 

2290 self.report_buffer_state() 

2291 # print(self.buffer) 

2292 # print(self.buffer_read) 

2293 

2294 def read(self, read_remaining=False): 

2295 """ 

2296 Reads data from the buffer 

2297 

2298 Parameters 

2299 ---------- 

2300 read_remaining : bool, optional 

2301 If true, read everything left on the buffer that hasn't yet been 

2302 read. The default is False. 

2303 

2304 Raises 

2305 ------ 

2306 ValueError 

2307 If there is not a block of data on the buffer and data would be 

2308 read a second time. 

2309 

2310 Returns 

2311 ------- 

2312 return_data : ndarray 

2313 A block of data read from the buffer. 

2314 

2315 """ 

2316 indices = ( 

2317 np.arange(self.read_index - self.overlap_size, self.read_index + self.block_size) 

2318 % self.buffer_size 

2319 ) 

2320 if read_remaining: 

2321 # Pick out just the indices that are ok to read 

2322 # print('Reading Remaining:') 

2323 # print(f"{indices.copy()=}") 

2324 indices = np.concatenate( 

2325 ( 

2326 indices[: self.overlap_size], 

2327 indices[self.overlap_size :][~self.buffer_read[indices[self.overlap_size :]]], 

2328 ) 

2329 ) 

2330 # print(f"{indices.copy()=}") 

2331 if np.any(self.buffer_read[indices[self.overlap_size :]]): 

2332 raise ValueError("Data would be read multiple times. Write data before reading again.") 

2333 return_data = self.buffer[..., indices] 

2334 self.buffer_read[indices[self.overlap_size :]] = True 

2335 self.read_index = ( 

2336 self.read_index + (return_data.shape[-1] - self.overlap_size) 

2337 ) % self.buffer_size 

2338 if self.debug: 

2339 print("Read Data from Buffer") 

2340 self.report_buffer_state() 

2341 # print(self.buffer) 

2342 # print(self.buffer_read) 

2343 return return_data 

2344 

2345 

2346class SineSpecification: 

2347 """A class representing a sine specification""" 

2348 

2349 def __init__( 

2350 self, 

2351 name, 

2352 start_time, 

2353 num_control, 

2354 num_breakpoints=None, 

2355 frequency_breakpoints=None, 

2356 amplitude_breakpoints=None, 

2357 phase_breakpoints=None, 

2358 sweep_type_breakpoints=None, 

2359 sweep_rate_breakpoints=None, 

2360 warning_breakpoints=None, 

2361 abort_breakpoints=None, 

2362 ): 

2363 """ 

2364 Initializes the sine specification 

2365 

2366 Parameters 

2367 ---------- 

2368 name : str 

2369 Name of the sine tone. 

2370 start_time : float 

2371 The starting time of the sine tone. 

2372 num_control : int 

2373 The number of control channels in the specification. 

2374 num_breakpoints : int, optional 

2375 The number of frequency breakpoints in the specification. Either 

2376 this or frequency_breakpoints must be specified. 

2377 frequency_breakpoints : ndarray, optional 

2378 The frequency breakpoints in the specification. Either this or 

2379 num_breakpoints must be specified. 

2380 amplitude_breakpoints : ndarray, optional 

2381 The amplitude breakpoints of the specification, with shape num_freq, 

2382 num_channels. If not specified, amplitudes will be 0. 

2383 phase_breakpoints : ndarray, optional 

2384 The phase breakpoints of the specification, with shape num_freq, 

2385 num_channels. If not specified, phase will be 0. Phases should be 

2386 in radians. 

2387 sweep_type_breakpoints : ndarray, optional 

2388 Should be a 0 if linear sweep or 1 if logarithmic sweep. Linear if 

2389 not specified. 

2390 sweep_rate_breakpoints : ndarray, optional 

2391 Sweep rate at each breakpoint. Hz/s if linear and oct/min if 

2392 logarithmic sweep 

2393 warning_breakpoints : ndarray, optional 

2394 A 4D array of warning amplitudes with shape num_freq, 2, 2, 

2395 num_channels. The second dimension uses the 0 index for the lower 

2396 warning limit and the 1 index for the upper warning limit. The 

2397 third dimension uses the 0 index for the "left" or "lower frequency" 

2398 side of the breakpoint and 1 for the "right" or "higher frequency" 

2399 side of the breakpoing. If not specified, no warnings will be 

2400 specified. 

2401 abort_breakpoints : ndarray, optional 

2402 A 4D array of abort amplitudes with shape num_freq, 2, 2, 

2403 num_channels. The second dimension uses the 0 index for the lower 

2404 warning limit and the 1 index for the upper warning limit. The 

2405 third dimension uses the 0 index for the "left" or "lower frequency" 

2406 side of the breakpoint and 1 for the "right" or "higher frequency" 

2407 side of the breakpoing. If not specified, no aborts will be 

2408 specified. 

2409 

2410 Raises 

2411 ------ 

2412 ValueError 

2413 if not one of frequency_breakpoints or num_breakpoints is specified 

2414 """ 

2415 spec_dtype = [ 

2416 ("frequency", "f8"), 

2417 ("amplitude", "f8", (num_control,)), 

2418 ("phase", "f8", (num_control,)), 

2419 ("sweep_type", "u1"), 

2420 ("sweep_rate", "f8"), 

2421 ("warning", "f8", (2, 2, num_control)), 

2422 ("abort", "f8", (2, 2, num_control)), 

2423 ] 

2424 if frequency_breakpoints is None and num_breakpoints is None: 

2425 raise ValueError("Must specify either number of breakpoints or breakpoint frequencies.") 

2426 if frequency_breakpoints is None: 

2427 self.breakpoint_table = np.zeros(num_breakpoints, dtype=spec_dtype) 

2428 else: 

2429 self.breakpoint_table = np.zeros(frequency_breakpoints.shape[0], dtype=spec_dtype) 

2430 self.breakpoint_table["frequency"] = frequency_breakpoints 

2431 if amplitude_breakpoints is not None: 

2432 self.breakpoint_table["amplitude"] = amplitude_breakpoints 

2433 if phase_breakpoints is not None: 

2434 self.breakpoint_table["phase"] = phase_breakpoints # Radians 

2435 if sweep_type_breakpoints is not None: 

2436 self.breakpoint_table["sweep_type"][:-1] = sweep_type_breakpoints 

2437 if sweep_rate_breakpoints is not None: 

2438 self.breakpoint_table["sweep_rate"][:-1] = sweep_rate_breakpoints 

2439 if warning_breakpoints is not None: 

2440 self.breakpoint_table["warning"] = warning_breakpoints 

2441 else: 

2442 self.breakpoint_table["warning"] = np.nan 

2443 if abort_breakpoints is not None: 

2444 self.breakpoint_table["abort"] = abort_breakpoints 

2445 else: 

2446 self.breakpoint_table["abort"] = np.nan 

2447 self.start_time = start_time 

2448 self.name = name 

2449 

2450 def copy(self): 

2451 """Creates a copy of the sine specification""" 

2452 return SineSpecification( 

2453 self.name, 

2454 self.start_time, 

2455 self.breakpoint_table["amplitude"].shape[-1], 

2456 frequency_breakpoints=self.breakpoint_table["frequency"].copy(), 

2457 amplitude_breakpoints=self.breakpoint_table["amplitude"].copy(), 

2458 phase_breakpoints=self.breakpoint_table["phase"].copy(), 

2459 sweep_type_breakpoints=self.breakpoint_table["sweep_type"][:-1].copy(), 

2460 sweep_rate_breakpoints=self.breakpoint_table["sweep_rate"][:-1].copy(), 

2461 warning_breakpoints=self.breakpoint_table["warning"].copy(), 

2462 abort_breakpoints=self.breakpoint_table["abort"].copy(), 

2463 ) 

2464 

2465 def create_signal( 

2466 self, 

2467 sample_rate, 

2468 ramp_samples=0, 

2469 control_index=None, 

2470 ignore_start_time=False, 

2471 only_breakpoints=False, 

2472 ): 

2473 """ 

2474 Creates a signal from the sine specification 

2475 

2476 Parameters 

2477 ---------- 

2478 sample_rate : float 

2479 The sample rate of the signal generated 

2480 ramp_samples : int, optional 

2481 The number of samples to add to the start and end due to ramp up 

2482 or ramp down. The default is 0. 

2483 control_index : int, optional 

2484 The channel index to generate a signal for. If not specified, 

2485 all channels will be generated. 

2486 ignore_start_time : bool, optional 

2487 If True, ignore the start time and have the sine sweep start 

2488 immediately in the signal. The default is False. 

2489 only_breakpoints : bool, optional 

2490 If True, only generate values at the breakpoints. The default is 

2491 False. 

2492 

2493 Returns 

2494 ------- 

2495 ordinate : ndarray 

2496 The generated signals 

2497 frequency : ndarray 

2498 The instantaneous frequency at each time step 

2499 argument : ndarray 

2500 The instantaneous argument at each time step. 

2501 amplitude : ndarray 

2502 The instantaneous amplitude at each timestep. 

2503 phase : ndarray 

2504 The instantaneous phase at each timestep in radians. 

2505 abscissa : ndarray 

2506 The abscissa at each time step. 

2507 start_index : int 

2508 The sample at which the specification starts taking into account 

2509 the start time and the ramp samples. 

2510 end_index : int 

2511 The sample at which the specification ends taking into accound the 

2512 ramp samples. 

2513 

2514 """ 

2515 # Convert octave per min to octave per second 

2516 sweep_rates = self.breakpoint_table["sweep_rate"].copy() 

2517 sweep_rates[self.breakpoint_table["sweep_type"] == 1] = ( 

2518 sweep_rates[self.breakpoint_table["sweep_type"] == 1] / 60 

2519 ) 

2520 # Create the sweep types array 

2521 sweep_types = [ 

2522 "lin" if sweep_type == 0 else "log" 

2523 for sweep_type in self.breakpoint_table["sweep_type"][:-1] 

2524 ] 

2525 if control_index is None: 

2526 ordinate = [] 

2527 amplitude = [] 

2528 phase = [] 

2529 for control_index in range(self.breakpoint_table["amplitude"].shape[-1]): 

2530 ( 

2531 this_ordinate, 

2532 argument, 

2533 frequency, 

2534 this_amplitude, 

2535 this_phase, 

2536 abscissa, 

2537 ) = sine_sweep( 

2538 1 / sample_rate, 

2539 self.breakpoint_table["frequency"], 

2540 sweep_rates, 

2541 sweep_types, 

2542 self.breakpoint_table["amplitude"][:, control_index], 

2543 self.breakpoint_table["phase"][:, control_index], 

2544 return_frequency=True, 

2545 return_argument=True, 

2546 return_amplitude=True, 

2547 return_phase=True, 

2548 return_abscissa=True, 

2549 only_breakpoints=only_breakpoints, 

2550 ) 

2551 ordinate.append(this_ordinate) 

2552 amplitude.append(this_amplitude) 

2553 phase.append(this_phase) 

2554 ordinate = np.array(ordinate) 

2555 amplitude = np.array(amplitude) 

2556 phase = np.array(phase) 

2557 else: 

2558 ordinate, argument, frequency, amplitude, phase, abscissa = sine_sweep( 

2559 1 / sample_rate, 

2560 self.breakpoint_table["frequency"], 

2561 sweep_rates, 

2562 sweep_types, 

2563 self.breakpoint_table["amplitude"][:, control_index], 

2564 self.breakpoint_table["phase"][:, control_index], 

2565 return_frequency=True, 

2566 return_argument=True, 

2567 return_amplitude=True, 

2568 return_phase=True, 

2569 return_abscissa=True, 

2570 only_breakpoints=only_breakpoints, 

2571 ) 

2572 

2573 if ignore_start_time: 

2574 delay_samples = 0 

2575 else: 

2576 delay_samples = int(sample_rate * self.start_time) 

2577 start_index = ramp_samples + delay_samples 

2578 if ramp_samples > 0 or delay_samples > 0: 

2579 # Create the pre-signal ramp 

2580 begin_abscissa = np.arange(-ramp_samples - delay_samples, 0) / sample_rate 

2581 begin_arguments = 2 * np.pi * frequency[0] * begin_abscissa + argument[0] 

2582 begin_frequencies = np.ones(ramp_samples + delay_samples) * frequency[0] 

2583 begin_amplitudes = ( 

2584 np.concatenate((np.zeros(delay_samples), np.linspace(0, 1, ramp_samples))) 

2585 * amplitude[..., [0]] 

2586 ) 

2587 begin_phases = np.ones(ramp_samples + delay_samples) * phase[..., [0]] 

2588 begin_signal = begin_amplitudes * np.cos(begin_arguments + begin_phases) 

2589 

2590 abscissa = np.concatenate((begin_abscissa, abscissa), axis=-1) 

2591 ordinate = np.concatenate((begin_signal, ordinate), axis=-1) 

2592 frequency = np.concatenate((begin_frequencies, frequency), axis=-1) 

2593 amplitude = np.concatenate((begin_amplitudes, amplitude), axis=-1) 

2594 phase = np.concatenate((begin_phases, phase), axis=-1) 

2595 argument = np.concatenate((begin_arguments, argument), axis=-1) 

2596 if ramp_samples > 0: 

2597 end_abscissa = np.arange(1, ramp_samples + 1) / sample_rate 

2598 end_arguments = 2 * np.pi * frequency[-1] * end_abscissa + argument[-1] 

2599 end_frequencies = np.ones(ramp_samples) * frequency[-1] 

2600 end_amplitudes = np.linspace(1, 0, ramp_samples) * amplitude[..., [-1]] 

2601 end_phases = np.ones(ramp_samples) * phase[..., [-1]] 

2602 end_signal = end_amplitudes * np.cos(end_arguments + end_phases) 

2603 

2604 abscissa = np.concatenate((abscissa, end_abscissa), axis=-1) 

2605 ordinate = np.concatenate((ordinate, end_signal), axis=-1) 

2606 frequency = np.concatenate((frequency, end_frequencies), axis=-1) 

2607 argument = np.concatenate((argument, end_arguments), axis=-1) 

2608 amplitude = np.concatenate((amplitude, end_amplitudes), axis=-1) 

2609 phase = np.concatenate((phase, end_phases), axis=-1) 

2610 end_index = abscissa.shape[-1] - ramp_samples 

2611 

2612 return ( 

2613 ordinate, 

2614 frequency, 

2615 argument, 

2616 amplitude, 

2617 phase, 

2618 abscissa, 

2619 start_index, 

2620 end_index, 

2621 ) 

2622 

2623 def interpolate_warning(self, channel_index, frequencies): 

2624 """ 

2625 Interpolates the warning array at the specified frequencies 

2626 

2627 Parameters 

2628 ---------- 

2629 channel_index : int 

2630 The channel to compute the warning levels with. 

2631 frequencies : ndarray 

2632 The frequencies at which to compute the warning levels. 

2633 

2634 Returns 

2635 ------- 

2636 warning_levels 

2637 A 2 x num_frequencies array containing the warning levels. The 

2638 first index is the lower warning level and the second index is the 

2639 upper warning level. If warnings are not specified for certain 

2640 values, they will be set to NaN. 

2641 

2642 """ 

2643 abscissa = np.repeat(self.breakpoint_table["frequency"], 2) 

2644 lower_ordinate = self.breakpoint_table["warning"][:, 0, :, channel_index].flatten() 

2645 upper_ordinate = self.breakpoint_table["warning"][:, 1, :, channel_index].flatten() 

2646 return np.array( 

2647 [ 

2648 np.interp(frequencies, abscissa, lower_ordinate), 

2649 np.interp(frequencies, abscissa, upper_ordinate), 

2650 ] 

2651 ) 

2652 

2653 def interpolate_abort(self, channel_index, frequencies): 

2654 """ 

2655 Interpolates the abort array at the specified frequencies 

2656 

2657 Parameters 

2658 ---------- 

2659 channel_index : int 

2660 The channel to compute the abort levels with. 

2661 frequencies : ndarray 

2662 The frequencies at which to compute the abort levels. 

2663 

2664 Returns 

2665 ------- 

2666 abort_levels 

2667 A 2 x num_frequencies array containing the abort levels. The 

2668 first index is the lower warning level and the second index is the 

2669 upper warning level. If warnings are not specified for certain 

2670 values, they will be set to NaN. 

2671 

2672 """ 

2673 abscissa = np.repeat(self.breakpoint_table["frequency"], 2) 

2674 lower_ordinate = self.breakpoint_table["abort"][:, 0, :, channel_index].flatten() 

2675 upper_ordinate = self.breakpoint_table["abort"][:, 1, :, channel_index].flatten() 

2676 return np.array( 

2677 [ 

2678 np.interp(frequencies, abscissa, lower_ordinate), 

2679 np.interp(frequencies, abscissa, upper_ordinate), 

2680 ] 

2681 ) 

2682 

2683 @staticmethod 

2684 def structured_array_equal(arr1, arr2): 

2685 """ 

2686 A method to check if two sine specification breakpoint tables are equal 

2687 

2688 Parameters 

2689 ---------- 

2690 arr1 : ndarray 

2691 A structured array representing a breakpoint table. 

2692 arr2 : ndarray 

2693 A structured array representing a breakpoint table. 

2694 

2695 Returns 

2696 ------- 

2697 bool 

2698 True if the two arrays are equal. False otherwise. 

2699 

2700 """ 

2701 if arr1.dtype != arr2.dtype: 

2702 # print('DTypes Not Equal') 

2703 return False 

2704 for field in arr1.dtype.names: 

2705 field1 = arr1[field] 

2706 field2 = arr2[field] 

2707 if not np.array_equal(field1, field2, equal_nan=True): 

2708 # print(f'Field {field} Not Equal') 

2709 # print(field1) 

2710 # print(field2) 

2711 return False 

2712 return True 

2713 

2714 def __eq__(self, other): 

2715 """ 

2716 A method to check if two sine specifications are equal. 

2717 

2718 Parameters 

2719 ---------- 

2720 other : SineSpecification 

2721 The SineSpecification object to compare against. 

2722 

2723 Returns 

2724 ------- 

2725 bool 

2726 True if the two SineSpecification objects are equal. 

2727 

2728 """ 

2729 if not SineSpecification.structured_array_equal( 

2730 self.breakpoint_table, other.breakpoint_table 

2731 ): 

2732 return False 

2733 if self.start_time != other.start_time: 

2734 return False 

2735 return True 

2736 

2737 @staticmethod 

2738 def create_combined_signals(specifications, sample_rate, ramp_samples, control_index=None): 

2739 """ 

2740 Creates a combined signal from many specifications 

2741 

2742 Parameters 

2743 ---------- 

2744 specifications : list of SineSpecification objects 

2745 The various specification objects to combine together to creat the 

2746 combined signals. 

2747 sample_rate : float 

2748 The sample rate at which the signals should be generated. 

2749 ramp_samples : int 

2750 The number of samples to use in the ramp portion of the signals. 

2751 control_index : int, optional 

2752 The control channel index at which the signals should be generated. 

2753 The default is to generate all channels. 

2754 

2755 Returns 

2756 ------- 

2757 signal : ndarray 

2758 A number of channels by number of timesteps array of time history 

2759 values. 

2760 order_signals : ndarray 

2761 Separate signals for each of the specification tones, in a 

2762 num_tones, num_channels, num_timesteps array. 

2763 order_frequencies : ndarray 

2764 The frequencies associated with each sine tone at each time step 

2765 in a num_tones, num_timesteps array. 

2766 order_arguments : ndarray 

2767 The arguments associated with each sine tone at each time step in a 

2768 num_tones, num_timestep array 

2769 order_amplitudes : ndarray 

2770 The instantaneous amplitudes associated with each sine tone and 

2771 channel at each time step in a num_tones, num_channels, 

2772 num_timesteps shaped array 

2773 order_phases : ndarray 

2774 The instantaneous phase associated with each sine tone and 

2775 channel at each time step in a num_tones, num_channels, 

2776 num_timesteps shaped array. Phases in Radians. 

2777 order_start_samples : ndarray 

2778 The starting sample for each tone taking into account the start time 

2779 and the ramp samples 

2780 order_end_samples : ndarray 

2781 The end sample for each tone taking into account the ramp samples. 

2782 

2783 """ 

2784 order_signals = [] 

2785 order_arguments = [] 

2786 order_frequencies = [] 

2787 order_amplitudes = [] 

2788 order_phases = [] 

2789 order_start_samples = [] 

2790 order_end_samples = [] 

2791 longest_signal = 0 

2792 for spec in specifications: 

2793 ( 

2794 ordinate, 

2795 frequency, 

2796 argument, 

2797 amplitude, 

2798 phase, 

2799 _, 

2800 start_index, 

2801 end_index, 

2802 ) = spec.create_signal(sample_rate, ramp_samples, control_index) 

2803 order_signals.append(ordinate) 

2804 order_frequencies.append(frequency) 

2805 order_amplitudes.append(amplitude) 

2806 order_phases.append(phase) 

2807 order_arguments.append(argument) 

2808 order_start_samples.append(start_index) 

2809 order_end_samples.append(end_index) 

2810 

2811 if order_signals[-1].shape[-1] > longest_signal: 

2812 longest_signal = order_signals[-1].shape[-1] 

2813 

2814 # Now that we know the longest signals, we know how much we need to pad 

2815 # to make all signals the same length 

2816 for i, signal in enumerate(order_signals): 

2817 extra_samples = longest_signal - signal.shape[-1] 

2818 end_abscissa = np.arange(1, extra_samples + 1) / sample_rate 

2819 end_arguments = ( 

2820 2 * np.pi * order_frequencies[i][-1] * end_abscissa + order_arguments[i][-1] 

2821 ) 

2822 end_frequencies = np.ones(extra_samples) * order_frequencies[i][-1] 

2823 end_amplitudes = np.zeros((extra_samples)) * order_amplitudes[i][..., [-1]] 

2824 end_phases = np.ones(extra_samples) * order_phases[i][..., [-1]] 

2825 end_signal = np.zeros(extra_samples) * signal[..., [-1]] 

2826 order_signals[i] = np.concatenate((order_signals[i], end_signal), axis=-1) 

2827 order_frequencies[i] = np.concatenate((order_frequencies[i], end_frequencies), axis=-1) 

2828 order_arguments[i] = np.concatenate((order_arguments[i], end_arguments), axis=-1) 

2829 order_amplitudes[i] = np.concatenate((order_amplitudes[i], end_amplitudes), axis=-1) 

2830 order_phases[i] = np.concatenate((order_phases[i], end_phases), axis=-1) 

2831 

2832 order_signals = np.array(order_signals) 

2833 order_frequencies = np.array(order_frequencies) 

2834 order_arguments = np.array(order_arguments) 

2835 order_amplitudes = np.array(order_amplitudes) 

2836 order_phases = np.array(order_phases) 

2837 order_start_samples = np.array(order_start_samples) 

2838 order_end_samples = np.array(order_end_samples) 

2839 signal = np.sum(order_signals, axis=0) 

2840 

2841 return ( 

2842 signal, 

2843 order_signals, 

2844 order_frequencies, 

2845 order_arguments, 

2846 order_amplitudes, 

2847 order_phases, 

2848 order_start_samples, 

2849 order_end_samples, 

2850 ) 

2851 

2852 

2853class FilterExplorer(QtWidgets.QDialog): 

2854 """Dialog box for exploring the Vold-Kalman Filter Settings""" 

2855 

2856 @staticmethod 

2857 def explore_filter_settings( 

2858 channel_names, 

2859 order_names, 

2860 specifications, 

2861 current_filter_type, 

2862 current_tracking_filter_cutoff, 

2863 current_tracking_filter_order, 

2864 current_filter_order, 

2865 current_bandwidth, 

2866 current_block_size, 

2867 current_overlap, 

2868 sample_rate, 

2869 ramp_time, 

2870 acquire_size, 

2871 parent=None, 

2872 ): 

2873 """ 

2874 Brings up the explore filter dialog box 

2875 

2876 Parameters 

2877 ---------- 

2878 channel_names : list of str 

2879 Channel names to use in the dialog box 

2880 order_names : list of st 

2881 Tone names to use in the dialog box 

2882 specifications : list of SineSpecification 

2883 Sine specifications used to compute the signals 

2884 current_filter_type : int 

2885 Choose the starting filter type, 0-DTF, 1-VK 

2886 current_tracking_filter_cutoff : float 

2887 The cutoff for the tracking filter 

2888 current_tracking_filter_order : int 

2889 The filter order for the tracking filter. 

2890 current_filter_order : int 

2891 The filter order for the Vold Kalman filter 

2892 current_bandwidth : float 

2893 The bandwidth for the Vold Kalman filter. 

2894 current_block_size : int 

2895 The number of samples to use when computing the Vold Kalman filter. 

2896 current_overlap : float 

2897 The percentage overlap of the frame size (in percent, so 15, not 

2898 0.15). 

2899 sample_rate : float 

2900 The sample rate of the signals to generate. 

2901 ramp_time : float 

2902 The ramp time added to the signal to ramp to full level. 

2903 acquire_size : int 

2904 The acquisition size in number of samples. 

2905 parent : QWidget, optional 

2906 A parent widget to the dialog box. The default is None. 

2907 

2908 Returns 

2909 ------- 

2910 result : bool 

2911 True if the dialog box was accepted, false if not. 

2912 filter_type : int 

2913 0 if DTF, 1 if VK. 

2914 filter_cutoff : float 

2915 The cutoff value for the DTF. 

2916 tracking_filter_order : int 

2917 The filter order for the DTF. 

2918 filter_order : int 

2919 The filter order for the VK filter. 

2920 filter_bandwidth : float 

2921 The bandwidth for the VK filter. 

2922 filter_blocksize : int 

2923 The number of samples in the analysis block for the VK filter. 

2924 filter_overlap : float 

2925 The overlap percentage (in percent not fraction) of the VK filter. 

2926 

2927 """ 

2928 dialog = FilterExplorer( 

2929 parent, 

2930 channel_names, 

2931 order_names, 

2932 specifications, 

2933 current_filter_type, 

2934 current_tracking_filter_cutoff, 

2935 current_tracking_filter_order, 

2936 current_filter_order, 

2937 current_bandwidth, 

2938 current_block_size, 

2939 current_overlap, 

2940 sample_rate, 

2941 ramp_time, 

2942 acquire_size, 

2943 ) 

2944 result = dialog.exec_() == QtWidgets.QDialog.Accepted 

2945 filter_type = dialog.filter_type_selector.currentIndex() 

2946 filter_order = dialog.filter_order_selector.currentIndex() + 1 

2947 filter_bandwidth = dialog.filter_bandwidth_selector.value() 

2948 filter_blocksize = dialog.filter_block_size_selector.value() 

2949 filter_overlap = dialog.filter_block_overlap_selector.value() 

2950 filter_cutoff = dialog.tracking_filter_cutoff_selector.value() 

2951 tracking_filter_order = dialog.tracking_filter_order_selector.value() 

2952 return ( 

2953 result, 

2954 filter_type, 

2955 filter_cutoff, 

2956 tracking_filter_order, 

2957 filter_order, 

2958 filter_bandwidth, 

2959 filter_blocksize, 

2960 filter_overlap, 

2961 ) 

2962 

2963 def __init__( 

2964 self, 

2965 parent, 

2966 channel_names, 

2967 order_names, 

2968 specifications, 

2969 current_filter_type, 

2970 current_tracking_filter_cutoff, 

2971 current_tracking_filter_order, 

2972 current_filter_order, 

2973 current_bandwidth, 

2974 current_block_size, 

2975 current_overlap, 

2976 sample_rate, 

2977 ramp_time, 

2978 acquire_size, 

2979 ): 

2980 super().__init__(parent) 

2981 uic.loadUi(filter_explorer_ui_path, self) 

2982 

2983 for channel_name in channel_names: 

2984 self.channel_selector.addItem(channel_name) 

2985 

2986 self.order_selector.setSelectionMode(QtWidgets.QListWidget.SingleSelection) 

2987 for order_name in order_names: 

2988 self.order_selector.addItem(order_name) 

2989 

2990 self.full_time_history_plotter = VaryingNumberOfLinePlot( 

2991 self.full_time_history_plot.getPlotItem() 

2992 ) 

2993 self.order_time_history_plotter = VaryingNumberOfLinePlot( 

2994 self.order_time_history_plot.getPlotItem() 

2995 ) 

2996 self.order_phase_plotter = VaryingNumberOfLinePlot(self.order_phase_plot.getPlotItem()) 

2997 self.order_amplitude_plotter = VaryingNumberOfLinePlot( 

2998 self.order_amplitude_plot.getPlotItem() 

2999 ) 

3000 

3001 self.filter_type_selector.setCurrentIndex(current_filter_type) 

3002 self.tracking_filter_order_selector.setValue(current_tracking_filter_order) 

3003 self.tracking_filter_cutoff_selector.setValue(current_tracking_filter_cutoff) 

3004 self.filter_order_selector.setCurrentIndex(current_filter_order - 1) 

3005 self.filter_bandwidth_selector.setValue(current_bandwidth) 

3006 self.filter_block_overlap_selector.setValue(current_overlap) 

3007 self.filter_block_size_selector.setValue(current_block_size) 

3008 

3009 self.ramp_time = ramp_time 

3010 self.sample_rate = sample_rate 

3011 self.specifications = specifications 

3012 self.acquire_size = acquire_size 

3013 

3014 self.signal = None 

3015 self.order_signals = None 

3016 self.order_frequencies = None 

3017 self.order_arguments = None 

3018 self.order_amplitudes = None 

3019 self.order_phases = None 

3020 self.reconstructed_signal = None 

3021 self.reconstructed_order_signals = None 

3022 self.reconstructed_order_amplitudes = None 

3023 self.reconstructed_order_phases = None 

3024 self.setWindowTitle("Sine Filter Explorer") 

3025 

3026 self.change_filter_setting_visibility() 

3027 

3028 self.create_signals() 

3029 

3030 self.update_plots() 

3031 

3032 self.connect_callbacks() 

3033 

3034 def connect_callbacks(self): 

3035 """ 

3036 Connects callback functions to the filter widgets 

3037 """ 

3038 self.accept_button.clicked.connect(self.accept) 

3039 self.reject_button.clicked.connect(self.reject) 

3040 self.order_selector.itemSelectionChanged.connect(self.update_plots) 

3041 self.channel_selector.currentIndexChanged.connect(self.create_and_plot_signals) 

3042 self.filter_type_selector.currentIndexChanged.connect(self.remove_filter_data_and_replot) 

3043 self.filter_order_selector.currentIndexChanged.connect(self.remove_filter_data_and_replot) 

3044 self.filter_bandwidth_selector.valueChanged.connect(self.remove_filter_data_and_replot) 

3045 self.filter_block_size_selector.valueChanged.connect(self.remove_filter_data_and_replot) 

3046 self.filter_block_overlap_selector.valueChanged.connect(self.remove_filter_data_and_replot) 

3047 self.tracking_filter_cutoff_selector.valueChanged.connect( 

3048 self.remove_filter_data_and_replot 

3049 ) 

3050 self.tracking_filter_order_selector.valueChanged.connect(self.remove_filter_data_and_replot) 

3051 self.noise_selector.valueChanged.connect(self.remove_filter_data_and_replot) 

3052 self.compute_button.clicked.connect(self.compute_filter) 

3053 self.filter_type_selector.currentIndexChanged.connect(self.change_filter_setting_visibility) 

3054 

3055 @property 

3056 def ramp_samples(self): 

3057 """Number of ramp samples computed from sample rate and ramp time""" 

3058 return int(self.ramp_time * self.sample_rate) 

3059 

3060 @property 

3061 def channel_index(self): 

3062 """Currently selected channel index""" 

3063 return self.channel_selector.currentIndex() 

3064 

3065 def change_filter_setting_visibility(self): 

3066 """Updates the visible widgets based on which filter type is selected""" 

3067 isdtf = self.filter_type_selector.currentIndex() == 0 

3068 for widget in [ 

3069 self.filter_order_label, 

3070 self.filter_order_selector, 

3071 self.filter_block_overlap_label, 

3072 self.filter_block_overlap_selector, 

3073 self.filter_bandwidth_label, 

3074 self.filter_bandwidth_selector, 

3075 self.filter_block_size_label, 

3076 self.filter_block_size_selector, 

3077 ]: 

3078 widget.setVisible(not isdtf) 

3079 for widget in [ 

3080 self.tracking_filter_cutoff_label, 

3081 self.tracking_filter_cutoff_selector, 

3082 self.tracking_filter_order_label, 

3083 self.tracking_filter_order_selector, 

3084 ]: 

3085 widget.setVisible(isdtf) 

3086 

3087 def create_signals(self): 

3088 """ 

3089 Creates signals from the specification to plot 

3090 """ 

3091 ( 

3092 self.signal, 

3093 self.order_signals, 

3094 self.order_frequencies, 

3095 self.order_arguments, 

3096 self.order_amplitudes, 

3097 self.order_phases, 

3098 _, 

3099 _, 

3100 ) = SineSpecification.create_combined_signals( 

3101 self.specifications, self.sample_rate, self.ramp_samples, self.channel_index 

3102 ) 

3103 

3104 self.reconstructed_signal = None 

3105 self.reconstructed_order_signals = None 

3106 self.reconstructed_order_amplitudes = None 

3107 self.reconstructed_order_phases = None 

3108 

3109 def remove_filter_data_and_replot(self): 

3110 """Removes existing filter data and updates the plots""" 

3111 self.reconstructed_signal = None 

3112 self.reconstructed_order_signals = None 

3113 self.reconstructed_order_amplitudes = None 

3114 self.reconstructed_order_phases = None 

3115 self.update_plots() 

3116 

3117 def compute_filter(self): 

3118 """Performs the filtering operations""" 

3119 if self.filter_type_selector.currentIndex() == 0: 

3120 block_size = self.acquire_size 

3121 generator = [ 

3122 digital_tracking_filter_generator( 

3123 dt=1 / self.sample_rate, 

3124 cutoff_frequency_ratio=self.tracking_filter_cutoff_selector.value() / 100, 

3125 filter_order=self.tracking_filter_order_selector.value(), 

3126 ) 

3127 for tone in self.order_signals 

3128 ] 

3129 for gen in generator: 

3130 gen.send(None) 

3131 else: 

3132 block_size = self.filter_block_size_selector.value() 

3133 generator = vold_kalman_filter_generator( 

3134 sample_rate=self.sample_rate, 

3135 num_orders=self.order_signals.shape[0], 

3136 block_size=block_size, 

3137 overlap=self.filter_block_overlap_selector.value(), 

3138 bandwidth=self.filter_bandwidth_selector.value(), 

3139 filter_order=self.filter_order_selector.currentIndex() + 1, 

3140 ) 

3141 generator.send(None) 

3142 

3143 # print(f"{self.signal.shape=}") 

3144 start_index = 0 

3145 reconstructed_signals = [] 

3146 reconstructed_amplitudes = [] 

3147 reconstructed_phases = [] 

3148 

3149 last_data = False 

3150 while not last_data: 

3151 end_index = start_index + block_size 

3152 block = self.signal[start_index:end_index] 

3153 block = block + self.noise_selector.value() * np.random.randn(block.size) 

3154 block_arguments = self.order_arguments[:, start_index:end_index] 

3155 block_frequencies = self.order_frequencies[:, start_index:end_index] 

3156 last_data = end_index >= self.signal.size 

3157 if self.filter_type_selector.currentIndex() == 0: 

3158 amps = [] 

3159 phss = [] 

3160 for arg, freq, gen in zip(block_arguments, block_frequencies, generator): 

3161 amp, phs = gen.send((block, freq, arg)) 

3162 amps.append(amp) 

3163 phss.append(phs) 

3164 reconstructed_amplitudes.append(np.array(amps)) 

3165 reconstructed_phases.append(np.array(phss)) 

3166 reconstructed_signals.append( 

3167 np.array(amps) * np.cos(block_arguments + np.array(phss)) 

3168 ) 

3169 else: 

3170 vk_signals, vk_amplitudes, vk_phases = generator.send( 

3171 (block, block_arguments, last_data) 

3172 ) 

3173 if vk_signals is not None: 

3174 reconstructed_signals.append(vk_signals) 

3175 reconstructed_amplitudes.append(vk_amplitudes) 

3176 reconstructed_phases.append(vk_phases) 

3177 start_index += block_size 

3178 

3179 self.reconstructed_order_signals = reconstructed_signals 

3180 self.reconstructed_order_amplitudes = reconstructed_amplitudes 

3181 self.reconstructed_order_phases = reconstructed_phases 

3182 self.reconstructed_signal = [ 

3183 np.sum(value, axis=0) for value in self.reconstructed_order_signals 

3184 ] 

3185 # print(f"{[sig.shape for sig in self.reconstructed_order_signals]=}") 

3186 

3187 self.update_plots() 

3188 

3189 def update_plots(self): 

3190 """Updates the plots""" 

3191 abscissa_plot_vals = [] 

3192 signal_plot_vals = [] 

3193 frequency_plot_vals = [] 

3194 amplitude_plot_vals = [] 

3195 phase_plot_vals = [] 

3196 full_signal_plot_vals = [] 

3197 abscissa = np.arange(self.signal.size) / self.sample_rate 

3198 

3199 try: 

3200 selected_index = [ 

3201 idx.row() for idx in self.order_selector.selectionModel().selectedIndexes() 

3202 ][0] 

3203 except IndexError: 

3204 selected_index = 0 

3205 full_signal_plot_vals.append(self.signal) 

3206 abscissa_plot_vals.append(abscissa) 

3207 signal_plot_vals.append(self.order_signals[selected_index]) 

3208 frequency_plot_vals.append(self.order_frequencies[selected_index]) 

3209 amplitude_plot_vals.append(self.order_amplitudes[selected_index]) 

3210 phase_plot_vals.append(self.order_phases[selected_index] * 180 / np.pi) 

3211 

3212 if self.reconstructed_order_signals is not None: 

3213 if self.plot_separate_frames_selector.isChecked(): 

3214 start_index = 0 

3215 for ( 

3216 block_order_signal, 

3217 block_order_amplitude, 

3218 block_order_phase, 

3219 block_signal, 

3220 ) in zip( 

3221 self.reconstructed_order_signals, 

3222 self.reconstructed_order_amplitudes, 

3223 self.reconstructed_order_phases, 

3224 self.reconstructed_signal, 

3225 ): 

3226 end_index = start_index + block_signal.shape[-1] 

3227 block_abscissa = abscissa[start_index:end_index] 

3228 block_frequency = self.order_frequencies[selected_index, start_index:end_index] 

3229 abscissa_plot_vals.append(block_abscissa) 

3230 full_signal_plot_vals.append(block_signal) 

3231 signal_plot_vals.append(block_order_signal[selected_index]) 

3232 frequency_plot_vals.append(block_frequency) 

3233 amplitude_plot_vals.append(block_order_amplitude[selected_index]) 

3234 phase_plot_vals.append(block_order_phase[selected_index] * 180 / np.pi) 

3235 start_index = end_index 

3236 else: 

3237 abscissa_plot_vals.append(abscissa) 

3238 full_signal_plot_vals.append(np.concatenate(self.reconstructed_signal, axis=-1)) 

3239 signal_plot_vals.append( 

3240 np.concatenate(self.reconstructed_order_signals, axis=-1)[selected_index] 

3241 ) 

3242 # print(f"{[v.shape for v in self.reconstructed_order_amplitudes]}") 

3243 amplitude_plot_vals.append( 

3244 np.concatenate(self.reconstructed_order_amplitudes, axis=-1)[selected_index] 

3245 ) 

3246 phase_plot_vals.append( 

3247 np.concatenate(self.reconstructed_order_phases, axis=-1)[selected_index] 

3248 * 180 

3249 / np.pi 

3250 ) 

3251 frequency_plot_vals.append(self.order_frequencies[selected_index]) 

3252 

3253 self.full_time_history_plotter.set_data(abscissa_plot_vals, full_signal_plot_vals) 

3254 self.order_time_history_plotter.set_data(abscissa_plot_vals, signal_plot_vals) 

3255 self.order_amplitude_plotter.set_data(frequency_plot_vals, amplitude_plot_vals) 

3256 self.order_phase_plotter.set_data(frequency_plot_vals, phase_plot_vals) 

3257 

3258 def create_and_plot_signals(self): 

3259 """Creates signals then plots the signals""" 

3260 self.create_signals() 

3261 self.update_plots() 

3262 

3263 

3264class PlotSineWindow(QtWidgets.QDialog): 

3265 """Class defining a subwindow that displays specific channel information""" 

3266 

3267 def __init__(self, parent, ui, tone_index, channel_index): 

3268 """ 

3269 Creates a window showing amplitude and phase information. 

3270 

3271 Parameters 

3272 ---------- 

3273 parent : QWidget 

3274 Parent of the window. 

3275 ui : SineUI 

3276 The User Interface of the Sine Controller 

3277 tone_index : int 

3278 Index specifying the tone to plot 

3279 channel_index : int 

3280 Index specifying the channel to plot 

3281 """ 

3282 super(QtWidgets.QDialog, self).__init__(parent) 

3283 self.setWindowFlags(self.windowFlags() & Qt.Tool) 

3284 self.tone_index = tone_index 

3285 self.channel_index = channel_index 

3286 spec_frequency = ui.specification_frequencies[tone_index] 

3287 spec_amplitude = ui.specification_amplitudes[tone_index, channel_index] 

3288 spec_phase = wrap(ui.specification_phases[tone_index, channel_index]) 

3289 spec = ui.environment_parameters.specifications[tone_index] 

3290 warn_freq = np.repeat(spec.breakpoint_table["frequency"], 2) 

3291 warn_low = spec.breakpoint_table["warning"][:, 0, :, channel_index].flatten() 

3292 warn_high = spec.breakpoint_table["warning"][:, 1, :, channel_index].flatten() 

3293 abort_low = spec.breakpoint_table["abort"][:, 0, :, channel_index].flatten() 

3294 abort_high = spec.breakpoint_table["abort"][:, 1, :, channel_index].flatten() 

3295 tone_name = spec.name 

3296 channel_name = ui.initialized_control_names[channel_index] 

3297 # Now plot the data 

3298 layout = QtWidgets.QVBoxLayout() 

3299 amp_plotwidget = pqtg.PlotWidget() 

3300 layout.addWidget(amp_plotwidget) 

3301 phs_plotwidget = pqtg.PlotWidget() 

3302 layout.addWidget(phs_plotwidget) 

3303 self.setLayout(layout) 

3304 amp_plot_item = amp_plotwidget.getPlotItem() 

3305 phs_plot_item = phs_plotwidget.getPlotItem() 

3306 for plot_item in [amp_plot_item, phs_plot_item]: 

3307 plot_item.showGrid(True, True, 0.25) 

3308 plot_item.enableAutoRange() 

3309 plot_item.getViewBox().enableAutoRange(enable=True) 

3310 amp_plot_item.plot(spec_frequency, spec_amplitude, pen={"color": "b", "width": 1}) 

3311 phs_plot_item.plot(spec_frequency, spec_phase, pen={"color": "b", "width": 1}) 

3312 amp_plot_item.plot( 

3313 warn_freq, 

3314 warn_low, 

3315 pen={"color": (255, 204, 0), "width": 1, "style": Qt.DashLine}, 

3316 ) 

3317 amp_plot_item.plot( 

3318 warn_freq, 

3319 warn_high, 

3320 pen={"color": (255, 204, 0), "width": 1, "style": Qt.DashLine}, 

3321 ) 

3322 amp_plot_item.plot( 

3323 warn_freq, 

3324 abort_low, 

3325 pen={"color": (153, 0, 0), "width": 1, "style": Qt.DashLine}, 

3326 ) 

3327 amp_plot_item.plot( 

3328 warn_freq, 

3329 abort_high, 

3330 pen={"color": (153, 0, 0), "width": 1, "style": Qt.DashLine}, 

3331 ) 

3332 if ui.achieved_excitation_frequencies is not None: 

3333 achieved_frequency = np.concatenate( 

3334 [fh[tone_index] for fh in ui.achieved_excitation_frequencies] 

3335 ) 

3336 achieved_amplitude = np.concatenate( 

3337 [ah[tone_index, channel_index] for ah in ui.achieved_response_amplitudes] 

3338 ) 

3339 achieved_phase = np.concatenate( 

3340 [ph[tone_index, channel_index] for ph in ui.achieved_response_phases] 

3341 ) 

3342 else: 

3343 achieved_frequency = np.array([0, 1]) 

3344 achieved_amplitude = np.nan * np.ones(2) 

3345 achieved_phase = np.nan * np.ones(2) 

3346 self.amp_curve = amp_plot_item.plot( 

3347 achieved_frequency, achieved_amplitude, pen={"color": "r", "width": 1} 

3348 ) 

3349 self.phs_curve = phs_plot_item.plot( 

3350 achieved_frequency, achieved_phase, pen={"color": "r", "width": 1} 

3351 ) 

3352 self.setWindowTitle(f"{tone_name} {channel_name}") 

3353 self.ui = ui 

3354 self.show() 

3355 

3356 def update_plot(self): 

3357 """Updates the plots with new data""" 

3358 if self.ui.achieved_excitation_frequencies is not None: 

3359 achieved_frequency = np.concatenate( 

3360 [fh[self.tone_index] for fh in self.ui.achieved_excitation_frequencies] 

3361 ) 

3362 achieved_amplitude = np.concatenate( 

3363 [ 

3364 ah[self.tone_index, self.channel_index] 

3365 for ah in self.ui.achieved_response_amplitudes 

3366 ] 

3367 ) 

3368 achieved_phase = np.concatenate( 

3369 [ph[self.tone_index, self.channel_index] for ph in self.ui.achieved_response_phases] 

3370 ) 

3371 else: 

3372 achieved_frequency = np.array([0, 1]) 

3373 achieved_amplitude = np.nan * np.ones(2) 

3374 achieved_phase = np.nan * np.ones(2) 

3375 self.amp_curve.setData(achieved_frequency, achieved_amplitude) 

3376 self.phs_curve.setData(achieved_frequency, achieved_phase)