Coverage for src / sdynpy / signal_processing / sdynpy_correlation.py: 12%
90 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-11 16:22 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-11 16:22 +0000
1# -*- coding: utf-8 -*-
2"""
3Functions to compute correlation metrics between datasets
4"""
5"""
6Copyright 2022 National Technology & Engineering Solutions of Sandia,
7LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S.
8Government retains certain rights in this software.
10This program is free software: you can redistribute it and/or modify
11it under the terms of the GNU General Public License as published by
12the Free Software Foundation, either version 3 of the License, or
13(at your option) any later version.
15This program is distributed in the hope that it will be useful,
16but WITHOUT ANY WARRANTY; without even the implied warranty of
17MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18GNU General Public License for more details.
20You should have received a copy of the GNU General Public License
21along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""
24import numpy as np
25import matplotlib.pyplot as plt
26import matplotlib.ticker as ticker
27import copy
29# Look at FRAC and TRAC values
32def mac(phi_1, phi_2=None):
33 if phi_2 is None:
34 phi_2 = phi_1
35 mac = np.zeros([phi_1.shape[-1], phi_2.shape[-1]])
36 for i, shape_1 in enumerate(phi_1.T):
37 for j, shape_2 in enumerate(phi_2.T):
38 mac[i, j] = np.abs(shape_1.T @ shape_2.conj())**2 / \
39 ((shape_1.T @ shape_1.conj()) * (shape_2.T @ shape_2.conj()))
40 return mac
43def frac(fft_1, fft_2=None):
44 if fft_2 is None:
45 fft_2 = fft_1
46 fft_1_original_shape = fft_1.shape
47 fft_1_flattened = fft_1.reshape(-1, fft_1.shape[-1])
48 fft_2_flattened = fft_2.reshape(-1, fft_2.shape[-1])
49 frac = (np.abs(np.sum(fft_1_flattened * fft_2_flattened.conj(), axis=-1))**2
50 / ((np.sum(fft_1_flattened * fft_1_flattened.conj(), axis=-1))
51 * np.sum(fft_2_flattened * fft_2_flattened.conj(), axis=-1)))
52 return frac.reshape(fft_1_original_shape[:-1])
55def trac(th_1, th_2=None):
56 if th_2 is None:
57 th_2 = th_1
58 th_1_original_shape = th_1.shape
59 th_1_flattened = th_1.reshape(-1, th_1.shape[-1])
60 th_2_flattened = th_2.reshape(-1, th_2.shape[-1])
61 trac = (np.abs(np.sum(th_1_flattened * th_2_flattened.conj(), axis=-1))**2
62 / ((np.sum(th_1_flattened * th_1_flattened.conj(), axis=-1))
63 * np.sum(th_2_flattened * th_2_flattened.conj(), axis=-1)))
64 return trac.reshape(th_1_original_shape[:-1])
67def msf(shapes, reference_shapes=None):
68 if reference_shapes is None:
69 reference_shapes = shapes
70 output = (np.einsum('...ij,...ij->...j', shapes, reference_shapes.conj()) /
71 np.einsum('...ij,...ij->...j', reference_shapes, reference_shapes.conj()))
72 return output
75def orthog(shapes_1, mass_matrix, shapes_2=None, scaling=None):
76 if scaling not in ['none', 'unity', None]:
77 raise ValueError('Invalid scaling, should be one of "none", "unity", or None')
78 if shapes_2 is None:
79 shapes_2 = shapes_1
80 mat = np.moveaxis(shapes_1, -2, -1) @ mass_matrix @ shapes_2
81 if scaling == 'unity':
82 diagonal = np.einsum('...ii->...i', mat)
83 scaling = 1/np.sqrt(diagonal)
84 scaling_matrix = np.zeros(mat.shape)
85 scaling_matrix[...,
86 np.arange(scaling_matrix.shape[-2]),
87 np.arange(scaling_matrix.shape[-1])] = scaling
88 mat = scaling_matrix @ mat @ scaling_matrix
89 return mat
92def matrix_plot(shape_matrix, ax=None, display_values=(0.1, 1.1), text_size=12, vmin=0, vmax=1,
93 boundaries=None):
94 if boundaries is None:
95 # Display number not index
96 @ticker.FuncFormatter
97 def major_formatter(x, pos):
98 return '{:0.0f}'.format(x + 1)
99 cm = plt.get_cmap()
100 else:
101 # Add boundaries to the shape matrix
102 boundaries = np.array(boundaries)
103 shape_matrix_original = shape_matrix.copy()
104 n_shapes = shape_matrix_original.shape[0]
105 shape_matrix = np.nan * np.empty([v + len(boundaries) for v in shape_matrix_original.shape])
106 index_map = {i + np.sum(boundaries <= i): i for i in np.arange(n_shapes)}
107 inverse_index_map = {i: i + np.sum(boundaries <= i) for i in np.arange(n_shapes)}
108 outputs = np.arange(n_shapes)
109 inputs = np.array([inverse_index_map[i] for i in outputs])
110 shape_matrix[
111 inputs[:, np.newaxis], inputs
112 ] = shape_matrix_original[
113 outputs[:, np.newaxis], outputs]
114 cm = copy.copy(plt.get_cmap())
115 cm.set_bad(color='w')
117 @ticker.FuncFormatter
118 def major_formatter(x, pos):
119 x = int(np.round(x))
120 if x not in index_map:
121 return ''
122 else:
123 x = index_map[x]
124 level = np.sum(x >= boundaries)
125 if level == 0:
126 shape_index = x
127 else:
128 shape_index = x - boundaries[level - 1]
129 return '{:},{:}'.format(level + 1, shape_index + 1)
130 if ax is None:
131 fig, ax = plt.subplots()
132 out = ax.imshow(shape_matrix, vmin=vmin, vmax=vmax, cmap=cm)
133 plt.colorbar(out, ax=ax)
134 ax.xaxis.set_major_formatter(major_formatter)
135 ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
136 ax.yaxis.set_major_formatter(major_formatter)
137 ax.yaxis.set_major_locator(ticker.MaxNLocator(integer=True))
138 ax.set_xlabel('Shape Number')
139 ax.set_ylabel('Shape Number')
140 if display_values is not None:
141 for key, val in np.ndenumerate(shape_matrix):
142 if ((True if display_values[0] is None else (val > display_values[0]))
143 and
144 (True if display_values[1] is None else (val <= display_values[1]))):
145 ax.text(key[1], key[0], '{:0.0f}'.format(val * 100),
146 fontdict={'size': text_size}, ha='center', va='center')
147 return ax