1# -*- coding: utf-8 -*-
2"""
3Created on Mon Mar 31 10:19:37 2025
4
5@author: dprohe
6"""
7
8import os
9
10# import time
11
12import numpy as np
13from scipy.io import loadmat
14from scipy.signal import windows
15from scipy.sparse import linalg
16from scipy import sparse
17from scipy.signal import lfilter, lfiltic, butter
18import pyqtgraph as pqtg
19from qtpy import QtWidgets, uic
20from qtpy.QtCore import Qt, QLocale # pylint: disable=no-name-in-module
21
22from .environments import (
23 sine_sweep_table_ui_path,
24 filter_explorer_ui_path,
25)
26from .ui_utilities import VaryingNumberOfLinePlot
27from .utilities import wrap
28
29
30DEBUG = False
31
32if DEBUG:
33 import pickle
34
35
36def load_specification(spec_path):
37 """Loads a sine specification from a .mat or .npz file
38
39 Assumes the phases are represented in degrees
40
41 Parameters
42 ----------
43 spec_path : str
44 The path to the specification to load.
45
46 Returns
47 -------
48 frequency : ndarray
49 The frequency breakpoint of the specificatyion.
50 amplitude : ndarray
51 The amplitudes at the breakpoints for each channel.
52 phase : ndarray
53 The phase in degrees of the specification
54 sweep_type : ndarray
55 The sweep types at each breakpoint.
56 sweep_rate : ndarray
57 The sweep rate at each breakpoint.
58 warning : ndarray
59 Upper, lower, left, and right warning levels at each
60 breakpoint
61 abort : ndarray
62 Upper, lower, left, and right abort levels at each breakpoint.
63 start_time : float
64 The start time for the sine sweep.
65 name : str
66 The name of the sine sweep.
67
68 """
69 _, extension = os.path.splitext(spec_path)
70 if extension.lower() == ".mat":
71 data = loadmat(spec_path)
72 else:
73 data = np.load(spec_path)
74 frequency = data["frequency"].flatten()
75 amplitude = data["amplitude"]
76 if "phase" in data:
77 phase = data["phase"] # Degrees
78 else:
79 phase = None
80 if "sweep_rate" in data:
81 sweep_rate = data["sweep_rate"].flatten()
82 else:
83 sweep_rate = None
84 if "sweep_type" in data:
85 sweep_type = data["sweep_type"].flatten()
86 else:
87 sweep_type = None
88 if "warning" in data:
89 warning = data["warning"]
90 else:
91 warning = None
92 if "abort" in data:
93 abort = data["abort"]
94 else:
95 abort = None
96 if "start_time" in data:
97 start_time = data["start_time"]
98 else:
99 start_time = None
100 if "name" in data:
101 name = data["name"][()]
102 else:
103 name = None
104 return (
105 frequency,
106 amplitude,
107 phase,
108 sweep_type,
109 sweep_rate,
110 warning,
111 abort,
112 start_time,
113 name,
114 )
115
116
117def sine_sweep(
118 dt,
119 frequencies,
120 sweep_rates,
121 sweep_types,
122 amplitudes=1,
123 phases=0,
124 return_argument=False,
125 return_frequency=False,
126 return_amplitude=False,
127 return_phase=False,
128 return_abscissa=False,
129 only_breakpoints=False,
130):
131 """
132 Generates a sweeping sine wave with linear or logarithmic sweep rate
133
134 Parameters
135 ----------
136 dt : float
137 The time step of the output signal
138 frequencies : iterable
139 A list of frequency breakpoints for the sweep. Can be ascending or
140 decending or both. Frequencies are specified in Hz, not rad/s.
141 sweep_rates : iterable
142 A list of sweep rates between the breakpoints. This array should have
143 one fewer element than the `frequencies` array. The ith element of this
144 array specifies the sweep rate between `frequencies[i]` and
145 `frequencies[i+1]`. For a linear sweep,
146 the rate is in Hz/s. For a logarithmic sweep, the rate is in octave/s.
147 sweep_types : iterable or str
148 The type of sweep to perform between each frequency breakpoint. Can be
149 'lin' or 'log'. If a string is specified, it will be used for all
150 breakpoints. Otherwise it should be an array containing strings with
151 one fewer element than that of the `frequencies` array.
152 amplitudes : iterable or float, optional
153 Amplitude of the cosine wave at each of the frequency breakpoints. Can
154 be specified as a single floating point value, or as an array with a
155 value specified for each breakpoint. The default is 1.
156 phases : iterable or float, optional
157 Phases in radians of the cosine wave at each of the frequency breakpoints. Can
158 be specified as a single floating point value, or as an array with a
159 value specified for each breakpoint. Be aware that modifying the phase
160 between breakpoints will effectively change the frequency of the signal,
161 because the phase will change over time. The default is 0.
162 return_argument : bool
163 If True, return cosine argument over time
164 return_frequency : bool
165 If True, return the instantaneous frequency over time
166 return_amplitude : bool
167 If True, return the instantaneous amplitude over time
168 return_phase : bool
169 If True, return the instantaneous phase over time
170 return_abscissa : bool
171 If True, return the instantaneous abscissa over time
172 only_breakpoints : bool
173 If True, only returns data at breakpoints. Default is False
174
175 Raises
176 ------
177 ValueError
178 If the sweep rate and start and end frequency would result in a negative
179 sweep time, for example if the start frequency is above the end frequency
180 and a positive sweep rate is specified.
181
182 Returns
183 -------
184 ordinate : np.ndarray
185 A numpy array consisting of the generated sine sweep signal. The length
186 of the signal will be determined by the frequency breakpoints and sweep
187 rates.
188 arg_over_time : np.ndarray
189 A numpy array consisting of the argument to the cosine wave over time.
190 freq_over_time : np.ndarray
191 A numpy array consisting of the frequency of the cosine wave over time.
192 amp_over_time : np.ndarray
193 A numpy array consistsing of the amplitude of the cosine wave over time.
194 phs_over_time : np.ndarray
195 A numpy array consisting of the added phase in radians of the cosine
196 wave over time.
197 abscissa : np.ndarray
198 A numpy array consisting of the time value at each time step returned
199
200 """
201 last_phase = 0
202 last_abscissa = 0
203 abscissa = []
204 ordinate = []
205 arg_over_time = []
206 freq_over_time = []
207 amp_over_time = []
208 phs_over_time = []
209
210 # Go through each section
211 for i in range(len(frequencies) - 1):
212 # Extract the terms
213 start_frequency = frequencies[i]
214 end_frequency = frequencies[i + 1]
215 omega_start = start_frequency * 2 * np.pi
216 try:
217 sweep_rate = sweep_rates[i]
218 except TypeError:
219 sweep_rate = sweep_rates
220 if isinstance(sweep_types, str):
221 sweep_type = sweep_types
222 else:
223 sweep_type = sweep_types[i]
224 try:
225 start_amplitude = amplitudes[i]
226 end_amplitude = amplitudes[i + 1]
227 except TypeError:
228 start_amplitude = amplitudes
229 end_amplitude = amplitudes
230 try:
231 start_phase = phases[i] # Radians
232 end_phase = phases[i + 1] # Radians
233 except TypeError:
234 start_phase = phases # Radians
235 end_phase = phases # Radians
236 # Compute the length of this portion of the signal
237 if sweep_type.lower() in ["lin", "linear"]:
238 sweep_time = +(end_frequency - start_frequency) / sweep_rate
239 elif sweep_type.lower() in ["log", "logarithmic"]:
240 sweep_time = np.log(end_frequency / start_frequency) / (sweep_rate * np.log(2))
241 else:
242 raise ValueError("Sweep type should be one of lin, linear, log, or logarithmic")
243 if sweep_time < 0:
244 raise ValueError(f"Sweep time for segment index {i} is negative. Check sweep rate.")
245 sweep_samples = int(np.floor(sweep_time / dt))
246 # Construct the abscissa
247 if only_breakpoints:
248 this_abscissa = np.array([0, sweep_samples * dt])
249 else:
250 this_abscissa = np.arange(sweep_samples + 1) * dt
251 # Compute the phase over time
252 if sweep_type.lower() in ["lin", "linear"]:
253 this_argument = (1 / 2) * (
254 sweep_rate * 2 * np.pi
255 ) * this_abscissa**2 + omega_start * this_abscissa
256 this_frequency = (sweep_rate) * this_abscissa + omega_start / (2 * np.pi)
257 elif sweep_type.lower() in ["log", "logarithmic"]:
258 this_argument = 2 ** (sweep_rate * this_abscissa) * omega_start / (
259 sweep_rate * np.log(2)
260 ) - omega_start / (sweep_rate * np.log(2))
261 this_frequency = 2 ** (sweep_rate * this_abscissa) * omega_start / (2 * np.pi)
262 else:
263 raise ValueError("Invalid sweep type, should be linear, lin, logarithmic, or log")
264 # Compute the phase at each time step
265 if end_frequency > start_frequency:
266 freq_interp = [start_frequency, end_frequency]
267 phase_interp = [start_phase, end_phase]
268 amp_interp = [start_amplitude, end_amplitude]
269 else:
270 freq_interp = [end_frequency, start_frequency]
271 phase_interp = [end_phase, start_phase]
272 amp_interp = [end_amplitude, start_amplitude]
273 this_phases = np.interp(this_frequency, freq_interp, phase_interp)
274 # Compute the amplitude at each time step
275 this_amplitudes = np.interp(this_frequency, freq_interp, amp_interp)
276 this_ordinate = this_amplitudes * np.cos(this_argument + this_phases + last_phase)
277 if i == len(frequencies) - 2:
278 last_index = None # If it's the last segment, go up until the end
279 else:
280 last_index = -1 # Otherwise, we remove the last point because the first point of the
281 # next segment will be this value
282 arg_over_time.append(this_argument[:last_index] + last_phase)
283 last_phase += this_argument[-1]
284 abscissa.append(this_abscissa[:last_index] + last_abscissa)
285 last_abscissa += this_abscissa[-1]
286 ordinate.append(this_ordinate[:last_index])
287 freq_over_time.append(this_frequency[:last_index])
288 amp_over_time.append(this_amplitudes[:last_index])
289 phs_over_time.append(this_phases[:last_index])
290 ordinate = np.concatenate(ordinate)
291 return_vals = [ordinate]
292 if return_argument:
293 return_vals.append(np.concatenate(arg_over_time))
294 if return_frequency:
295 return_vals.append(np.concatenate(freq_over_time))
296 if return_amplitude:
297 return_vals.append(np.concatenate(amp_over_time))
298 if return_phase:
299 return_vals.append(np.concatenate(phs_over_time))
300 if return_abscissa:
301 return_vals.append(np.concatenate(abscissa))
302 if len(return_vals) == 1:
303 return_vals = return_vals[0]
304 else:
305 return_vals = tuple(return_vals)
306 return return_vals
307
308
309class NoWheelSpinBox(QtWidgets.QDoubleSpinBox):
310 """A simple class to remove the scroll wheel capability from a spin box"""
311
312 def wheelEvent(self, event): # pylint: disable=invalid-name
313 """Capture the wheel event but ignore it"""
314 event.ignore()
315
316
317class AdaptiveNoWheelSpinBox(NoWheelSpinBox):
318 """A spinbox that changes number of decimals based on the value provided"""
319
320 localization = QLocale(QLocale.English, QLocale.UnitedStates)
321
322 def __init__(self, parent=None):
323 super().__init__(parent)
324
325 self.setDecimals(10)
326
327 def textFromValue(self, value): # pylint: disable=invalid-name
328 """Gets the text to show in the spinbox based on the value stored in the spinbox"""
329 return AdaptiveNoWheelSpinBox.localization.toString(value, "g", self.decimals())
330
331
332class NoWheelComboBox(QtWidgets.QComboBox):
333 """A simple class to remove the scroll wheel capability from a combo box"""
334
335 def wheelEvent(self, event): # pylint: disable=invalid-name
336 """Capture the wheel event but ignore it"""
337 event.ignore()
338
339
340class SineSweepTable:
341 """A class representing a breakpoint table defining a sine sweep"""
342
343 def __init__(
344 self,
345 parent_tabwidget: QtWidgets.QTabWidget,
346 update_specification_function,
347 remove_function,
348 control_names,
349 data_acquisition_parameters,
350 ):
351 """Initializes a sine sweep table to represent the breakpoints of a sine tone
352
353 Parameters
354 ----------
355 parent_tabwidget : QtWidgets.QTabWidget
356 The parent tabwidget in the sine ui class, which is needed to
357 propogate changes in this widget back up to the main UI class
358 update_specification_function : function
359 The function to call to update the specification when the values
360 in this table have changed
361 remove_function : function
362 The function to call when we remove a table from the tab widget
363 control_names : list of str
364 A list of strings to be used as the control channel names in the
365 table
366 data_acquisition_parameters : DataAcquisitionParameters
367 The data acquisition parameters, including sample rate
368 """
369 self.parent_tabwidget = parent_tabwidget
370 self.update_specification_function = update_specification_function
371 self.remove_function = remove_function
372 self.control_names = control_names
373 self.data_acquisition_parameters = data_acquisition_parameters
374 self.widget = QtWidgets.QWidget()
375 uic.loadUi(sine_sweep_table_ui_path, self.widget)
376 self.index = self.parent_tabwidget.count() - 1
377 self.parent_tabwidget.insertTab(self.index, self.widget, f"Sine {self.index+1}")
378 self.widget.name_editor.setText(f"Sine {self.index+1}")
379 self.parent_tabwidget.setCurrentIndex(self.index)
380 self.connect_callbacks()
381 self.clear_and_update_specification_table()
382
383 def connect_callbacks(self):
384 """Connects the widgets in the UI to methods of the object"""
385 self.widget.add_breakpoint_button.clicked.connect(self.add_breakpoint)
386 self.widget.remove_breakpoint_button.clicked.connect(self.remove_breakpoint)
387 self.widget.load_breakpoints_button.clicked.connect(self.load_specification)
388 self.widget.name_editor.editingFinished.connect(self.update_name)
389 self.widget.start_time_selector.valueChanged.connect(self.update_specification_function)
390 self.widget.remove_tone_button.clicked.connect(self.remove_tone)
391
392 def add_breakpoint(self):
393 """Adds a breakpoint to the table"""
394 selected_indices = self.widget.breakpoint_table.selectedIndexes()
395 if selected_indices:
396 selected_row = selected_indices[0].row()
397 else:
398 # If no row is selected, add the row at the start
399 selected_row = 0
400 control_names = self.control_names
401 self.widget.breakpoint_table.insertRow(selected_row)
402 self.widget.warning_table.insertRow(selected_row)
403 self.widget.abort_table.insertRow(selected_row)
404 # Frequency display, Breakpoint Table
405 spinbox = AdaptiveNoWheelSpinBox()
406 spinbox.setRange(0, self.data_acquisition_parameters.sample_rate / 2)
407 spinbox.setSingleStep(1)
408 spinbox.setValue(0)
409 spinbox.setKeyboardTracking(False)
410 spinbox.valueChanged.connect(self.update_specification_function)
411 self.widget.breakpoint_table.setCellWidget(selected_row, 0, spinbox)
412 # Frequency display, warning table
413 spinbox = AdaptiveNoWheelSpinBox()
414 spinbox.setRange(0, self.data_acquisition_parameters.sample_rate / 2)
415 spinbox.setSingleStep(1)
416 spinbox.setValue(0)
417 spinbox.setKeyboardTracking(False)
418 spinbox.setReadOnly(True)
419 spinbox.setButtonSymbols(AdaptiveNoWheelSpinBox.NoButtons)
420 self.widget.warning_table.setCellWidget(selected_row, 0, spinbox)
421 # Frequency display, abort table
422 spinbox = AdaptiveNoWheelSpinBox()
423 spinbox.setRange(0, self.data_acquisition_parameters.sample_rate / 2)
424 spinbox.setSingleStep(1)
425 spinbox.setValue(0)
426 spinbox.setKeyboardTracking(False)
427 spinbox.setReadOnly(True)
428 self.widget.abort_table.setCellWidget(selected_row, 0, spinbox)
429 # Linear or logarithmic selector
430 combobox = NoWheelComboBox()
431 combobox.addItems(["Linear", "Logarithmic"])
432 combobox.setCurrentIndex(0)
433 combobox.currentIndexChanged.connect(self.update_specification_function)
434 self.widget.breakpoint_table.setCellWidget(selected_row, 1, combobox)
435 # Rate selector
436 spinbox = AdaptiveNoWheelSpinBox()
437 spinbox.setRange(-1000000, 1000000)
438 spinbox.setSingleStep(1)
439 spinbox.setValue(1)
440 spinbox.setSuffix(" Hz/s")
441 spinbox.setKeyboardTracking(False)
442 spinbox.valueChanged.connect(self.update_specification_function)
443 self.widget.breakpoint_table.setCellWidget(selected_row, 2, spinbox)
444 # All of the individual values
445 for j in range(len(control_names)):
446 spinbox = AdaptiveNoWheelSpinBox()
447 spinbox.setRange(0, 1000000)
448 spinbox.setSingleStep(1)
449 spinbox.setValue(1)
450 spinbox.setKeyboardTracking(False)
451 spinbox.valueChanged.connect(self.update_specification_function)
452 self.widget.breakpoint_table.setCellWidget(selected_row, 3 + j * 2, spinbox)
453 spinbox = AdaptiveNoWheelSpinBox()
454 spinbox.setRange(-1000000, 1000000)
455 spinbox.setSingleStep(1)
456 spinbox.setValue(0)
457 spinbox.setKeyboardTracking(False)
458 spinbox.valueChanged.connect(self.update_specification_function)
459 self.widget.breakpoint_table.setCellWidget(selected_row, 4 + j * 2, spinbox)
460 for k in range(4):
461 if selected_row == 0 and k in (0, 2):
462 item = self.widget.warning_table.item(selected_row, 1 + k + j * 4)
463 if item is None:
464 item = QtWidgets.QTableWidgetItem()
465 self.widget.warning_table.setItem(selected_row, 1 + k + j * 4, item)
466 item.setFlags(item.flags() & ~Qt.ItemIsEditable)
467 item = self.widget.abort_table.item(selected_row, 1 + k + j * 4)
468 if item is None:
469 item = QtWidgets.QTableWidgetItem()
470 self.widget.abort_table.setItem(selected_row, 1 + k + j * 4, item)
471 item.setFlags(item.flags() & ~Qt.ItemIsEditable)
472 spinbox = AdaptiveNoWheelSpinBox()
473 spinbox.setRange(0, 1000000)
474 spinbox.setSingleStep(1)
475 spinbox.setValue(0)
476 spinbox.setKeyboardTracking(False)
477 spinbox.setSpecialValueText("Disabled")
478 spinbox.valueChanged.connect(self.update_specification_function)
479 self.widget.warning_table.setCellWidget(
480 selected_row + (1 if selected_row == 0 and k in (0, 2) else 0),
481 1 + k + j * 4,
482 spinbox,
483 )
484 spinbox = AdaptiveNoWheelSpinBox()
485 spinbox.setRange(0, 1000000)
486 spinbox.setSingleStep(1)
487 spinbox.setValue(0)
488 spinbox.setKeyboardTracking(False)
489 spinbox.setSpecialValueText("Disabled")
490 spinbox.valueChanged.connect(self.update_specification_function)
491 self.widget.abort_table.setCellWidget(
492 selected_row + (1 if selected_row == 0 and k in (0, 2) else 0),
493 1 + k + j * 4,
494 spinbox,
495 )
496 self.update_specification_function()
497
498 def remove_breakpoint(self):
499 """Removes a breakpoint from the table"""
500 selected_indices = self.widget.breakpoint_table.selectedIndexes()
501 if selected_indices:
502 selected_row = selected_indices[0].row()
503 else:
504 # If no row is selected, remove the last row
505 selected_row = self.widget.breakpoint_table.rowCount() - 1
506 if selected_row == self.widget.breakpoint_table.rowCount() - 1:
507 last_row = True
508 else:
509 last_row = False
510 if selected_row == 0:
511 first_row = True
512 else:
513 first_row = False
514 self.widget.breakpoint_table.removeRow(selected_row)
515 self.widget.warning_table.removeRow(selected_row)
516 self.widget.abort_table.removeRow(selected_row)
517 if last_row:
518 new_last_row_index = self.widget.breakpoint_table.rowCount() - 1
519 for column in [1, 2]:
520 widget = self.widget.breakpoint_table.cellWidget(new_last_row_index, column)
521 if widget:
522 # Remove the widget from the cell
523 self.widget.breakpoint_table.removeCellWidget(new_last_row_index, column)
524 widget.deleteLater()
525 item = self.widget.breakpoint_table.item(new_last_row_index, column)
526 if item is None:
527 item = QtWidgets.QTableWidgetItem()
528 self.widget.breakpoint_table.setItem(new_last_row_index, column, item)
529 item.setFlags(item.flags() & ~Qt.ItemIsEditable)
530 for column in np.arange(2, self.widget.warning_table.columnCount(), 2):
531 for table in [self.widget.warning_table, self.widget.abort_table]:
532 widget = table.cellWidget(new_last_row_index, column)
533 if widget:
534 # Remove the widget from the cell
535 table.removeCellWidget(new_last_row_index, column)
536 widget.deleteLater()
537 item = table.item(new_last_row_index, column)
538 if item is None:
539 item = QtWidgets.QTableWidgetItem()
540 table.setItem(new_last_row_index, column, item)
541 item.setFlags(item.flags() & ~Qt.ItemIsEditable)
542 if first_row:
543 for column in np.arange(1, self.widget.warning_table.columnCount(), 2):
544 for table in [self.widget.warning_table, self.widget.abort_table]:
545 widget = table.cellWidget(0, column)
546 if widget:
547 # Remove the widget from the cell
548 table.removeCellWidget(0, column)
549 widget.deleteLater()
550 item = table.item(0, column)
551 if item is None:
552 item = QtWidgets.QTableWidgetItem()
553 table.setItem(0, column, item)
554 item.setFlags(item.flags() & ~Qt.ItemIsEditable)
555 self.update_specification_function()
556
557 def load_specification(self, clicked, filename=None): # pylint: disable=unused-argument
558 """Loads a breakpoint table using a dialog or the specified filename
559
560 Parameters
561 ----------
562 clicked :
563 The clicked event that triggered the callback.
564 filename :
565 File name defining the specification for bypassing the callback when
566 loading from a file (Default value = None).
567
568 """
569 if filename is None:
570 filename, _ = QtWidgets.QFileDialog.getOpenFileName(
571 self.widget,
572 "Select Specification File",
573 filter="Numpy or Mat (*.npy *.npz *.mat)",
574 )
575 if filename == "":
576 return
577 (
578 frequencies,
579 amplitudes,
580 phases, # Degrees
581 sweep_types,
582 sweep_rates,
583 warnings,
584 aborts,
585 start_time,
586 name,
587 ) = load_specification(filename)
588 self.clear_and_update_specification_table(
589 frequencies,
590 amplitudes,
591 phases, # Degrees
592 sweep_types,
593 sweep_rates,
594 warnings,
595 aborts,
596 start_time,
597 name,
598 )
599 self.update_specification_function()
600
601 def clear_and_update_specification_table(
602 self,
603 frequencies=None,
604 amplitudes=None,
605 phases=None,
606 sweep_types=None,
607 sweep_rates=None,
608 warning_amplitudes=None,
609 abort_amplitudes=None,
610 start_time=None,
611 sine_name=None,
612 control_names=None,
613 ):
614 """Clears the table and updates it with the optional parameters supplied.
615
616 Parameters
617 ----------
618 frequencies : ndarray, optional
619 A 1D array containing the frequencies to use as the breakpoints. By
620 default, a table consisting of two breakpoints will be specified.
621 amplitudes : ndarray, optional
622 A 2D array consisting of amplitudes for (channel, frequency) pairs.
623 If not specified, an amplitude of zero will be used.
624 phases : ndarray, optional
625 A 2D array consisting of phases for (channel, frequency) pairs.
626 If not specified, an amplitude of zero will be used. Phases
627 are specified in degrees.
628 sweep_types : ndarray or list of strings, optional
629 A 1D array of strings to use as the sweep type. They should
630 be one of lin, log, linear, or logarithmic. Linear is used
631 if not specified.
632 sweep_rates : ndarray, optional
633 A 1D array of values to use as the sweep rate. They should
634 be in Hz/s for linear sweeps or octave per minute for
635 logarithmic sweeps.
636 warning_amplitudes : ndarray, optional
637 A 4D ndarray with shape 2, 2, num_channels, num_frequencies.
638 The first dimension specifies upper and lower limits, the
639 second dimension specifies frequencies greater or lower than
640 the frequency breakpoint. If a value is not desired, it
641 should be set to nan.
642 abort_amplitudes : ndarray, optional
643 A 4D ndarray with shape 2, 2, num_channels, num_frequencies.
644 The first dimension specifies upper and lower limits, the
645 second dimension specifies frequencies greater or lower than
646 the frequency breakpoint. If a value is not desired, it
647 should be set to nan.
648 start_time : float, optional
649 The starting time for the specified sine tone. If not specified,
650 it will be set to 0
651 sine_name : str, optional
652 The name of the sine tone used in the software.
653 control_names : array of str, optional
654 Channel names to use in the table. If not specified, the
655 existing channel names will be used.
656 """
657 # print(f'Called clear_and_update_specification with {control_names=}')
658 # print(f'Called clear_and_update_specification with {start_time=}')
659 # print(f'Called clear_and_update_specification with {sine_name=}')
660 if control_names is not None:
661 self.control_names = control_names
662 control_names = self.control_names
663 if frequencies is None:
664 num_rows = 2
665 else:
666 num_rows = frequencies.size
667 self.widget.breakpoint_table.clear()
668 self.widget.breakpoint_table.setRowCount(num_rows)
669 self.widget.breakpoint_table.setColumnCount(3 + 2 * len(control_names))
670 self.widget.warning_table.setRowCount(num_rows)
671 self.widget.warning_table.setColumnCount(1 + 4 * len(control_names))
672 self.widget.abort_table.setRowCount(num_rows)
673 self.widget.abort_table.setColumnCount(1 + 4 * len(control_names))
674 breakpoint_header_labels = ["Frequency", "Sweep Type", "Sweep Rate"]
675 other_header_labels = ["Frequency"]
676 for name in control_names:
677 breakpoint_header_labels.append(name + " Amplitude")
678 breakpoint_header_labels.append(name + " Phase")
679 other_header_labels.append(name + " Lower Left")
680 other_header_labels.append(name + " Lower Right")
681 other_header_labels.append(name + " Upper Left")
682 other_header_labels.append(name + " Upper Right")
683 self.widget.breakpoint_table.setHorizontalHeaderLabels(breakpoint_header_labels)
684 self.widget.warning_table.setHorizontalHeaderLabels(other_header_labels)
685 self.widget.abort_table.setHorizontalHeaderLabels(other_header_labels)
686 # Set up widgets in the table
687 for row in range(num_rows):
688 # Frequency Breakpoint
689 spinbox = AdaptiveNoWheelSpinBox()
690 spinbox.setRange(0, self.data_acquisition_parameters.sample_rate / 2)
691 spinbox.setSingleStep(1)
692 if frequencies is None:
693 spinbox.setValue(0)
694 else:
695 spinbox.setValue(frequencies[row])
696 spinbox.setKeyboardTracking(False)
697 spinbox.setDecimals(4)
698 spinbox.valueChanged.connect(self.update_specification_function)
699 self.widget.breakpoint_table.setCellWidget(row, 0, spinbox)
700 # Frequency display, warning table
701 spinbox = AdaptiveNoWheelSpinBox()
702 spinbox.setRange(0, self.data_acquisition_parameters.sample_rate / 2)
703 spinbox.setSingleStep(1)
704 if frequencies is None:
705 spinbox.setValue(0)
706 else:
707 spinbox.setValue(frequencies[row])
708 spinbox.setKeyboardTracking(False)
709 spinbox.setReadOnly(True)
710 spinbox.setButtonSymbols(AdaptiveNoWheelSpinBox.NoButtons)
711 self.widget.warning_table.setCellWidget(row, 0, spinbox)
712 # Frequency display, abort table
713 spinbox = AdaptiveNoWheelSpinBox()
714 spinbox.setRange(0, self.data_acquisition_parameters.sample_rate / 2)
715 spinbox.setSingleStep(1)
716 if frequencies is None:
717 spinbox.setValue(0)
718 else:
719 spinbox.setValue(frequencies[row])
720 spinbox.setKeyboardTracking(False)
721 spinbox.setReadOnly(True)
722 spinbox.setButtonSymbols(AdaptiveNoWheelSpinBox.NoButtons)
723 self.widget.abort_table.setCellWidget(row, 0, spinbox)
724 # Rate and type
725 if row < num_rows - 1:
726 combobox = NoWheelComboBox()
727 combobox.addItems(["Linear", "Logarithmic"])
728 if sweep_types is not None:
729 if str(sweep_types[row]).lower() in ["lin", "linear"]:
730 combobox.setCurrentIndex(0)
731 else:
732 combobox.setCurrentIndex(1)
733 combobox.currentIndexChanged.connect(self.update_specification_function)
734 self.widget.breakpoint_table.setCellWidget(row, 1, combobox)
735 spinbox = AdaptiveNoWheelSpinBox()
736 spinbox.setRange(-1000000, 1000000)
737 spinbox.setSingleStep(1)
738 if sweep_rates is not None:
739 spinbox.setValue(sweep_rates[row])
740 else:
741 spinbox.setValue(1)
742 if combobox.currentIndex() == 0:
743 spinbox.setSuffix(" Hz/s")
744 else:
745 spinbox.setSuffix(" oct/min")
746 spinbox.setKeyboardTracking(False)
747 spinbox.valueChanged.connect(self.update_specification_function)
748 self.widget.breakpoint_table.setCellWidget(row, 2, spinbox)
749 else:
750 item = self.widget.breakpoint_table.item(row, 1)
751 if item is None:
752 item = QtWidgets.QTableWidgetItem()
753 self.widget.breakpoint_table.setItem(row, 1, item)
754 item.setFlags(item.flags() & ~Qt.ItemIsEditable)
755 item = self.widget.breakpoint_table.item(row, 2)
756 if item is None:
757 item = QtWidgets.QTableWidgetItem()
758 self.widget.breakpoint_table.setItem(row, 2, item)
759 item.setFlags(item.flags() & ~Qt.ItemIsEditable)
760 # Amplitude and Phases
761 for j in range(len(control_names)):
762 spinbox = AdaptiveNoWheelSpinBox()
763 spinbox.setRange(0, 1000000)
764 spinbox.setSingleStep(1)
765 if amplitudes is None:
766 spinbox.setValue(1)
767 else:
768 spinbox.setValue(amplitudes[j, row])
769 spinbox.setKeyboardTracking(False)
770 spinbox.valueChanged.connect(self.update_specification_function)
771 self.widget.breakpoint_table.setCellWidget(row, 3 + j * 2, spinbox)
772 spinbox = AdaptiveNoWheelSpinBox()
773 spinbox.setRange(-1000000, 1000000)
774 spinbox.setSingleStep(1)
775 if phases is None:
776 spinbox.setValue(0)
777 else:
778 spinbox.setValue(phases[j, row])
779 spinbox.valueChanged.connect(self.update_specification_function)
780 spinbox.setKeyboardTracking(False)
781 self.widget.breakpoint_table.setCellWidget(row, 4 + j * 2, spinbox)
782 for k in range(4):
783 spinbox = AdaptiveNoWheelSpinBox()
784 spinbox.setRange(0, 1000000)
785 spinbox.setSingleStep(1)
786 if (
787 row == 0 and k in (0, 2)
788 ) or ( # If first frequency and looking at left side
789 row == num_rows - 1 and k in (1, 3)
790 ): # or if last frequency and looking at right side
791 item = self.widget.warning_table.item(row, 1 + k + j * 4)
792 if item is None:
793 item = QtWidgets.QTableWidgetItem()
794 self.widget.warning_table.setItem(row, 1 + k + j * 4, item)
795 item.setFlags(item.flags() & ~Qt.ItemIsEditable)
796 item = self.widget.abort_table.item(row, 1 + k + j * 4)
797 if item is None:
798 item = QtWidgets.QTableWidgetItem()
799 self.widget.abort_table.setItem(row, 1 + k + j * 4, item)
800 item.setFlags(item.flags() & ~Qt.ItemIsEditable)
801 else:
802 if warning_amplitudes is None:
803 spinbox.setValue(0)
804 else:
805 val = warning_amplitudes[np.unravel_index(k, (2, 2)) + (j, row)]
806 spinbox.setValue(0 if np.isnan(val) else val)
807 spinbox.setKeyboardTracking(False)
808 spinbox.setSpecialValueText("Disabled")
809 spinbox.valueChanged.connect(self.update_specification_function)
810 self.widget.warning_table.setCellWidget(row, 1 + k + j * 4, spinbox)
811 spinbox = AdaptiveNoWheelSpinBox()
812 spinbox.setRange(0, 1000000)
813 spinbox.setSingleStep(1)
814 if abort_amplitudes is None:
815 spinbox.setValue(0)
816 else:
817 val = abort_amplitudes[np.unravel_index(k, (2, 2)) + (j, row)]
818 spinbox.setValue(0 if np.isnan(val) else val)
819 spinbox.setKeyboardTracking(False)
820 spinbox.setSpecialValueText("Disabled")
821 spinbox.valueChanged.connect(self.update_specification_function)
822 self.widget.abort_table.setCellWidget(row, 1 + k + j * 4, spinbox)
823 if sine_name is not None:
824 self.widget.name_editor.setText(sine_name)
825 self.update_name()
826 if start_time is not None:
827 self.widget.start_time_selector.setValue(start_time)
828
829 def update_name(self):
830 """Called when the name of the sine tone is changed"""
831 self.parent_tabwidget.setTabText(self.index, self.widget.name_editor.text())
832
833 def remove_tone(self):
834 """Called when the remove button is pressed"""
835 self.remove_function(self.index)
836
837 def get_specification(self):
838 """Computes a sine sweep specification from the table"""
839 num_control = (self.widget.breakpoint_table.columnCount() - 3) // 2
840 spec = SineSpecification(
841 self.widget.name_editor.text(),
842 self.widget.start_time_selector.value(),
843 num_control,
844 self.widget.breakpoint_table.rowCount(),
845 )
846 for row, spec_row in enumerate(spec.breakpoint_table):
847 spec_row["frequency"] = self.widget.breakpoint_table.cellWidget(row, 0).value()
848 if row < len(spec.breakpoint_table) - 1:
849 spec_row["sweep_type"] = self.widget.breakpoint_table.cellWidget(
850 row, 1
851 ).currentIndex()
852 spec_row["sweep_rate"] = self.widget.breakpoint_table.cellWidget(row, 2).value()
853 for i in range(num_control):
854 spec_row["amplitude"][i] = self.widget.breakpoint_table.cellWidget(
855 row, 3 + 2 * i
856 ).value()
857 spec_row["phase"][i] = (
858 self.widget.breakpoint_table.cellWidget(row, 4 + 2 * i).value() * np.pi / 180
859 ) # Convert degrees to radians for all calculations
860 for k in range(4):
861 ind = np.unravel_index(k, (2, 2))
862 if (row == 0 and k in (0, 2)) or (
863 row == len(spec.breakpoint_table) - 1 and k in (1, 3)
864 ):
865 spec_row["warning"][ind + (i,)] = np.nan
866 spec_row["abort"][ind + (i,)] = np.nan
867 else:
868 val = self.widget.warning_table.cellWidget(row, 1 + k + i * 4).value()
869 spec_row["warning"][ind + (i,)] = np.nan if val == 0 else val
870 val = self.widget.abort_table.cellWidget(row, 1 + k + i * 4).value()
871 spec_row["abort"][ind + (i,)] = np.nan if val == 0 else val
872 return spec
873
874
875def digital_tracking_filter_generator(
876 dt,
877 cutoff_frequency_ratio=0.15,
878 filter_order=2,
879 phase_estimate=None,
880 amplitude_estimate=None,
881):
882 """
883 Computes amplitudes and phases using a digital tracking filter
884
885 Parameters
886 ----------
887 dt : float
888 The time step of the signal
889 cutoff_frequency_ratio : float
890 The cutoff frequency of the low-pass filter compared to the lowest
891 frequency sine tone in each block. Default is 0.15.
892 filter_order : float
893 The filter order of the low-pass butterworth filter. Default is 2.
894 phase_estimate : float
895 An estimate of the initial phase to seed the low-pass filter.
896 amplitude_estimate : float
897 An estimate of the initial amplitude to seed the low-pass filter.
898
899 Sends
900 ------
901 xi : iterable
902 The next block of the signal to be filtered
903 fi : iterable
904 The frequencies at the time steps in xi
905 argsi : iterable
906 The argument to a cosine function at the time steps in xi
907
908 Yields
909 -------
910 amplitude : np.ndarray
911 The amplitude at each time step
912 phase : np.ndarray
913 The phase at each time step
914 """
915 # if plot_results:
916 # fig,ax = plt.subplots(2,2,sharex=True)
917 # ax[0,0].set_ylabel('Signal and Amplitude')
918 # ax[0,1].set_ylabel('Phase')
919 # ax[1,0].set_ylabel('Filtered COLA Signal (cos)')
920 # ax[1,1].set_ylabel('Filtered COLA Signal (sin)')
921 # sample_index = 0
922 # fig.tight_layout()
923 if phase_estimate is None:
924 phase_estimate = 0
925 if amplitude_estimate is None:
926 amplitude_estimate = 0
927
928 xi_0_filt = None
929 xi_90_filt = None
930 xi_0 = None
931 xi_90 = None
932 amplitude = None
933 phase = None
934 while True:
935 xi, fi, argsi = yield amplitude, phase
936 xi = np.array(xi)
937 fi = np.array(fi)
938 argsi = np.array(argsi)
939 # print(f"{cutoff_frequency_ratio=}")
940 # print(f"{np.min(fi)=}")
941 # print(f"{cutoff_frequency_ratio*np.min(fi)=}")
942 b, a = butter(filter_order, cutoff_frequency_ratio * np.min(fi), fs=1 / dt)
943 if xi_0_filt is None:
944 # Set up some fake data to initialize the filter to a good value
945 past_ts = np.arange(-filter_order * 2 - 1, 0) * dt
946 past_xs = amplitude_estimate * np.cos(2 * np.pi * fi[0] * past_ts + phase_estimate)
947 xi_0 = np.cos(2 * np.pi * fi[0] * past_ts) * past_xs
948 xi_90 = -np.sin(2 * np.pi * fi[0] * past_ts) * past_xs
949 xi_0_filt = 0.5 * amplitude_estimate * np.cos(phase_estimate) * np.ones(xi_0.shape)
950 xi_90_filt = 0.5 * amplitude_estimate * np.sin(phase_estimate) * np.ones(xi_90.shape)
951 # if plot_results:
952 # ax[1,0].plot(past_ts,xi_0,'r')
953 # ax[1,0].plot(past_ts,xi_0_filt,'m')
954 # ax[1,1].plot(past_ts,xi_90,'r')
955 # ax[1,1].plot(past_ts,xi_90_filt,'m')
956 # Set up the filter initial states
957 z0i = lfiltic(b, a, xi_0_filt[::-1], xi_0[::-1])
958 z90i = lfiltic(b, a, xi_90_filt[::-1], xi_90[::-1])
959 # Now set up the tracking filter
960 cola0 = np.cos(argsi)
961 cola90 = -np.sin(argsi)
962 xi_0 = cola0 * xi
963 xi_90 = cola90 * xi
964 xi_0_filt, z0i = lfilter(b, a, xi_0, zi=z0i)
965 xi_90_filt, z90i = lfilter(b, a, xi_90, zi=z90i)
966 phase = np.arctan2(xi_90_filt, xi_0_filt)
967 amplitude = 2 * np.sqrt(xi_0_filt**2 + xi_90_filt**2)
968 # if plot_results:
969 # ti = np.arange(sample_index,sample_index + xi.shape[-1])*dt
970 # ax[0,0].plot(ti,xi,'b')
971 # ax[0,0].plot(ti,amplitude,'g')
972 # ax[0,1].plot(ti,phase,'g')
973 # ax[1,0].plot(ti,xi_0,'b')
974 # ax[1,0].plot(ti,xi_0_filt,'g')
975 # ax[1,1].plot(ti,xi_90,'b')
976 # ax[1,1].plot(ti,xi_90_filt,'g')
977 # sample_index += xi.shape[-1]
978
979
980class DefaultSineControlLaw:
981 """A default control law for the sine environment"""
982
983 def __unpickleable_fields__(self):
984 """Defines fields that can't be pickled in the case of an error"""
985 return ["tracking_filters"]
986
987 def __getstate__(self):
988 """Defines how the object is pickled"""
989 state = self.__dict__.copy()
990 for field in self.__unpickleable_fields__():
991 if field in state:
992 del state[field]
993 return state
994
995 def __setstate__(self, state):
996 """Defines how the object is restored from a pickle"""
997 self.__dict__.update(state)
998 for field in self.__unpickleable_fields__():
999 setattr(self, field, None)
1000
1001 def __init__(
1002 self,
1003 sample_rate, # Sample Rate of the data acquisition
1004 specifications, # Specification structured array
1005 output_oversample, # Oversampling required for output
1006 ramp_time, # Length of the ramp on the start and end of the signal
1007 convergence_factor, # Scale factor on the convergence of the closed-loop control
1008 block_size, # Size of writing blocks
1009 buffer_blocks, # Number of write blocks to keep in the buffer
1010 extra_control_parameters, # Required parameters
1011 sysid_frequency_spacing, # Frequency Spacing
1012 sysid_transfer_functions, # Transfer Functions
1013 sysid_response_noise, # Noise levels and correlation
1014 sysid_reference_noise, # from the system identification
1015 sysid_response_cpsd, # Response levels and correlation
1016 sysid_reference_cpsd, # from the system identification
1017 sysid_coherence, # Coherence from the system identification
1018 sysid_frames, # Number of frames in the FRF matrices
1019 ):
1020 """
1021 Initialize the sine control law
1022
1023 Parameters
1024 ----------
1025 sample_rate : float
1026 The sample rate of the data acquisition system.
1027 specifications : list of SineSpecification
1028 One SineSpecification object for each tone in the environment.
1029 output_oversample : int
1030 The oversample factor on the output.
1031 ramp_time : float
1032 The time to ramp up in level to start or ramp down to end.
1033 convergence_factor : float
1034 A value between 0 and 1 specifying how quickly the error correction
1035 should take place.
1036 block_size : int
1037 The number of samples generated for each analysis and error
1038 correction step
1039 buffer_blocks : int
1040 The number of blocks to keep in a buffer to ensure we don't run
1041 out of data to generate.
1042 extra_control_parameters : str
1043 A string containing any extr information the control law needs
1044 to know about.
1045 sysid_frequency_spacing : float
1046 The frequency spacing in the transfer function.
1047 sysid_transfer_functions : ndarray
1048 A 3D ndarray with shape num_freq, num_control, num_drive.
1049 sysid_response_noise : ndarray
1050 A 3D ndarray containing CPSDs with values populated from the
1051 noise floor check. Shape is num_freq, num_control, num_control.
1052 sysid_reference_noise : ndarray
1053 A 3D ndarray containing CPSDs with the values populated from the
1054 noise floor check. Shape is num_freq, num_drive, num_drive.
1055 sysid_response_cpsd : ndarray
1056 A 3D ndarray containing CPSDs with values populated from the
1057 system identification. Shape is num_freq, num_control, num_control.
1058 sysid_reference_cpsd : ndarray
1059 A 3D ndarray containing CPSDs with the values populated from the
1060 system identification. Shape is num_freq, num_drive, num_drive.
1061 sysid_coherence : ndarray
1062 Multiple coherence of the control channels from the system
1063 identification
1064 sysid_frames : int
1065 NNumber of frames used in the system identification.
1066
1067 """
1068 # start_time = time.time()
1069 self.block_size = block_size
1070 self.sample_rate = sample_rate
1071 self.buffer_blocks = buffer_blocks
1072 self.specifications = [spec.copy() for spec in specifications]
1073 self.output_oversample = output_oversample
1074 self.extra_control_parameters = extra_control_parameters
1075 self.ramp_samples = int(ramp_time * sample_rate) * output_oversample
1076 self.convergence_factor = convergence_factor
1077 # Loop through the different specifications and perform the initial control
1078 (
1079 self.specified_response,
1080 self.specified_tone_response,
1081 self.specified_frequency,
1082 self.specified_argument,
1083 self.specified_amplitude,
1084 self.specified_phase, # Radians
1085 _,
1086 _,
1087 ) = SineSpecification.create_combined_signals(
1088 self.specifications,
1089 self.sample_rate * self.output_oversample,
1090 self.ramp_samples,
1091 )
1092 self.specified_phase = self.specified_phase # Radians
1093 self.tone_slices = []
1094 for amp in self.specified_amplitude:
1095 nonzero_indices = np.any(amp != 0, axis=0)
1096 self.tone_slices.append(
1097 slice(
1098 np.argmax(nonzero_indices),
1099 nonzero_indices.size - np.argmax(nonzero_indices[::-1]),
1100 )
1101 )
1102
1103 # System ID Parameters
1104 self.frfs = None
1105 self.frf_frequency_spacing = None
1106 self.frf_frequencies = None
1107 self.sysid_response_noise = None
1108 self.sysid_reference_noise = None
1109 self.sysid_response_cpsd = None
1110 self.sysid_reference_cpsd = None
1111 self.sysid_coherence = None
1112 self.frames = None
1113 self.frf_pinv = None
1114 self.interpolated_frf_pinv = None
1115 self.largest_correction_factors = None
1116
1117 # Preshaped drive parameters
1118 self.preshaped_drive_amplitudes = None
1119 self.preshaped_drive_phases = None # Radians
1120 self.preshaped_drive_signals = None
1121
1122 # Control parameters
1123 self.start_index = None
1124 self.end_index = None
1125 self.signal_slice = None
1126 self.control_tones = None
1127 self.control_ramp_up = None
1128 self.control_ramp_down = None
1129 self.target_ramp_up = None
1130 self.target_ramp_down = None
1131 self.control_response_signals = None
1132 self.control_response_amplitudes = None
1133 self.control_response_phases = None # Radians
1134 self.control_drive_correction = None
1135 self.control_sent_complex_excitation = None
1136 self.control_analysis_index = None
1137 self.control_write_index = None
1138 self.max_singular_values = None
1139
1140 self.maximum_drive_voltage = None
1141 self.harddisk_storage = None
1142 for string in extra_control_parameters.split("\n"):
1143 if string.strip() == "":
1144 continue
1145 try:
1146 command, value = string.split("=")
1147 except ValueError:
1148 print(f'Unable to Parse Extra Parameters Line"{string:}"')
1149 continue
1150 command = command.strip()
1151 value = value.strip()
1152 if command in [
1153 "maximum_drive_voltage",
1154 "max_drive_voltage",
1155 "maximum_excitation_voltage",
1156 "max_excitation_voltage",
1157 ]:
1158 self.maximum_drive_voltage = float(value)
1159 print(f"Set Maximum Drive Voltage to {self.maximum_drive_voltage}")
1160 elif command in ["harddisk_storage"]:
1161 self.harddisk_storage = value
1162 else:
1163 print(f"Unknown extra parameter {command}")
1164
1165 if sysid_transfer_functions is not None:
1166 self.system_id_update(
1167 sysid_frequency_spacing, # Frequency Spacing
1168 sysid_transfer_functions, # Transfer Functions
1169 sysid_response_noise, # Noise levels and correlation
1170 sysid_reference_noise, # from the system identification
1171 sysid_response_cpsd, # Response levels and correlation
1172 sysid_reference_cpsd, # from the system identification
1173 sysid_coherence, # Coherence from the system identification
1174 sysid_frames, # Number of frames in the FRF matrices
1175 )
1176 # finish_time = time.time()
1177 # print(f'__init__ called in {finish_time - start_time:0.2f}s.')
1178
1179 def system_id_update(
1180 self,
1181 sysid_frequency_spacing, # Frequency Spacing
1182 sysid_transfer_functions, # Transfer Functions
1183 sysid_response_noise, # Noise levels and correlation
1184 sysid_reference_noise, # from the system identification
1185 sysid_response_cpsd, # Response levels and correlation
1186 sysid_reference_cpsd, # from the system identification
1187 sysid_coherence, # Coherence from the system identification
1188 sysid_frames, # Number of frames in the FRF matrices
1189 ):
1190 """
1191 Updates the control law after system identification has finished
1192
1193 Parameters
1194 ----------
1195 sysid_frequency_spacing : float
1196 The frequency spacing in the transfer function.
1197 sysid_transfer_functions : ndarray
1198 A 3D ndarray with shape num_freq, num_control, num_drive.
1199 sysid_response_noise : ndarray
1200 A 3D ndarray containing CPSDs with values populated from the
1201 noise floor check. Shape is num_freq, num_control, num_control.
1202 sysid_reference_noise : ndarray
1203 A 3D ndarray containing CPSDs with the values populated from the
1204 noise floor check. Shape is num_freq, num_drive, num_drive.
1205 sysid_response_cpsd : ndarray
1206 A 3D ndarray containing CPSDs with values populated from the
1207 system identification. Shape is num_freq, num_control, num_control.
1208 sysid_reference_cpsd : ndarray
1209 A 3D ndarray containing CPSDs with the values populated from the
1210 system identification. Shape is num_freq, num_drive, num_drive.
1211 sysid_coherence : ndarray
1212 Multiple coherence of the control channels from the system
1213 identification
1214 sysid_frames : int
1215 Number of frames used in the system identification.
1216
1217 Returns
1218 -------
1219 preshaped_drive_signals : ndarray
1220 Excitation signals with shape num_tones, num_drives, num_timesteps.
1221 specified_frequency : ndarray
1222 The instantaneous frequencies at each of the drive timesteps with
1223 shape num_tones, num_timesteps
1224 specified_argument : ndarray
1225 The instantaneous sine argument at each of the drive timesteps with
1226 shape num_tones, num_drives, num_timesteps
1227 preshaped_drive_amplitudes : ndarray
1228 The instantaneous amplitude at each of the drive timesteps with
1229 shape num_tones, num_drives, num_timesteps
1230 preshaped_drive_phases : ndarray
1231 The instantaneous phase in radians at each of the drive timesteps
1232 with shape num_tones, num_drives, num_timesteps.
1233 """
1234 # start_time = time.time()
1235 # print('Updating System ID Information')
1236 self.frf_frequency_spacing = sysid_frequency_spacing
1237 self.frfs = sysid_transfer_functions
1238 self.frf_frequencies = self.frf_frequency_spacing * np.arange(self.frfs.shape[0])
1239 # print('Inverting FRFs')
1240 self.frf_pinv = np.linalg.pinv(self.frfs)
1241 self.max_singular_values = np.max(
1242 np.linalg.svd(self.frfs, compute_uv=False, full_matrices=False), axis=-1
1243 )
1244 self.sysid_response_noise = sysid_response_noise
1245 self.sysid_reference_noise = sysid_reference_noise
1246 self.sysid_response_cpsd = sysid_response_cpsd
1247 self.sysid_reference_cpsd = sysid_reference_cpsd
1248 self.sysid_coherence = sysid_coherence
1249 self.frames = sysid_frames
1250
1251 # Go through and compute the response amplitudes and phases from each of the sine tones
1252 # print('Preallocating Amplitude, Phase, FRFs, and Correction Factors')
1253 if self.harddisk_storage is not None:
1254 filename = os.path.join(self.harddisk_storage, "preshaped_drive_amplitudes.mmap")
1255 shape = (
1256 self.specified_frequency.shape[0],
1257 self.frfs.shape[-1],
1258 self.specified_frequency.shape[-1],
1259 )
1260 self.preshaped_drive_amplitudes = np.memmap(
1261 filename, dtype=float, shape=shape, mode="w+"
1262 )
1263 filename = os.path.join(self.harddisk_storage, "preshaped_drive_phases.mmap")
1264 shape = (
1265 self.specified_frequency.shape[0],
1266 self.frfs.shape[-1],
1267 self.specified_frequency.shape[-1],
1268 )
1269 self.preshaped_drive_phases = np.memmap(filename, dtype=float, shape=shape, mode="w+")
1270 filename = os.path.join(self.harddisk_storage, "largest_correction_factors.mmap")
1271 shape = self.specified_frequency.shape
1272 self.largest_correction_factors = np.memmap(
1273 filename, dtype=float, shape=shape, mode="w+"
1274 )
1275 filename = os.path.join(self.harddisk_storage, "interpolated_frf_pinv.mmap")
1276 shape = (
1277 self.specified_frequency.shape[0],
1278 self.specified_frequency.shape[-1],
1279 ) + self.frf_pinv.shape[-2:]
1280 self.interpolated_frf_pinv = np.memmap(filename, dtype="c16", shape=shape, mode="w+")
1281 else:
1282 self.preshaped_drive_amplitudes = np.zeros(
1283 (
1284 self.specified_frequency.shape[0],
1285 self.frfs.shape[-1],
1286 self.specified_frequency.shape[-1],
1287 )
1288 )
1289 self.preshaped_drive_phases = np.zeros( # Radians
1290 (
1291 self.specified_frequency.shape[0],
1292 self.frfs.shape[-1],
1293 self.specified_frequency.shape[-1],
1294 )
1295 )
1296 self.largest_correction_factors = np.zeros(self.specified_frequency.shape)
1297 self.interpolated_frf_pinv = np.zeros(
1298 (
1299 self.specified_frequency.shape[0],
1300 self.specified_frequency.shape[-1],
1301 )
1302 + self.frf_pinv.shape[-2:],
1303 dtype="c16",
1304 )
1305 for tone_index, (freq, amp, phs, control_slice) in enumerate( # phs is radians
1306 zip(
1307 self.specified_frequency,
1308 self.specified_amplitude,
1309 self.specified_phase, # Radians
1310 self.tone_slices,
1311 )
1312 ):
1313 # print('Interpolating Tone {:}'.format(tone_index))
1314 control_amp = amp[..., control_slice]
1315 control_phs = phs[..., control_slice] # Radians
1316 control_freq = freq[..., control_slice]
1317 # Interpolate the pseudoinverse of the FRF
1318 for index in np.ndindex(*self.frf_pinv.shape[1:]):
1319 # print(' Interpolating Response {:}'.format(index))
1320 interpolated_index = (tone_index, control_slice) + index
1321 frf_pinv_index = (Ellipsis,) + index
1322 self.interpolated_frf_pinv[interpolated_index] = np.interp(
1323 control_freq, self.frf_frequencies, self.frf_pinv[frf_pinv_index]
1324 )
1325 # print('Computing Largest Correction Factors')
1326 self.largest_correction_factors[tone_index, control_slice] = (
1327 1 / np.interp(control_freq, self.frf_frequencies, self.max_singular_values) ** 2
1328 )
1329 # print('Computing Complex Response')
1330 complex_response = np.moveaxis(
1331 control_amp * np.exp(1j * control_phs), -1, 0
1332 )[ # Radians
1333 ..., np.newaxis
1334 ]
1335 # print('Computing Complex Excitation')
1336 complex_excitation = np.moveaxis(
1337 (self.interpolated_frf_pinv[tone_index, control_slice] @ complex_response)[..., 0],
1338 0,
1339 1,
1340 )
1341 # print('Extracting Excitation Amplitude and Phase')
1342 self.preshaped_drive_amplitudes[tone_index, :, control_slice] = np.abs(
1343 complex_excitation
1344 )
1345 self.preshaped_drive_phases[tone_index, :, control_slice] = np.unwrap(
1346 np.angle(complex_excitation)
1347 ) # Radians
1348
1349 if self.maximum_drive_voltage is not None:
1350 # print('Truncating for Maximum Voltage')
1351 self.preshaped_drive_amplitudes[
1352 self.preshaped_drive_amplitudes > self.maximum_drive_voltage
1353 ] = self.maximum_drive_voltage
1354 self.preshaped_drive_amplitudes[
1355 self.preshaped_drive_amplitudes < -self.maximum_drive_voltage
1356 ] = -self.maximum_drive_voltage
1357 # print('Computing Excitation Signal')
1358 self.preshaped_drive_signals = self.preshaped_drive_amplitudes * np.cos(
1359 self.specified_argument[:, np.newaxis, :] + self.preshaped_drive_phases # Radians
1360 )
1361
1362 if self.harddisk_storage is not None:
1363 # print('Flushing memmaps')
1364 self.preshaped_drive_amplitudes.flush()
1365 self.preshaped_drive_phases.flush() # Radians
1366 self.largest_correction_factors.flush()
1367 self.interpolated_frf_pinv.flush()
1368
1369 if DEBUG:
1370 print("Writing Sine Debug Pickle")
1371 with open("debug_data/sine_control_law_debug.pkl", "wb") as f:
1372 pickle.dump(self, f)
1373 print("Done!")
1374
1375 # finish_time = time.time()
1376 # print(f'system_id_update called in {finish_time - start_time:0.2f}s.')
1377
1378 return (
1379 self.preshaped_drive_signals,
1380 self.specified_frequency,
1381 self.specified_argument,
1382 self.preshaped_drive_amplitudes,
1383 self.preshaped_drive_phases, # Radians for return value
1384 )
1385
1386 def get_control_targets(self, block_start, block_end):
1387 """Gets up the control targets for a specified block, adding ramps
1388
1389 Parameters
1390 ----------
1391 block_start : int
1392 The starting index for the block of data that targets are computed from
1393 block_end : int
1394 The end index for the block of data that targets are computed from
1395
1396 Returns
1397 -------
1398 amplitudes
1399 The amplitudes over time including ramp up and ramp down
1400 phases
1401 The phases over time including ramp up and ramp down portions, in radians
1402 arguments
1403 The argument of the cosine wave over time including the ramp
1404 up and ramp down portions
1405 """
1406 ramp_up_start = block_start - self.ramp_samples
1407 if ramp_up_start >= 0:
1408 ramp_up_start = self.ramp_samples
1409 ramp_up_end = block_end - self.ramp_samples
1410 if ramp_up_end >= 0:
1411 ramp_up_end = self.ramp_samples
1412 ramp_down_start = block_start - self.end_index + self.start_index + self.ramp_samples
1413 if ramp_down_start < 0:
1414 ramp_down_start = 0
1415 ramp_down_end = block_end - self.end_index + self.start_index + self.ramp_samples
1416 if ramp_down_end < 0:
1417 ramp_down_end = 0
1418 middle_start = block_start
1419 if middle_start < self.ramp_samples:
1420 middle_start = self.ramp_samples
1421 if middle_start > self.end_index - self.start_index - self.ramp_samples:
1422 middle_start = self.end_index - self.start_index - self.ramp_samples
1423 middle_end = block_end
1424 if middle_end < self.ramp_samples:
1425 middle_end = self.ramp_samples
1426 if middle_end > self.end_index - self.start_index - self.ramp_samples:
1427 middle_end = self.end_index - self.start_index - self.ramp_samples
1428 amplitudes = np.concatenate(
1429 (
1430 self.target_ramp_up[..., ramp_up_start:ramp_up_end],
1431 self.specified_amplitude[
1432 self.control_tones,
1433 ...,
1434 self.start_index + middle_start : self.start_index + middle_end,
1435 ],
1436 self.target_ramp_down[..., ramp_down_start:ramp_down_end],
1437 ),
1438 axis=-1,
1439 )
1440 phases = self.specified_phase[ # Radians
1441 self.control_tones,
1442 ...,
1443 self.start_index + block_start : self.start_index + block_start + amplitudes.shape[-1],
1444 ]
1445 arguments = self.specified_argument[
1446 self.control_tones,
1447 self.start_index + block_start : self.start_index + block_start + amplitudes.shape[-1],
1448 ]
1449 return amplitudes, phases, arguments # Radians
1450
1451 def get_control_preshaped_excitations(self, block_start, block_end):
1452 """Gets the initial guess at an excitation signal over a portion of time
1453
1454 Parameters
1455 ----------
1456 block_start : int
1457 The starting index for the block of data that excitations are computed from
1458 block_end : int
1459 The end index for the block of data that excitations are computed from
1460
1461 Returns
1462 -------
1463 amplitudes
1464 The amplitudes over time including ramp up and ramp down
1465 phases
1466 The phases over time including ramp up and ramp down portions
1467 arguments
1468 The argument of the cosine wave over time including the ramp
1469 up and ramp down portions, in radians
1470 """
1471 ramp_up_start = block_start - self.ramp_samples
1472 if ramp_up_start >= 0:
1473 ramp_up_start = self.ramp_samples
1474 ramp_up_end = block_end - self.ramp_samples
1475 if ramp_up_end >= 0:
1476 ramp_up_end = self.ramp_samples
1477 ramp_down_start = block_start - self.end_index + self.start_index + self.ramp_samples
1478 if ramp_down_start < 0:
1479 ramp_down_start = 0
1480 ramp_down_end = block_end - self.end_index + self.start_index + self.ramp_samples
1481 if ramp_down_end < 0:
1482 ramp_down_end = 0
1483 middle_start = block_start
1484 if middle_start < self.ramp_samples:
1485 middle_start = self.ramp_samples
1486 if middle_start > self.end_index - self.start_index - self.ramp_samples:
1487 middle_start = self.end_index - self.start_index - self.ramp_samples
1488 middle_end = block_end
1489 if middle_end < self.ramp_samples:
1490 middle_end = self.ramp_samples
1491 if middle_end > self.end_index - self.start_index - self.ramp_samples:
1492 middle_end = self.end_index - self.start_index - self.ramp_samples
1493 amplitudes = np.concatenate(
1494 (
1495 self.control_ramp_up[..., ramp_up_start:ramp_up_end],
1496 self.preshaped_drive_amplitudes[
1497 self.control_tones,
1498 ...,
1499 self.start_index + middle_start : self.start_index + middle_end,
1500 ],
1501 self.control_ramp_down[..., ramp_down_start:ramp_down_end],
1502 ),
1503 axis=-1,
1504 )
1505 phases = self.preshaped_drive_phases[ # Radians
1506 self.control_tones,
1507 ...,
1508 self.start_index + block_start : self.start_index + block_start + amplitudes.shape[-1],
1509 ]
1510 arguments = self.specified_argument[
1511 self.control_tones,
1512 self.start_index + block_start : self.start_index + block_start + amplitudes.shape[-1],
1513 ]
1514 return amplitudes, phases, arguments # Phase in Radians
1515
1516 def initialize_control(self, control_tones, start_index, end_index):
1517 """
1518 Initializes the control and creates a preshaped drive signal
1519
1520 Aguments are provided to specify which tones and portion of time
1521 to generate the signal over
1522
1523 Parameters
1524 ----------
1525 control_tones : ndarray or slice
1526 Indicies into the specifications to determine which control tones
1527 should be used.
1528 start_index : int
1529 The starting time step index.
1530 end_index : int
1531 The ending time step index.
1532
1533 Returns
1534 -------
1535 excitation_signals : ndarray
1536 The drive signal at each shaker over time
1537
1538 """
1539 # start_time = time.time()
1540 # Parse the frequency content to get the portion of the signal we care
1541 # about
1542 self.control_tones = control_tones
1543 self.start_index = start_index
1544 self.end_index = self.preshaped_drive_signals.shape[-1] if end_index is None else end_index
1545
1546 if DEBUG:
1547 print("Writing Sine Debug Pickle")
1548 with open("debug_data/sine_control_law_initialize_control_debug.pkl", "wb") as f:
1549 pickle.dump(self, f)
1550 print("Done!")
1551
1552 # Set up the analysis and write_indices
1553 self.control_analysis_index = 0
1554 self.control_write_index = self.ramp_samples + self.buffer_blocks * self.block_size
1555
1556 # Set up the ramp-ups and ramp downs for the excitation signal
1557 self.control_ramp_up = (
1558 np.linspace(0, 1, self.ramp_samples)
1559 * self.preshaped_drive_amplitudes[
1560 self.control_tones,
1561 ...,
1562 self.start_index + self.ramp_samples,
1563 np.newaxis,
1564 ]
1565 )
1566 self.control_ramp_down = (
1567 np.linspace(1, 0, self.ramp_samples)
1568 * self.preshaped_drive_amplitudes[
1569 self.control_tones, ..., self.end_index - self.ramp_samples, np.newaxis
1570 ]
1571 )
1572 self.target_ramp_up = (
1573 np.linspace(0, 1, self.ramp_samples)
1574 * self.specified_amplitude[
1575 self.control_tones,
1576 ...,
1577 self.start_index + self.ramp_samples,
1578 np.newaxis,
1579 ]
1580 )
1581 self.target_ramp_down = (
1582 np.linspace(1, 0, self.ramp_samples)
1583 * self.specified_amplitude[
1584 self.control_tones, ..., self.end_index - self.ramp_samples, np.newaxis
1585 ]
1586 )
1587
1588 (
1589 starting_drive_amplitudes,
1590 starting_drive_phases,
1591 starting_arguments,
1592 ) = self.get_control_preshaped_excitations( # Radians
1593 0, self.control_write_index
1594 ) # Radians
1595
1596 complex_excitation = starting_drive_amplitudes * (
1597 np.exp(1j * starting_drive_phases)
1598 ) # Radians
1599 excitation_signals = np.sum(
1600 starting_drive_amplitudes
1601 * np.cos(starting_drive_phases + starting_arguments[:, np.newaxis, :]), # Radians
1602 axis=0,
1603 )
1604
1605 # Set up control parameters
1606 self.control_drive_correction = np.zeros(starting_drive_amplitudes.shape[:2], dtype=complex)
1607
1608 # Set up the amplitude and phase tracking
1609 self.control_response_amplitudes = []
1610 self.control_response_phases = [] # Radians
1611 self.control_response_signals = []
1612 self.control_sent_complex_excitation = []
1613
1614 if DEBUG:
1615 print("Writing Sine Debug Pickle")
1616 with open("debug_data/sine_control_law_debug.pkl", "wb") as f:
1617 pickle.dump(self, f)
1618 print("Done!")
1619
1620 # Sending ramp and first two blocks to start, so add them to the list of blocks sent.
1621 self.control_sent_complex_excitation.append(complex_excitation)
1622
1623 # finish_time = time.time()
1624 # print(f'initialize_control called in {finish_time - start_time:0.2f}s.')
1625 return excitation_signals
1626
1627 def update_control(
1628 self,
1629 control_signals,
1630 control_amplitudes,
1631 control_phases, # Radians
1632 control_frequencies, # pylint: disable=unused-argument
1633 time_delay, # pylint: disable=unused-argument
1634 ):
1635 """
1636 Updates the control parameters based on previous responses
1637
1638 Parameters
1639 ----------
1640 control_signals : ndarray
1641 Time histories acquired by the environment.
1642 control_amplitudes : ndarray
1643 Amplitudes extracted from the time signals with shape num_tones,
1644 num_channels, num_timesteps.
1645 control_phases : ndarray
1646 Phases extracted from the time signals in radians with shape
1647 num_tones, num_channels, num_timesteps
1648 control_frequencies : ndarray
1649 Instantaneous frequencies at the timesteps analyzed with shape
1650 num_tones, num_timesteps.
1651 time_delay : float
1652 Time delay computed between the acquisition and output signals,
1653 which can be used to adjust for phase drifts due to delays.
1654
1655 Returns
1656 -------
1657 drive_correction : ndarray
1658 A correction factor on the drive signals with shape num_tones,
1659 num_drives.
1660
1661 """
1662
1663 # start_time = time.time()
1664 self.control_response_signals.append(control_signals)
1665 self.control_response_amplitudes.append(control_amplitudes)
1666 self.control_response_phases.append(control_phases) # Radians
1667
1668 if DEBUG:
1669 print("Writing Sine Debug Pickle")
1670 with open("debug_data/sine_control_law_update_control_debug.pkl", "wb") as f:
1671 pickle.dump(self, f)
1672 print("Done!")
1673
1674 # Find the equivalent block in the signal
1675 block_start_index = self.control_analysis_index
1676 block_end_index = (
1677 control_signals.shape[-1] * self.output_oversample + self.control_analysis_index
1678 )
1679 if self.convergence_factor != 0:
1680 reduction_slice = slice(
1681 block_start_index + self.start_index,
1682 block_end_index + self.start_index,
1683 self.output_oversample,
1684 )
1685 # Compute the target of the current block
1686 target_response_amplitudes, target_response_phases, _ = ( # Radians
1687 self.get_control_targets(block_start_index, block_end_index)
1688 )
1689 complex_targets = target_response_amplitudes[..., :: self.output_oversample] * np.exp(
1690 1j * target_response_phases[..., :: self.output_oversample]
1691 )
1692 complex_achieved = control_amplitudes * np.exp(1j * control_phases) # Radians
1693 complex_error = (
1694 complex_targets - complex_achieved
1695 ) # Number of Tones x Num Responses x Num Freqs
1696 block_correction_factor = self.convergence_factor * np.min(
1697 self.largest_correction_factors[self.control_tones, ..., reduction_slice],
1698 axis=-1,
1699 keepdims=True,
1700 ) # Num Tones x 1
1701 block_frf = self.interpolated_frf_pinv[
1702 self.control_tones, reduction_slice, ...
1703 ] # Number of Tones x Num Freqs x Num Excitations x Num Responses
1704 block_drive_correction = ( # Number of Tones x Num Freqs x Num Excitations x 1
1705 block_frf # Number of Tones x Num Freqs x Num Excitations x Num Responses
1706 @ complex_error.transpose(0, 2, 1)[..., np.newaxis]
1707 ) # Number of Tones x Number Freqs x Num Responses x 1
1708 self.control_drive_correction = (
1709 self.control_drive_correction # Number of Tones x Number of Excitation Signals
1710 + block_correction_factor # Number of Tones x 1
1711 * np.mean(block_drive_correction[..., 0], axis=1)
1712 ) # Mean across frequency lines, Number of Tones x Num Excitations
1713 self.control_analysis_index = block_end_index
1714 # finish_time = time.time()
1715 # print(f'update_control called in {finish_time - start_time:0.2f}s.')
1716 return self.control_drive_correction
1717
1718 def generate_signal(self):
1719 """
1720 Generates the next portion of the signal during the control
1721 calculations
1722
1723 Returns
1724 -------
1725 excitation_signal : ndarray
1726 The next portion of the signal to generate, with shape num_drives,
1727 num_timesteps.
1728 done_controlling : bool
1729 A flag specifying that the entire signal has been generated, so no
1730 more control decisions should be made.
1731
1732 """
1733 if DEBUG:
1734 print("Writing Sine Debug Pickle")
1735 with open("debug_data/sine_control_law_generate_signal_debug.pkl", "wb") as f:
1736 pickle.dump(self, f)
1737 print("Done!")
1738
1739 # start_time = time.time()
1740 start_index = self.control_write_index
1741 end_index = self.control_write_index + self.block_size
1742 excitation_amplitudes, excitation_phases, excitation_arguments = ( # Radians
1743 self.get_control_preshaped_excitations(start_index, end_index)
1744 )
1745 complex_excitation = (
1746 excitation_amplitudes * np.exp(1j * excitation_phases)
1747 + self.control_drive_correction[..., np.newaxis]
1748 ) # Num tones x num signals x num freqs
1749 amplitudes = np.abs(complex_excitation)
1750 if self.maximum_drive_voltage is not None:
1751 over_indices = amplitudes > self.maximum_drive_voltage
1752 complex_excitation[over_indices] = (
1753 self.maximum_drive_voltage
1754 * complex_excitation[over_indices]
1755 / amplitudes[over_indices]
1756 )
1757 excitation_signals = np.abs(complex_excitation) * np.cos(
1758 excitation_arguments[:, np.newaxis, :] + np.angle(complex_excitation)
1759 )
1760 # Combine all tones into one signal
1761 excitation_signal = np.sum(excitation_signals, axis=0)
1762 # Store this value so we know what was output
1763 self.control_sent_complex_excitation.append(complex_excitation)
1764 # Check if we've exhausted all of our data
1765 done_controlling = end_index >= (self.end_index - self.start_index)
1766 self.control_write_index = end_index
1767
1768 # finish_time = time.time()
1769 # print(f'generate_signal called in {finish_time - start_time:0.2f}s.')
1770 return excitation_signal, done_controlling
1771
1772 def finalize_control(self):
1773 """
1774 A method to update the control based on previous results, generating a
1775 new preshaped drive signal
1776
1777 Returns
1778 -------
1779 preshaped_drive_signals : ndarray
1780 Excitation signals with shape num_tones, num_drives, num_timesteps.
1781 specified_frequency : ndarray
1782 The instantaneous frequencies at each of the drive timesteps with
1783 shape num_tones, num_timesteps
1784 specified_argument : ndarray
1785 The instantaneous sine argument at each of the drive timesteps with
1786 shape num_tones, num_drives, num_timesteps
1787 preshaped_drive_amplitudes : ndarray
1788 The instantaneous amplitude at each of the drive timesteps with
1789 shape num_tones, num_drives, num_timesteps
1790 preshaped_drive_phases : ndarray
1791 The instantaneous phase in radians at each of the drive timesteps
1792 with shape num_tones, num_drives, num_timesteps.
1793 ramp_samples : ndarray
1794 The number of ramp samples used in the signal
1795
1796 """
1797 return (
1798 self.preshaped_drive_signals,
1799 self.specified_frequency,
1800 self.specified_argument,
1801 self.preshaped_drive_amplitudes,
1802 self.preshaped_drive_phases, # Radians
1803 self.ramp_samples,
1804 )
1805
1806
1807def vold_kalman_filter(
1808 sample_rate,
1809 signal,
1810 arguments,
1811 filter_order=None,
1812 bandwidth=None,
1813 method=None,
1814 return_amp_phs=False,
1815 return_envelope=False,
1816 return_r=False,
1817):
1818 """
1819 Extract sinusoidal components from a signal using the second generation
1820 Vold-Kalman filter.
1821
1822 Parameters
1823 ----------
1824 sample_rate : float
1825 The sample rate of the signal in Hz.
1826 signal : ndarray
1827 A 1D signal containing sinusoidal components that need to be extracted
1828 arguments : ndarray
1829 A 2D array consisting of the arguments to the sinusoidal components of
1830 the form exp(1j*argument). This is the integral over time of the
1831 angular frequency, which can be approximated as
1832 2*np.pi*scipy.integrate.cumulative_trapezoid(frequencies,timesteps,initial=0)
1833 if frequencies is the frequency at each time step in Hz timesteps is
1834 the vector of time steps in seconds. This is a 2D array where the
1835 number of rows is the
1836 number of different sinusoidal components that are desired to be
1837 extracted, and the number of columns are the number of time steps in
1838 the `signal` argument.
1839 filter_order : int, optional
1840 The order of the VK filter, which should be 1, 2, or 3. The default is
1841 2. The low-pass filter roll-off is approximately -40 dB per times the
1842 filter order.
1843 bandwidth : ndarray, optional
1844 The prescribed bandwidth of the filter. This is related to the filter
1845 selectivity parameter `r` in the literature. This will be broadcast to
1846 the same shape as the `arguments` argument. The default is the sample
1847 rate divided by 1000.
1848 method : str, optional
1849 Can be set to either 'single' or 'multi'. In a 'single' solution, each
1850 sinusoidal component will be solved independently without any coupling.
1851 This can be more efficient, but will result in errors if the
1852 frequencies of the sine waves cross. The 'multi' solution will solve
1853 for all sinusoidal components simultaneously, resulting in a better
1854 estimate of crossing frequencies. The default is 'multi'.
1855 return_amp_phs : bool
1856 Returns the amplitude and phase of the reconstructed signals at each
1857 time step. Default is False
1858 return_envelope : bool
1859 Returns the complex envelope and phasors at each time step. Default is
1860 False
1861 return_r : bool
1862 Returns the computed selectivity parameters for the filter. Default is
1863 False
1864
1865 Raises
1866 ------
1867 ValueError
1868 If arguments are not the correct size or values.
1869
1870 Returns
1871 -------
1872 reconstructed_signals : ndarray
1873 Returns a time history the same size as `signal` for each of the
1874 sinusoidal components solved for.
1875 reconstructed_amplitudes : ndarray
1876 Returns the amplitude over time for each of the sinusoidal components
1877 solved for. Only returned if return_amp_phs is True.
1878 reconstructed_phases : ndarray
1879 Returns the phase over time for each of the sinusoidal components
1880 solved for. Only returned if return_amp_phs is True.
1881 reconstructed_envelope : ndarray
1882 Returns the complex envelope `x` over time for each of the sinusoidal
1883 components solved for. Only returned if return_envelope is True.
1884 reconstructed_phasor : ndarray
1885 Returns the phasor `c` over time for each of the sinusoidal components
1886 solved for. Only returned if return_envelope is True.
1887 r : ndarray
1888 Returns the selectivity `r` over time for each of the sinusoidal
1889 components solved for. Only returned if return_r is True.
1890
1891 """
1892 # pylint: disable=invalid-name
1893 if filter_order is None:
1894 filter_order = 2
1895
1896 if bandwidth is None:
1897 bandwidth = sample_rate / 1000
1898
1899 # Make sure input data are numpy arrays
1900 signal = np.array(signal)
1901 arguments = np.atleast_2d(arguments)
1902 bandwidth = np.atleast_2d(bandwidth)
1903 bandwidth = np.broadcast_to(bandwidth, arguments.shape)
1904 relative_bandwidth = bandwidth / sample_rate
1905
1906 # Extract some sizes to make sure everything is correctly sized
1907 n_samples = signal.shape[-1]
1908
1909 n_orders_arg, n_arg = arguments.shape
1910 if n_arg != n_samples:
1911 raise ValueError(
1912 "Argument array must have identical number of columns as samples in signal"
1913 )
1914
1915 if method is None:
1916 if n_orders_arg > 1:
1917 method = "multi"
1918 else:
1919 method = "single"
1920 if method.lower() not in ["multi", "single"]:
1921 raise ValueError('`method` must be either "multi" or "single"')
1922
1923 # Construct phasors to multiply the signals by
1924 phasor = np.exp(1j * arguments)
1925
1926 # Construct the matrices for the least squares solution
1927 if filter_order == 1:
1928 coefs = np.array([1, -1])
1929 r = np.sqrt((np.sqrt(2) - 1) / (2 * (1 - np.cos(np.pi * relative_bandwidth))))
1930 elif filter_order == 2:
1931 coefs = np.array([1, -2, 1])
1932 r = np.sqrt(
1933 (np.sqrt(2) - 1)
1934 / (
1935 6
1936 - 8 * np.cos(np.pi * relative_bandwidth)
1937 + 2 * np.cos(2 * np.pi * relative_bandwidth)
1938 )
1939 )
1940 elif filter_order == 3:
1941 coefs = np.array([1, -3, 3, -1])
1942 r = np.sqrt(
1943 (np.sqrt(2) - 1)
1944 / (
1945 20
1946 - 30 * np.cos(np.pi * relative_bandwidth)
1947 + 12 * np.cos(2 * np.pi * relative_bandwidth)
1948 - 2 * np.cos(3 * np.pi * relative_bandwidth)
1949 )
1950 )
1951 else:
1952 raise ValueError("filter order must be 1, 2, or 3")
1953
1954 # Construct the solution matrices
1955 A = sparse.spdiags(
1956 np.tile(coefs, (n_samples, 1)).T,
1957 np.arange(filter_order + 1),
1958 n_samples - filter_order,
1959 n_samples,
1960 )
1961 B = []
1962 for rvec in r:
1963 R = sparse.spdiags(rvec, 0, n_samples, n_samples)
1964 AR = A @ R
1965 B.append((AR).T @ (AR) + sparse.eye(n_samples))
1966
1967 if method.lower() == "multi":
1968 # This solves the multiple order approach, constructing a big matrix of
1969 # Bs on the diagonal and CHCs on the off-diagonals. We can set up the
1970 # matrix as diagonals and upper diagonals then add the transpose to get the
1971 # lower diagonals
1972 B_multi_diagonal = sparse.block_diag(B)
1973 # There will be number of orders**2 B matrices, and number of orders
1974 # diagonals, so there will be n_orders**2-n_orders off diagonals, half on
1975 # on the upper triangle. We need to fill in all of these values for all
1976 # time steps.
1977 num_off_diags = (n_orders_arg**2 - n_orders_arg) // 2
1978 row_indices = np.zeros((n_samples, num_off_diags), dtype=int)
1979 col_indices = np.zeros((n_samples, num_off_diags), dtype=int)
1980 CHC = np.zeros((n_samples, num_off_diags), dtype="c16")
1981 # Keep track of the off-diagonal index so we know which column to put the
1982 # data in
1983 off_diagonal_index = 0
1984 # Now we need to step through the off-diagonal blocks and create the arrays
1985 for row_index in range(n_orders_arg):
1986 # Since we need to stay on the upper triangle, column indices will start
1987 # after the diagonal entry
1988 for col_index in range(row_index + 1, n_orders_arg):
1989 row_indices[:, off_diagonal_index] = np.arange(
1990 row_index * n_samples, (row_index + 1) * n_samples
1991 )
1992 col_indices[:, off_diagonal_index] = np.arange(
1993 col_index * n_samples, (col_index + 1) * n_samples
1994 )
1995 CHC[:, off_diagonal_index] = phasor[row_index].conj() * phasor[col_index]
1996 off_diagonal_index += 1
1997 # We set up the variables as multidimensional so we could store them easier,
1998 # but now we need to flatten them to put them into the sparse matrix.
1999 # We choose CSR because we can do math with it easier
2000 B_multi_utri = sparse.csr_matrix(
2001 (CHC.flatten(), (row_indices.flatten(), col_indices.flatten())),
2002 shape=B_multi_diagonal.shape,
2003 )
2004
2005 # Now we can assemble the entire matrix by adding with the complex conjugate
2006 # of the upper triangle to get the lower triangle
2007 B_multi = B_multi_diagonal + B_multi_utri + B_multi_utri.getH()
2008
2009 # We also need to construct the right hand side of the equation. This
2010 # should be a multiplication of the phasor^H with the signal
2011 RHS = phasor.flatten().conj() * np.tile(signal, n_orders_arg)
2012
2013 x_multi = linalg.spsolve(B_multi, RHS[:, np.newaxis])
2014 x = 2 * x_multi.reshape(
2015 n_orders_arg, -1
2016 ) # Multiply by 2 to account for missing negative frequency components
2017 else:
2018 # This solves the single order approach. If the user has put in multiple
2019 # orders, it will solve them all independently instead of combining them
2020 # into a single larger solve.
2021 x = np.zeros((n_orders_arg, n_samples), dtype=np.complex128)
2022 for i, (phasor_i, B_i) in enumerate(zip(phasor, B)):
2023 # We already have the left side of the equation B, now we just need the
2024 # right side of the equation, which is the phasor hermetian
2025 # times the signal elementwise-multiplied
2026 RHS = phasor_i.conj() * signal
2027 x[i] = 2 * linalg.spsolve(B_i, RHS)
2028
2029 return_value = [np.real(x * phasor)]
2030 if return_amp_phs:
2031 return_value.extend([np.abs(x), np.angle(x)])
2032 if return_envelope:
2033 return_value.extend([x, phasor])
2034 if return_r:
2035 return_value.extend([r])
2036 if len(return_value) == 1:
2037 return return_value[0]
2038 else:
2039 return return_value
2040
2041
2042def vold_kalman_filter_generator(
2043 sample_rate,
2044 num_orders,
2045 block_size,
2046 overlap,
2047 filter_order=None,
2048 bandwidth=None,
2049 method=None,
2050 buffer_size_factor=3,
2051):
2052 """
2053 Extracts sinusoidal information using a Vold-Kalman Filter
2054
2055 This uses an windowed-overlap-and-add process to solve for the signal while
2056 removing start and end effects of the filter. Each time the generator is
2057 called, it will yield a further section of the results up until the overlap
2058 section.
2059
2060 Parameters
2061 ----------
2062 sample_rate : float
2063 The sample rate of the signal in Hz.
2064 num_orders : int
2065 The number of orders that will be found in the signal
2066 block_size : int
2067 The size of the blocks used in the analysis.
2068 overlap : float, optional
2069 Fraction of the block size to overlap when computing the results. If
2070 not specified, it will default to 0.15.
2071 filter_order : int, optional
2072 The order of the VK filter, which should be 1, 2, or 3. The default is
2073 2. The low-pass filter roll-off is approximately -40 dB per times the
2074 filter order.
2075 bandwidth : ndarray, optional
2076 The prescribed bandwidth of the filter. This is related to the filter
2077 selectivity parameter `r` in the literature. This will be broadcast to
2078 the same shape as the `arguments` argument. The default is the sample
2079 rate divided by 1000.
2080 method : str, optional
2081 Can be set to either 'single' or 'multi'. In a 'single' solution, each
2082 sinusoidal component will be solved independently without any coupling.
2083 This can be more efficient, but will result in errors if the
2084 frequencies of the sine waves cross. The 'multi' solution will solve
2085 for all sinusoidal components simultaneously, resulting in a better
2086 estimate of crossing frequencies. The default is 'multi'.
2087 buffer_size_factor : int, optional
2088 Specifies the size of the buffer. buffer_size_factor * block_size is
2089 the size of the buffer.
2090
2091 Raises
2092 ------
2093 ValueError
2094 If arguments are not the correct size or values.
2095 ValueError
2096 If data is provided subsequently to specifying last_signal = True
2097
2098 Sends
2099 -----
2100 xi : iterable
2101 The next block of the signal to be filtered. This should be a 1D
2102 signal containing sinusoidal components that need to be extracted.
2103 argsi : iterable
2104 A 2D array consisting of the arguments to the sinusoidal components of
2105 the form exp(1j*argsi). This is the integral over time of the
2106 angular frequency, which can be approximated as
2107 2*np.pi*scipy.integrate.cumulative_trapezoid(frequencies,timesteps,initial=0)
2108 if frequencies is the frequency at each time step in Hz timesteps is
2109 the vector of time steps in seconds. This is a 2D array where the
2110 number of rows is the
2111 number of different sinusoidal components that are desired to be
2112 extracted, and the number of columns are the number of time steps in
2113 the `signal` argument.
2114 last_signal : bool
2115 If True, the remainder of the data will be returned and the
2116 overlap-and-add process will be finished.
2117
2118 Yields
2119 -------
2120 reconstructed_signals : ndarray
2121 Returns a time history the same size as `signal` for each of the
2122 sinusoidal components solved for.
2123 reconstructed_amplitudes : ndarray
2124 Returns the amplitude over time for each of the sinusoidal components
2125 solved for. Only returned if return_amp_phs is True.
2126 reconstructed_phases : ndarray
2127 Returns the phase over time for each of the sinusoidal components
2128 solved for. Only returned if return_amp_phs is True.
2129
2130 """
2131 previous_envelope = None
2132 reconstructed_signals = None
2133 reconstructed_amplitudes = None
2134 reconstructed_phases = None
2135 overlap_samples = int(overlap * block_size)
2136 window = windows.hann(overlap_samples * 2, False)
2137 start_window = window[:overlap_samples]
2138 end_window = window[overlap_samples:]
2139 buffer = CircularBufferWithOverlap(
2140 buffer_size_factor * block_size,
2141 block_size,
2142 overlap_samples,
2143 data_shape=(num_orders + 1,),
2144 )
2145 first_output = True
2146 last_signal = False
2147 while True:
2148 xi, argsi, check_last_signal = (
2149 yield reconstructed_signals,
2150 reconstructed_amplitudes,
2151 reconstructed_phases,
2152 )
2153 if last_signal and check_last_signal:
2154 raise ValueError("Generator has been exhausted.")
2155 last_signal = check_last_signal
2156 argsi = np.atleast_2d(argsi)
2157 buffer_data = np.concatenate([xi[np.newaxis], argsi])
2158 # print(f"{buffer_data.shape=}")
2159 buffer_output = buffer.write_get_data(buffer_data, last_signal)
2160 if buffer_output is not None:
2161 # print(f"{buffer_output.shape=}")
2162 if first_output:
2163 buffer_output = buffer_output[..., overlap_samples:]
2164 first_output = False
2165 signal = buffer_output[0]
2166 arguments = buffer_output[1:]
2167 else:
2168 signal = buffer_output[0]
2169 arguments = buffer_output[1:]
2170 signal[:overlap_samples] = signal[:overlap_samples] * start_window
2171 if not last_signal:
2172 signal[-overlap_samples:] = signal[-overlap_samples:] * end_window
2173 # print(f"{signal.shape=}")
2174 # Do the VK Filtering
2175 _, vk_envelope, vk_phasor = vold_kalman_filter(
2176 sample_rate,
2177 signal,
2178 arguments,
2179 filter_order,
2180 bandwidth,
2181 method,
2182 return_envelope=True,
2183 )
2184 # print(f"{vk_signal.shape=}")
2185 # If necessary, do the overlap
2186 if previous_envelope is not None:
2187 vk_envelope[..., :overlap_samples] = (
2188 vk_envelope[..., :overlap_samples] + previous_envelope[..., -overlap_samples:]
2189 )
2190 if not last_signal:
2191 reconstructed_signals = np.real(
2192 vk_envelope[..., :-overlap_samples] * vk_phasor[..., :-overlap_samples]
2193 )
2194 reconstructed_amplitudes = np.abs(vk_envelope[..., :-overlap_samples])
2195 reconstructed_phases = np.angle(vk_envelope[..., :-overlap_samples])
2196 else:
2197 reconstructed_signals = np.real(vk_envelope * vk_phasor)
2198 reconstructed_amplitudes = np.abs(vk_envelope)
2199 reconstructed_phases = np.angle(vk_envelope)
2200 previous_envelope = vk_envelope
2201 else:
2202 # print(f"{buffer_output=}")
2203 reconstructed_signals = None
2204 reconstructed_amplitudes = None
2205 reconstructed_phases = None
2206
2207
2208class CircularBufferWithOverlap:
2209 """
2210 A Circular buffer that allows data to be added and removed
2211 from the buffer with overlap
2212 """
2213
2214 def __init__(self, buffer_size, block_size, overlap_size, dtype="float", data_shape=()):
2215 """Initialize the circular buffer
2216
2217 Parameters
2218 ----------
2219 buffer_size : int
2220 The total size of the circular buffer
2221 block_size : int
2222 The number of samples written or read in each block
2223 overlap_size : int
2224 The number of samples from the previous block to include in the read.
2225 dtype : dtype, optional
2226 The type of data in the buffer. The default is "float".
2227 data_shape : tuple, optional
2228 The shape of data in the buffer. The default is ().
2229 """
2230 self.buffer_size = buffer_size
2231 self.block_size = block_size
2232 self.overlap_size = overlap_size
2233 self.buffer = np.zeros(
2234 tuple(data_shape) + (buffer_size,), dtype=dtype
2235 ) # Initialize buffer with zeros
2236 self.buffer_read = np.ones((buffer_size,), dtype=bool)
2237 self.write_index = 0 # Index where the next block will be written
2238 self.read_index = 0 # Index where the next block will be read from
2239 self.debug = False
2240 if self.debug:
2241 self.report_buffer_state()
2242
2243 def report_buffer_state(self):
2244 """Prints the current buffer state"""
2245 read_samples = np.sum(self.buffer_read)
2246 write_samples = self.buffer_size - read_samples
2247 print(f"{read_samples} of {self.buffer_size} have been read")
2248 print(f"{write_samples} of {self.buffer_size} have been written but not read")
2249
2250 def write_get_data(self, data, read_remaining=False):
2251 """
2252 Writes a block of data and then returns a block if available
2253
2254 Parameters:
2255 - data: Array to write to the buffer.
2256 """
2257 self.write(data)
2258 try:
2259 return self.read(read_remaining)
2260 except ValueError:
2261 return None
2262
2263 def write(self, data):
2264 """
2265 Write a block of data to the circular buffer.
2266
2267 Parameters:
2268 - data: Array to write to the buffer.
2269 """
2270 # Compute the end index for the write operation
2271 indices = (
2272 np.arange(self.write_index, self.write_index + data.shape[-1] + self.overlap_size)
2273 % self.buffer_size
2274 )
2275
2276 if np.any(~self.buffer_read[indices]):
2277 raise ValueError(
2278 "Overwriting data on buffer that has not been read. "
2279 "Read data before writing again."
2280 )
2281
2282 self.buffer[..., indices[: None if self.overlap_size == 0 else -self.overlap_size]] = data
2283 self.buffer_read[indices[: None if self.overlap_size == 0 else -self.overlap_size]] = False
2284
2285 # Update the write index
2286 self.write_index = (self.write_index + data.shape[-1]) % self.buffer_size
2287
2288 if self.debug:
2289 print("Wrote Data to Buffer")
2290 self.report_buffer_state()
2291 # print(self.buffer)
2292 # print(self.buffer_read)
2293
2294 def read(self, read_remaining=False):
2295 """
2296 Reads data from the buffer
2297
2298 Parameters
2299 ----------
2300 read_remaining : bool, optional
2301 If true, read everything left on the buffer that hasn't yet been
2302 read. The default is False.
2303
2304 Raises
2305 ------
2306 ValueError
2307 If there is not a block of data on the buffer and data would be
2308 read a second time.
2309
2310 Returns
2311 -------
2312 return_data : ndarray
2313 A block of data read from the buffer.
2314
2315 """
2316 indices = (
2317 np.arange(self.read_index - self.overlap_size, self.read_index + self.block_size)
2318 % self.buffer_size
2319 )
2320 if read_remaining:
2321 # Pick out just the indices that are ok to read
2322 # print('Reading Remaining:')
2323 # print(f"{indices.copy()=}")
2324 indices = np.concatenate(
2325 (
2326 indices[: self.overlap_size],
2327 indices[self.overlap_size :][~self.buffer_read[indices[self.overlap_size :]]],
2328 )
2329 )
2330 # print(f"{indices.copy()=}")
2331 if np.any(self.buffer_read[indices[self.overlap_size :]]):
2332 raise ValueError("Data would be read multiple times. Write data before reading again.")
2333 return_data = self.buffer[..., indices]
2334 self.buffer_read[indices[self.overlap_size :]] = True
2335 self.read_index = (
2336 self.read_index + (return_data.shape[-1] - self.overlap_size)
2337 ) % self.buffer_size
2338 if self.debug:
2339 print("Read Data from Buffer")
2340 self.report_buffer_state()
2341 # print(self.buffer)
2342 # print(self.buffer_read)
2343 return return_data
2344
2345
2346class SineSpecification:
2347 """A class representing a sine specification"""
2348
2349 def __init__(
2350 self,
2351 name,
2352 start_time,
2353 num_control,
2354 num_breakpoints=None,
2355 frequency_breakpoints=None,
2356 amplitude_breakpoints=None,
2357 phase_breakpoints=None,
2358 sweep_type_breakpoints=None,
2359 sweep_rate_breakpoints=None,
2360 warning_breakpoints=None,
2361 abort_breakpoints=None,
2362 ):
2363 """
2364 Initializes the sine specification
2365
2366 Parameters
2367 ----------
2368 name : str
2369 Name of the sine tone.
2370 start_time : float
2371 The starting time of the sine tone.
2372 num_control : int
2373 The number of control channels in the specification.
2374 num_breakpoints : int, optional
2375 The number of frequency breakpoints in the specification. Either
2376 this or frequency_breakpoints must be specified.
2377 frequency_breakpoints : ndarray, optional
2378 The frequency breakpoints in the specification. Either this or
2379 num_breakpoints must be specified.
2380 amplitude_breakpoints : ndarray, optional
2381 The amplitude breakpoints of the specification, with shape num_freq,
2382 num_channels. If not specified, amplitudes will be 0.
2383 phase_breakpoints : ndarray, optional
2384 The phase breakpoints of the specification, with shape num_freq,
2385 num_channels. If not specified, phase will be 0. Phases should be
2386 in radians.
2387 sweep_type_breakpoints : ndarray, optional
2388 Should be a 0 if linear sweep or 1 if logarithmic sweep. Linear if
2389 not specified.
2390 sweep_rate_breakpoints : ndarray, optional
2391 Sweep rate at each breakpoint. Hz/s if linear and oct/min if
2392 logarithmic sweep
2393 warning_breakpoints : ndarray, optional
2394 A 4D array of warning amplitudes with shape num_freq, 2, 2,
2395 num_channels. The second dimension uses the 0 index for the lower
2396 warning limit and the 1 index for the upper warning limit. The
2397 third dimension uses the 0 index for the "left" or "lower frequency"
2398 side of the breakpoint and 1 for the "right" or "higher frequency"
2399 side of the breakpoing. If not specified, no warnings will be
2400 specified.
2401 abort_breakpoints : ndarray, optional
2402 A 4D array of abort amplitudes with shape num_freq, 2, 2,
2403 num_channels. The second dimension uses the 0 index for the lower
2404 warning limit and the 1 index for the upper warning limit. The
2405 third dimension uses the 0 index for the "left" or "lower frequency"
2406 side of the breakpoint and 1 for the "right" or "higher frequency"
2407 side of the breakpoing. If not specified, no aborts will be
2408 specified.
2409
2410 Raises
2411 ------
2412 ValueError
2413 if not one of frequency_breakpoints or num_breakpoints is specified
2414 """
2415 spec_dtype = [
2416 ("frequency", "f8"),
2417 ("amplitude", "f8", (num_control,)),
2418 ("phase", "f8", (num_control,)),
2419 ("sweep_type", "u1"),
2420 ("sweep_rate", "f8"),
2421 ("warning", "f8", (2, 2, num_control)),
2422 ("abort", "f8", (2, 2, num_control)),
2423 ]
2424 if frequency_breakpoints is None and num_breakpoints is None:
2425 raise ValueError("Must specify either number of breakpoints or breakpoint frequencies.")
2426 if frequency_breakpoints is None:
2427 self.breakpoint_table = np.zeros(num_breakpoints, dtype=spec_dtype)
2428 else:
2429 self.breakpoint_table = np.zeros(frequency_breakpoints.shape[0], dtype=spec_dtype)
2430 self.breakpoint_table["frequency"] = frequency_breakpoints
2431 if amplitude_breakpoints is not None:
2432 self.breakpoint_table["amplitude"] = amplitude_breakpoints
2433 if phase_breakpoints is not None:
2434 self.breakpoint_table["phase"] = phase_breakpoints # Radians
2435 if sweep_type_breakpoints is not None:
2436 self.breakpoint_table["sweep_type"][:-1] = sweep_type_breakpoints
2437 if sweep_rate_breakpoints is not None:
2438 self.breakpoint_table["sweep_rate"][:-1] = sweep_rate_breakpoints
2439 if warning_breakpoints is not None:
2440 self.breakpoint_table["warning"] = warning_breakpoints
2441 else:
2442 self.breakpoint_table["warning"] = np.nan
2443 if abort_breakpoints is not None:
2444 self.breakpoint_table["abort"] = abort_breakpoints
2445 else:
2446 self.breakpoint_table["abort"] = np.nan
2447 self.start_time = start_time
2448 self.name = name
2449
2450 def copy(self):
2451 """Creates a copy of the sine specification"""
2452 return SineSpecification(
2453 self.name,
2454 self.start_time,
2455 self.breakpoint_table["amplitude"].shape[-1],
2456 frequency_breakpoints=self.breakpoint_table["frequency"].copy(),
2457 amplitude_breakpoints=self.breakpoint_table["amplitude"].copy(),
2458 phase_breakpoints=self.breakpoint_table["phase"].copy(),
2459 sweep_type_breakpoints=self.breakpoint_table["sweep_type"][:-1].copy(),
2460 sweep_rate_breakpoints=self.breakpoint_table["sweep_rate"][:-1].copy(),
2461 warning_breakpoints=self.breakpoint_table["warning"].copy(),
2462 abort_breakpoints=self.breakpoint_table["abort"].copy(),
2463 )
2464
2465 def create_signal(
2466 self,
2467 sample_rate,
2468 ramp_samples=0,
2469 control_index=None,
2470 ignore_start_time=False,
2471 only_breakpoints=False,
2472 ):
2473 """
2474 Creates a signal from the sine specification
2475
2476 Parameters
2477 ----------
2478 sample_rate : float
2479 The sample rate of the signal generated
2480 ramp_samples : int, optional
2481 The number of samples to add to the start and end due to ramp up
2482 or ramp down. The default is 0.
2483 control_index : int, optional
2484 The channel index to generate a signal for. If not specified,
2485 all channels will be generated.
2486 ignore_start_time : bool, optional
2487 If True, ignore the start time and have the sine sweep start
2488 immediately in the signal. The default is False.
2489 only_breakpoints : bool, optional
2490 If True, only generate values at the breakpoints. The default is
2491 False.
2492
2493 Returns
2494 -------
2495 ordinate : ndarray
2496 The generated signals
2497 frequency : ndarray
2498 The instantaneous frequency at each time step
2499 argument : ndarray
2500 The instantaneous argument at each time step.
2501 amplitude : ndarray
2502 The instantaneous amplitude at each timestep.
2503 phase : ndarray
2504 The instantaneous phase at each timestep in radians.
2505 abscissa : ndarray
2506 The abscissa at each time step.
2507 start_index : int
2508 The sample at which the specification starts taking into account
2509 the start time and the ramp samples.
2510 end_index : int
2511 The sample at which the specification ends taking into accound the
2512 ramp samples.
2513
2514 """
2515 # Convert octave per min to octave per second
2516 sweep_rates = self.breakpoint_table["sweep_rate"].copy()
2517 sweep_rates[self.breakpoint_table["sweep_type"] == 1] = (
2518 sweep_rates[self.breakpoint_table["sweep_type"] == 1] / 60
2519 )
2520 # Create the sweep types array
2521 sweep_types = [
2522 "lin" if sweep_type == 0 else "log"
2523 for sweep_type in self.breakpoint_table["sweep_type"][:-1]
2524 ]
2525 if control_index is None:
2526 ordinate = []
2527 amplitude = []
2528 phase = []
2529 for control_index in range(self.breakpoint_table["amplitude"].shape[-1]):
2530 (
2531 this_ordinate,
2532 argument,
2533 frequency,
2534 this_amplitude,
2535 this_phase,
2536 abscissa,
2537 ) = sine_sweep(
2538 1 / sample_rate,
2539 self.breakpoint_table["frequency"],
2540 sweep_rates,
2541 sweep_types,
2542 self.breakpoint_table["amplitude"][:, control_index],
2543 self.breakpoint_table["phase"][:, control_index],
2544 return_frequency=True,
2545 return_argument=True,
2546 return_amplitude=True,
2547 return_phase=True,
2548 return_abscissa=True,
2549 only_breakpoints=only_breakpoints,
2550 )
2551 ordinate.append(this_ordinate)
2552 amplitude.append(this_amplitude)
2553 phase.append(this_phase)
2554 ordinate = np.array(ordinate)
2555 amplitude = np.array(amplitude)
2556 phase = np.array(phase)
2557 else:
2558 ordinate, argument, frequency, amplitude, phase, abscissa = sine_sweep(
2559 1 / sample_rate,
2560 self.breakpoint_table["frequency"],
2561 sweep_rates,
2562 sweep_types,
2563 self.breakpoint_table["amplitude"][:, control_index],
2564 self.breakpoint_table["phase"][:, control_index],
2565 return_frequency=True,
2566 return_argument=True,
2567 return_amplitude=True,
2568 return_phase=True,
2569 return_abscissa=True,
2570 only_breakpoints=only_breakpoints,
2571 )
2572
2573 if ignore_start_time:
2574 delay_samples = 0
2575 else:
2576 delay_samples = int(sample_rate * self.start_time)
2577 start_index = ramp_samples + delay_samples
2578 if ramp_samples > 0 or delay_samples > 0:
2579 # Create the pre-signal ramp
2580 begin_abscissa = np.arange(-ramp_samples - delay_samples, 0) / sample_rate
2581 begin_arguments = 2 * np.pi * frequency[0] * begin_abscissa + argument[0]
2582 begin_frequencies = np.ones(ramp_samples + delay_samples) * frequency[0]
2583 begin_amplitudes = (
2584 np.concatenate((np.zeros(delay_samples), np.linspace(0, 1, ramp_samples)))
2585 * amplitude[..., [0]]
2586 )
2587 begin_phases = np.ones(ramp_samples + delay_samples) * phase[..., [0]]
2588 begin_signal = begin_amplitudes * np.cos(begin_arguments + begin_phases)
2589
2590 abscissa = np.concatenate((begin_abscissa, abscissa), axis=-1)
2591 ordinate = np.concatenate((begin_signal, ordinate), axis=-1)
2592 frequency = np.concatenate((begin_frequencies, frequency), axis=-1)
2593 amplitude = np.concatenate((begin_amplitudes, amplitude), axis=-1)
2594 phase = np.concatenate((begin_phases, phase), axis=-1)
2595 argument = np.concatenate((begin_arguments, argument), axis=-1)
2596 if ramp_samples > 0:
2597 end_abscissa = np.arange(1, ramp_samples + 1) / sample_rate
2598 end_arguments = 2 * np.pi * frequency[-1] * end_abscissa + argument[-1]
2599 end_frequencies = np.ones(ramp_samples) * frequency[-1]
2600 end_amplitudes = np.linspace(1, 0, ramp_samples) * amplitude[..., [-1]]
2601 end_phases = np.ones(ramp_samples) * phase[..., [-1]]
2602 end_signal = end_amplitudes * np.cos(end_arguments + end_phases)
2603
2604 abscissa = np.concatenate((abscissa, end_abscissa), axis=-1)
2605 ordinate = np.concatenate((ordinate, end_signal), axis=-1)
2606 frequency = np.concatenate((frequency, end_frequencies), axis=-1)
2607 argument = np.concatenate((argument, end_arguments), axis=-1)
2608 amplitude = np.concatenate((amplitude, end_amplitudes), axis=-1)
2609 phase = np.concatenate((phase, end_phases), axis=-1)
2610 end_index = abscissa.shape[-1] - ramp_samples
2611
2612 return (
2613 ordinate,
2614 frequency,
2615 argument,
2616 amplitude,
2617 phase,
2618 abscissa,
2619 start_index,
2620 end_index,
2621 )
2622
2623 def interpolate_warning(self, channel_index, frequencies):
2624 """
2625 Interpolates the warning array at the specified frequencies
2626
2627 Parameters
2628 ----------
2629 channel_index : int
2630 The channel to compute the warning levels with.
2631 frequencies : ndarray
2632 The frequencies at which to compute the warning levels.
2633
2634 Returns
2635 -------
2636 warning_levels
2637 A 2 x num_frequencies array containing the warning levels. The
2638 first index is the lower warning level and the second index is the
2639 upper warning level. If warnings are not specified for certain
2640 values, they will be set to NaN.
2641
2642 """
2643 abscissa = np.repeat(self.breakpoint_table["frequency"], 2)
2644 lower_ordinate = self.breakpoint_table["warning"][:, 0, :, channel_index].flatten()
2645 upper_ordinate = self.breakpoint_table["warning"][:, 1, :, channel_index].flatten()
2646 return np.array(
2647 [
2648 np.interp(frequencies, abscissa, lower_ordinate),
2649 np.interp(frequencies, abscissa, upper_ordinate),
2650 ]
2651 )
2652
2653 def interpolate_abort(self, channel_index, frequencies):
2654 """
2655 Interpolates the abort array at the specified frequencies
2656
2657 Parameters
2658 ----------
2659 channel_index : int
2660 The channel to compute the abort levels with.
2661 frequencies : ndarray
2662 The frequencies at which to compute the abort levels.
2663
2664 Returns
2665 -------
2666 abort_levels
2667 A 2 x num_frequencies array containing the abort levels. The
2668 first index is the lower warning level and the second index is the
2669 upper warning level. If warnings are not specified for certain
2670 values, they will be set to NaN.
2671
2672 """
2673 abscissa = np.repeat(self.breakpoint_table["frequency"], 2)
2674 lower_ordinate = self.breakpoint_table["abort"][:, 0, :, channel_index].flatten()
2675 upper_ordinate = self.breakpoint_table["abort"][:, 1, :, channel_index].flatten()
2676 return np.array(
2677 [
2678 np.interp(frequencies, abscissa, lower_ordinate),
2679 np.interp(frequencies, abscissa, upper_ordinate),
2680 ]
2681 )
2682
2683 @staticmethod
2684 def structured_array_equal(arr1, arr2):
2685 """
2686 A method to check if two sine specification breakpoint tables are equal
2687
2688 Parameters
2689 ----------
2690 arr1 : ndarray
2691 A structured array representing a breakpoint table.
2692 arr2 : ndarray
2693 A structured array representing a breakpoint table.
2694
2695 Returns
2696 -------
2697 bool
2698 True if the two arrays are equal. False otherwise.
2699
2700 """
2701 if arr1.dtype != arr2.dtype:
2702 # print('DTypes Not Equal')
2703 return False
2704 for field in arr1.dtype.names:
2705 field1 = arr1[field]
2706 field2 = arr2[field]
2707 if not np.array_equal(field1, field2, equal_nan=True):
2708 # print(f'Field {field} Not Equal')
2709 # print(field1)
2710 # print(field2)
2711 return False
2712 return True
2713
2714 def __eq__(self, other):
2715 """
2716 A method to check if two sine specifications are equal.
2717
2718 Parameters
2719 ----------
2720 other : SineSpecification
2721 The SineSpecification object to compare against.
2722
2723 Returns
2724 -------
2725 bool
2726 True if the two SineSpecification objects are equal.
2727
2728 """
2729 if not SineSpecification.structured_array_equal(
2730 self.breakpoint_table, other.breakpoint_table
2731 ):
2732 return False
2733 if self.start_time != other.start_time:
2734 return False
2735 return True
2736
2737 @staticmethod
2738 def create_combined_signals(specifications, sample_rate, ramp_samples, control_index=None):
2739 """
2740 Creates a combined signal from many specifications
2741
2742 Parameters
2743 ----------
2744 specifications : list of SineSpecification objects
2745 The various specification objects to combine together to creat the
2746 combined signals.
2747 sample_rate : float
2748 The sample rate at which the signals should be generated.
2749 ramp_samples : int
2750 The number of samples to use in the ramp portion of the signals.
2751 control_index : int, optional
2752 The control channel index at which the signals should be generated.
2753 The default is to generate all channels.
2754
2755 Returns
2756 -------
2757 signal : ndarray
2758 A number of channels by number of timesteps array of time history
2759 values.
2760 order_signals : ndarray
2761 Separate signals for each of the specification tones, in a
2762 num_tones, num_channels, num_timesteps array.
2763 order_frequencies : ndarray
2764 The frequencies associated with each sine tone at each time step
2765 in a num_tones, num_timesteps array.
2766 order_arguments : ndarray
2767 The arguments associated with each sine tone at each time step in a
2768 num_tones, num_timestep array
2769 order_amplitudes : ndarray
2770 The instantaneous amplitudes associated with each sine tone and
2771 channel at each time step in a num_tones, num_channels,
2772 num_timesteps shaped array
2773 order_phases : ndarray
2774 The instantaneous phase associated with each sine tone and
2775 channel at each time step in a num_tones, num_channels,
2776 num_timesteps shaped array. Phases in Radians.
2777 order_start_samples : ndarray
2778 The starting sample for each tone taking into account the start time
2779 and the ramp samples
2780 order_end_samples : ndarray
2781 The end sample for each tone taking into account the ramp samples.
2782
2783 """
2784 order_signals = []
2785 order_arguments = []
2786 order_frequencies = []
2787 order_amplitudes = []
2788 order_phases = []
2789 order_start_samples = []
2790 order_end_samples = []
2791 longest_signal = 0
2792 for spec in specifications:
2793 (
2794 ordinate,
2795 frequency,
2796 argument,
2797 amplitude,
2798 phase,
2799 _,
2800 start_index,
2801 end_index,
2802 ) = spec.create_signal(sample_rate, ramp_samples, control_index)
2803 order_signals.append(ordinate)
2804 order_frequencies.append(frequency)
2805 order_amplitudes.append(amplitude)
2806 order_phases.append(phase)
2807 order_arguments.append(argument)
2808 order_start_samples.append(start_index)
2809 order_end_samples.append(end_index)
2810
2811 if order_signals[-1].shape[-1] > longest_signal:
2812 longest_signal = order_signals[-1].shape[-1]
2813
2814 # Now that we know the longest signals, we know how much we need to pad
2815 # to make all signals the same length
2816 for i, signal in enumerate(order_signals):
2817 extra_samples = longest_signal - signal.shape[-1]
2818 end_abscissa = np.arange(1, extra_samples + 1) / sample_rate
2819 end_arguments = (
2820 2 * np.pi * order_frequencies[i][-1] * end_abscissa + order_arguments[i][-1]
2821 )
2822 end_frequencies = np.ones(extra_samples) * order_frequencies[i][-1]
2823 end_amplitudes = np.zeros((extra_samples)) * order_amplitudes[i][..., [-1]]
2824 end_phases = np.ones(extra_samples) * order_phases[i][..., [-1]]
2825 end_signal = np.zeros(extra_samples) * signal[..., [-1]]
2826 order_signals[i] = np.concatenate((order_signals[i], end_signal), axis=-1)
2827 order_frequencies[i] = np.concatenate((order_frequencies[i], end_frequencies), axis=-1)
2828 order_arguments[i] = np.concatenate((order_arguments[i], end_arguments), axis=-1)
2829 order_amplitudes[i] = np.concatenate((order_amplitudes[i], end_amplitudes), axis=-1)
2830 order_phases[i] = np.concatenate((order_phases[i], end_phases), axis=-1)
2831
2832 order_signals = np.array(order_signals)
2833 order_frequencies = np.array(order_frequencies)
2834 order_arguments = np.array(order_arguments)
2835 order_amplitudes = np.array(order_amplitudes)
2836 order_phases = np.array(order_phases)
2837 order_start_samples = np.array(order_start_samples)
2838 order_end_samples = np.array(order_end_samples)
2839 signal = np.sum(order_signals, axis=0)
2840
2841 return (
2842 signal,
2843 order_signals,
2844 order_frequencies,
2845 order_arguments,
2846 order_amplitudes,
2847 order_phases,
2848 order_start_samples,
2849 order_end_samples,
2850 )
2851
2852
2853class FilterExplorer(QtWidgets.QDialog):
2854 """Dialog box for exploring the Vold-Kalman Filter Settings"""
2855
2856 @staticmethod
2857 def explore_filter_settings(
2858 channel_names,
2859 order_names,
2860 specifications,
2861 current_filter_type,
2862 current_tracking_filter_cutoff,
2863 current_tracking_filter_order,
2864 current_filter_order,
2865 current_bandwidth,
2866 current_block_size,
2867 current_overlap,
2868 sample_rate,
2869 ramp_time,
2870 acquire_size,
2871 parent=None,
2872 ):
2873 """
2874 Brings up the explore filter dialog box
2875
2876 Parameters
2877 ----------
2878 channel_names : list of str
2879 Channel names to use in the dialog box
2880 order_names : list of st
2881 Tone names to use in the dialog box
2882 specifications : list of SineSpecification
2883 Sine specifications used to compute the signals
2884 current_filter_type : int
2885 Choose the starting filter type, 0-DTF, 1-VK
2886 current_tracking_filter_cutoff : float
2887 The cutoff for the tracking filter
2888 current_tracking_filter_order : int
2889 The filter order for the tracking filter.
2890 current_filter_order : int
2891 The filter order for the Vold Kalman filter
2892 current_bandwidth : float
2893 The bandwidth for the Vold Kalman filter.
2894 current_block_size : int
2895 The number of samples to use when computing the Vold Kalman filter.
2896 current_overlap : float
2897 The percentage overlap of the frame size (in percent, so 15, not
2898 0.15).
2899 sample_rate : float
2900 The sample rate of the signals to generate.
2901 ramp_time : float
2902 The ramp time added to the signal to ramp to full level.
2903 acquire_size : int
2904 The acquisition size in number of samples.
2905 parent : QWidget, optional
2906 A parent widget to the dialog box. The default is None.
2907
2908 Returns
2909 -------
2910 result : bool
2911 True if the dialog box was accepted, false if not.
2912 filter_type : int
2913 0 if DTF, 1 if VK.
2914 filter_cutoff : float
2915 The cutoff value for the DTF.
2916 tracking_filter_order : int
2917 The filter order for the DTF.
2918 filter_order : int
2919 The filter order for the VK filter.
2920 filter_bandwidth : float
2921 The bandwidth for the VK filter.
2922 filter_blocksize : int
2923 The number of samples in the analysis block for the VK filter.
2924 filter_overlap : float
2925 The overlap percentage (in percent not fraction) of the VK filter.
2926
2927 """
2928 dialog = FilterExplorer(
2929 parent,
2930 channel_names,
2931 order_names,
2932 specifications,
2933 current_filter_type,
2934 current_tracking_filter_cutoff,
2935 current_tracking_filter_order,
2936 current_filter_order,
2937 current_bandwidth,
2938 current_block_size,
2939 current_overlap,
2940 sample_rate,
2941 ramp_time,
2942 acquire_size,
2943 )
2944 result = dialog.exec_() == QtWidgets.QDialog.Accepted
2945 filter_type = dialog.filter_type_selector.currentIndex()
2946 filter_order = dialog.filter_order_selector.currentIndex() + 1
2947 filter_bandwidth = dialog.filter_bandwidth_selector.value()
2948 filter_blocksize = dialog.filter_block_size_selector.value()
2949 filter_overlap = dialog.filter_block_overlap_selector.value()
2950 filter_cutoff = dialog.tracking_filter_cutoff_selector.value()
2951 tracking_filter_order = dialog.tracking_filter_order_selector.value()
2952 return (
2953 result,
2954 filter_type,
2955 filter_cutoff,
2956 tracking_filter_order,
2957 filter_order,
2958 filter_bandwidth,
2959 filter_blocksize,
2960 filter_overlap,
2961 )
2962
2963 def __init__(
2964 self,
2965 parent,
2966 channel_names,
2967 order_names,
2968 specifications,
2969 current_filter_type,
2970 current_tracking_filter_cutoff,
2971 current_tracking_filter_order,
2972 current_filter_order,
2973 current_bandwidth,
2974 current_block_size,
2975 current_overlap,
2976 sample_rate,
2977 ramp_time,
2978 acquire_size,
2979 ):
2980 super().__init__(parent)
2981 uic.loadUi(filter_explorer_ui_path, self)
2982
2983 for channel_name in channel_names:
2984 self.channel_selector.addItem(channel_name)
2985
2986 self.order_selector.setSelectionMode(QtWidgets.QListWidget.SingleSelection)
2987 for order_name in order_names:
2988 self.order_selector.addItem(order_name)
2989
2990 self.full_time_history_plotter = VaryingNumberOfLinePlot(
2991 self.full_time_history_plot.getPlotItem()
2992 )
2993 self.order_time_history_plotter = VaryingNumberOfLinePlot(
2994 self.order_time_history_plot.getPlotItem()
2995 )
2996 self.order_phase_plotter = VaryingNumberOfLinePlot(self.order_phase_plot.getPlotItem())
2997 self.order_amplitude_plotter = VaryingNumberOfLinePlot(
2998 self.order_amplitude_plot.getPlotItem()
2999 )
3000
3001 self.filter_type_selector.setCurrentIndex(current_filter_type)
3002 self.tracking_filter_order_selector.setValue(current_tracking_filter_order)
3003 self.tracking_filter_cutoff_selector.setValue(current_tracking_filter_cutoff)
3004 self.filter_order_selector.setCurrentIndex(current_filter_order - 1)
3005 self.filter_bandwidth_selector.setValue(current_bandwidth)
3006 self.filter_block_overlap_selector.setValue(current_overlap)
3007 self.filter_block_size_selector.setValue(current_block_size)
3008
3009 self.ramp_time = ramp_time
3010 self.sample_rate = sample_rate
3011 self.specifications = specifications
3012 self.acquire_size = acquire_size
3013
3014 self.signal = None
3015 self.order_signals = None
3016 self.order_frequencies = None
3017 self.order_arguments = None
3018 self.order_amplitudes = None
3019 self.order_phases = None
3020 self.reconstructed_signal = None
3021 self.reconstructed_order_signals = None
3022 self.reconstructed_order_amplitudes = None
3023 self.reconstructed_order_phases = None
3024 self.setWindowTitle("Sine Filter Explorer")
3025
3026 self.change_filter_setting_visibility()
3027
3028 self.create_signals()
3029
3030 self.update_plots()
3031
3032 self.connect_callbacks()
3033
3034 def connect_callbacks(self):
3035 """
3036 Connects callback functions to the filter widgets
3037 """
3038 self.accept_button.clicked.connect(self.accept)
3039 self.reject_button.clicked.connect(self.reject)
3040 self.order_selector.itemSelectionChanged.connect(self.update_plots)
3041 self.channel_selector.currentIndexChanged.connect(self.create_and_plot_signals)
3042 self.filter_type_selector.currentIndexChanged.connect(self.remove_filter_data_and_replot)
3043 self.filter_order_selector.currentIndexChanged.connect(self.remove_filter_data_and_replot)
3044 self.filter_bandwidth_selector.valueChanged.connect(self.remove_filter_data_and_replot)
3045 self.filter_block_size_selector.valueChanged.connect(self.remove_filter_data_and_replot)
3046 self.filter_block_overlap_selector.valueChanged.connect(self.remove_filter_data_and_replot)
3047 self.tracking_filter_cutoff_selector.valueChanged.connect(
3048 self.remove_filter_data_and_replot
3049 )
3050 self.tracking_filter_order_selector.valueChanged.connect(self.remove_filter_data_and_replot)
3051 self.noise_selector.valueChanged.connect(self.remove_filter_data_and_replot)
3052 self.compute_button.clicked.connect(self.compute_filter)
3053 self.filter_type_selector.currentIndexChanged.connect(self.change_filter_setting_visibility)
3054
3055 @property
3056 def ramp_samples(self):
3057 """Number of ramp samples computed from sample rate and ramp time"""
3058 return int(self.ramp_time * self.sample_rate)
3059
3060 @property
3061 def channel_index(self):
3062 """Currently selected channel index"""
3063 return self.channel_selector.currentIndex()
3064
3065 def change_filter_setting_visibility(self):
3066 """Updates the visible widgets based on which filter type is selected"""
3067 isdtf = self.filter_type_selector.currentIndex() == 0
3068 for widget in [
3069 self.filter_order_label,
3070 self.filter_order_selector,
3071 self.filter_block_overlap_label,
3072 self.filter_block_overlap_selector,
3073 self.filter_bandwidth_label,
3074 self.filter_bandwidth_selector,
3075 self.filter_block_size_label,
3076 self.filter_block_size_selector,
3077 ]:
3078 widget.setVisible(not isdtf)
3079 for widget in [
3080 self.tracking_filter_cutoff_label,
3081 self.tracking_filter_cutoff_selector,
3082 self.tracking_filter_order_label,
3083 self.tracking_filter_order_selector,
3084 ]:
3085 widget.setVisible(isdtf)
3086
3087 def create_signals(self):
3088 """
3089 Creates signals from the specification to plot
3090 """
3091 (
3092 self.signal,
3093 self.order_signals,
3094 self.order_frequencies,
3095 self.order_arguments,
3096 self.order_amplitudes,
3097 self.order_phases,
3098 _,
3099 _,
3100 ) = SineSpecification.create_combined_signals(
3101 self.specifications, self.sample_rate, self.ramp_samples, self.channel_index
3102 )
3103
3104 self.reconstructed_signal = None
3105 self.reconstructed_order_signals = None
3106 self.reconstructed_order_amplitudes = None
3107 self.reconstructed_order_phases = None
3108
3109 def remove_filter_data_and_replot(self):
3110 """Removes existing filter data and updates the plots"""
3111 self.reconstructed_signal = None
3112 self.reconstructed_order_signals = None
3113 self.reconstructed_order_amplitudes = None
3114 self.reconstructed_order_phases = None
3115 self.update_plots()
3116
3117 def compute_filter(self):
3118 """Performs the filtering operations"""
3119 if self.filter_type_selector.currentIndex() == 0:
3120 block_size = self.acquire_size
3121 generator = [
3122 digital_tracking_filter_generator(
3123 dt=1 / self.sample_rate,
3124 cutoff_frequency_ratio=self.tracking_filter_cutoff_selector.value() / 100,
3125 filter_order=self.tracking_filter_order_selector.value(),
3126 )
3127 for tone in self.order_signals
3128 ]
3129 for gen in generator:
3130 gen.send(None)
3131 else:
3132 block_size = self.filter_block_size_selector.value()
3133 generator = vold_kalman_filter_generator(
3134 sample_rate=self.sample_rate,
3135 num_orders=self.order_signals.shape[0],
3136 block_size=block_size,
3137 overlap=self.filter_block_overlap_selector.value(),
3138 bandwidth=self.filter_bandwidth_selector.value(),
3139 filter_order=self.filter_order_selector.currentIndex() + 1,
3140 )
3141 generator.send(None)
3142
3143 # print(f"{self.signal.shape=}")
3144 start_index = 0
3145 reconstructed_signals = []
3146 reconstructed_amplitudes = []
3147 reconstructed_phases = []
3148
3149 last_data = False
3150 while not last_data:
3151 end_index = start_index + block_size
3152 block = self.signal[start_index:end_index]
3153 block = block + self.noise_selector.value() * np.random.randn(block.size)
3154 block_arguments = self.order_arguments[:, start_index:end_index]
3155 block_frequencies = self.order_frequencies[:, start_index:end_index]
3156 last_data = end_index >= self.signal.size
3157 if self.filter_type_selector.currentIndex() == 0:
3158 amps = []
3159 phss = []
3160 for arg, freq, gen in zip(block_arguments, block_frequencies, generator):
3161 amp, phs = gen.send((block, freq, arg))
3162 amps.append(amp)
3163 phss.append(phs)
3164 reconstructed_amplitudes.append(np.array(amps))
3165 reconstructed_phases.append(np.array(phss))
3166 reconstructed_signals.append(
3167 np.array(amps) * np.cos(block_arguments + np.array(phss))
3168 )
3169 else:
3170 vk_signals, vk_amplitudes, vk_phases = generator.send(
3171 (block, block_arguments, last_data)
3172 )
3173 if vk_signals is not None:
3174 reconstructed_signals.append(vk_signals)
3175 reconstructed_amplitudes.append(vk_amplitudes)
3176 reconstructed_phases.append(vk_phases)
3177 start_index += block_size
3178
3179 self.reconstructed_order_signals = reconstructed_signals
3180 self.reconstructed_order_amplitudes = reconstructed_amplitudes
3181 self.reconstructed_order_phases = reconstructed_phases
3182 self.reconstructed_signal = [
3183 np.sum(value, axis=0) for value in self.reconstructed_order_signals
3184 ]
3185 # print(f"{[sig.shape for sig in self.reconstructed_order_signals]=}")
3186
3187 self.update_plots()
3188
3189 def update_plots(self):
3190 """Updates the plots"""
3191 abscissa_plot_vals = []
3192 signal_plot_vals = []
3193 frequency_plot_vals = []
3194 amplitude_plot_vals = []
3195 phase_plot_vals = []
3196 full_signal_plot_vals = []
3197 abscissa = np.arange(self.signal.size) / self.sample_rate
3198
3199 try:
3200 selected_index = [
3201 idx.row() for idx in self.order_selector.selectionModel().selectedIndexes()
3202 ][0]
3203 except IndexError:
3204 selected_index = 0
3205 full_signal_plot_vals.append(self.signal)
3206 abscissa_plot_vals.append(abscissa)
3207 signal_plot_vals.append(self.order_signals[selected_index])
3208 frequency_plot_vals.append(self.order_frequencies[selected_index])
3209 amplitude_plot_vals.append(self.order_amplitudes[selected_index])
3210 phase_plot_vals.append(self.order_phases[selected_index] * 180 / np.pi)
3211
3212 if self.reconstructed_order_signals is not None:
3213 if self.plot_separate_frames_selector.isChecked():
3214 start_index = 0
3215 for (
3216 block_order_signal,
3217 block_order_amplitude,
3218 block_order_phase,
3219 block_signal,
3220 ) in zip(
3221 self.reconstructed_order_signals,
3222 self.reconstructed_order_amplitudes,
3223 self.reconstructed_order_phases,
3224 self.reconstructed_signal,
3225 ):
3226 end_index = start_index + block_signal.shape[-1]
3227 block_abscissa = abscissa[start_index:end_index]
3228 block_frequency = self.order_frequencies[selected_index, start_index:end_index]
3229 abscissa_plot_vals.append(block_abscissa)
3230 full_signal_plot_vals.append(block_signal)
3231 signal_plot_vals.append(block_order_signal[selected_index])
3232 frequency_plot_vals.append(block_frequency)
3233 amplitude_plot_vals.append(block_order_amplitude[selected_index])
3234 phase_plot_vals.append(block_order_phase[selected_index] * 180 / np.pi)
3235 start_index = end_index
3236 else:
3237 abscissa_plot_vals.append(abscissa)
3238 full_signal_plot_vals.append(np.concatenate(self.reconstructed_signal, axis=-1))
3239 signal_plot_vals.append(
3240 np.concatenate(self.reconstructed_order_signals, axis=-1)[selected_index]
3241 )
3242 # print(f"{[v.shape for v in self.reconstructed_order_amplitudes]}")
3243 amplitude_plot_vals.append(
3244 np.concatenate(self.reconstructed_order_amplitudes, axis=-1)[selected_index]
3245 )
3246 phase_plot_vals.append(
3247 np.concatenate(self.reconstructed_order_phases, axis=-1)[selected_index]
3248 * 180
3249 / np.pi
3250 )
3251 frequency_plot_vals.append(self.order_frequencies[selected_index])
3252
3253 self.full_time_history_plotter.set_data(abscissa_plot_vals, full_signal_plot_vals)
3254 self.order_time_history_plotter.set_data(abscissa_plot_vals, signal_plot_vals)
3255 self.order_amplitude_plotter.set_data(frequency_plot_vals, amplitude_plot_vals)
3256 self.order_phase_plotter.set_data(frequency_plot_vals, phase_plot_vals)
3257
3258 def create_and_plot_signals(self):
3259 """Creates signals then plots the signals"""
3260 self.create_signals()
3261 self.update_plots()
3262
3263
3264class PlotSineWindow(QtWidgets.QDialog):
3265 """Class defining a subwindow that displays specific channel information"""
3266
3267 def __init__(self, parent, ui, tone_index, channel_index):
3268 """
3269 Creates a window showing amplitude and phase information.
3270
3271 Parameters
3272 ----------
3273 parent : QWidget
3274 Parent of the window.
3275 ui : SineUI
3276 The User Interface of the Sine Controller
3277 tone_index : int
3278 Index specifying the tone to plot
3279 channel_index : int
3280 Index specifying the channel to plot
3281 """
3282 super(QtWidgets.QDialog, self).__init__(parent)
3283 self.setWindowFlags(self.windowFlags() & Qt.Tool)
3284 self.tone_index = tone_index
3285 self.channel_index = channel_index
3286 spec_frequency = ui.specification_frequencies[tone_index]
3287 spec_amplitude = ui.specification_amplitudes[tone_index, channel_index]
3288 spec_phase = wrap(ui.specification_phases[tone_index, channel_index])
3289 spec = ui.environment_parameters.specifications[tone_index]
3290 warn_freq = np.repeat(spec.breakpoint_table["frequency"], 2)
3291 warn_low = spec.breakpoint_table["warning"][:, 0, :, channel_index].flatten()
3292 warn_high = spec.breakpoint_table["warning"][:, 1, :, channel_index].flatten()
3293 abort_low = spec.breakpoint_table["abort"][:, 0, :, channel_index].flatten()
3294 abort_high = spec.breakpoint_table["abort"][:, 1, :, channel_index].flatten()
3295 tone_name = spec.name
3296 channel_name = ui.initialized_control_names[channel_index]
3297 # Now plot the data
3298 layout = QtWidgets.QVBoxLayout()
3299 amp_plotwidget = pqtg.PlotWidget()
3300 layout.addWidget(amp_plotwidget)
3301 phs_plotwidget = pqtg.PlotWidget()
3302 layout.addWidget(phs_plotwidget)
3303 self.setLayout(layout)
3304 amp_plot_item = amp_plotwidget.getPlotItem()
3305 phs_plot_item = phs_plotwidget.getPlotItem()
3306 for plot_item in [amp_plot_item, phs_plot_item]:
3307 plot_item.showGrid(True, True, 0.25)
3308 plot_item.enableAutoRange()
3309 plot_item.getViewBox().enableAutoRange(enable=True)
3310 amp_plot_item.plot(spec_frequency, spec_amplitude, pen={"color": "b", "width": 1})
3311 phs_plot_item.plot(spec_frequency, spec_phase, pen={"color": "b", "width": 1})
3312 amp_plot_item.plot(
3313 warn_freq,
3314 warn_low,
3315 pen={"color": (255, 204, 0), "width": 1, "style": Qt.DashLine},
3316 )
3317 amp_plot_item.plot(
3318 warn_freq,
3319 warn_high,
3320 pen={"color": (255, 204, 0), "width": 1, "style": Qt.DashLine},
3321 )
3322 amp_plot_item.plot(
3323 warn_freq,
3324 abort_low,
3325 pen={"color": (153, 0, 0), "width": 1, "style": Qt.DashLine},
3326 )
3327 amp_plot_item.plot(
3328 warn_freq,
3329 abort_high,
3330 pen={"color": (153, 0, 0), "width": 1, "style": Qt.DashLine},
3331 )
3332 if ui.achieved_excitation_frequencies is not None:
3333 achieved_frequency = np.concatenate(
3334 [fh[tone_index] for fh in ui.achieved_excitation_frequencies]
3335 )
3336 achieved_amplitude = np.concatenate(
3337 [ah[tone_index, channel_index] for ah in ui.achieved_response_amplitudes]
3338 )
3339 achieved_phase = np.concatenate(
3340 [ph[tone_index, channel_index] for ph in ui.achieved_response_phases]
3341 )
3342 else:
3343 achieved_frequency = np.array([0, 1])
3344 achieved_amplitude = np.nan * np.ones(2)
3345 achieved_phase = np.nan * np.ones(2)
3346 self.amp_curve = amp_plot_item.plot(
3347 achieved_frequency, achieved_amplitude, pen={"color": "r", "width": 1}
3348 )
3349 self.phs_curve = phs_plot_item.plot(
3350 achieved_frequency, achieved_phase, pen={"color": "r", "width": 1}
3351 )
3352 self.setWindowTitle(f"{tone_name} {channel_name}")
3353 self.ui = ui
3354 self.show()
3355
3356 def update_plot(self):
3357 """Updates the plots with new data"""
3358 if self.ui.achieved_excitation_frequencies is not None:
3359 achieved_frequency = np.concatenate(
3360 [fh[self.tone_index] for fh in self.ui.achieved_excitation_frequencies]
3361 )
3362 achieved_amplitude = np.concatenate(
3363 [
3364 ah[self.tone_index, self.channel_index]
3365 for ah in self.ui.achieved_response_amplitudes
3366 ]
3367 )
3368 achieved_phase = np.concatenate(
3369 [ph[self.tone_index, self.channel_index] for ph in self.ui.achieved_response_phases]
3370 )
3371 else:
3372 achieved_frequency = np.array([0, 1])
3373 achieved_amplitude = np.nan * np.ones(2)
3374 achieved_phase = np.nan * np.ones(2)
3375 self.amp_curve.setData(achieved_frequency, achieved_amplitude)
3376 self.phs_curve.setData(achieved_frequency, achieved_phase)