Coverage for src / sdynpy / signal_processing / sdynpy_correlation.py: 12%

90 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-11 16:22 +0000

1# -*- coding: utf-8 -*- 

2""" 

3Functions to compute correlation metrics between datasets 

4""" 

5""" 

6Copyright 2022 National Technology & Engineering Solutions of Sandia, 

7LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S. 

8Government retains certain rights in this software. 

9 

10This program is free software: you can redistribute it and/or modify 

11it under the terms of the GNU General Public License as published by 

12the Free Software Foundation, either version 3 of the License, or 

13(at your option) any later version. 

14 

15This program is distributed in the hope that it will be useful, 

16but WITHOUT ANY WARRANTY; without even the implied warranty of 

17MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

18GNU General Public License for more details. 

19 

20You should have received a copy of the GNU General Public License 

21along with this program. If not, see <https://www.gnu.org/licenses/>. 

22""" 

23 

24import numpy as np 

25import matplotlib.pyplot as plt 

26import matplotlib.ticker as ticker 

27import copy 

28 

29# Look at FRAC and TRAC values 

30 

31 

32def mac(phi_1, phi_2=None): 

33 if phi_2 is None: 

34 phi_2 = phi_1 

35 mac = np.zeros([phi_1.shape[-1], phi_2.shape[-1]]) 

36 for i, shape_1 in enumerate(phi_1.T): 

37 for j, shape_2 in enumerate(phi_2.T): 

38 mac[i, j] = np.abs(shape_1.T @ shape_2.conj())**2 / \ 

39 ((shape_1.T @ shape_1.conj()) * (shape_2.T @ shape_2.conj())) 

40 return mac 

41 

42 

43def frac(fft_1, fft_2=None): 

44 if fft_2 is None: 

45 fft_2 = fft_1 

46 fft_1_original_shape = fft_1.shape 

47 fft_1_flattened = fft_1.reshape(-1, fft_1.shape[-1]) 

48 fft_2_flattened = fft_2.reshape(-1, fft_2.shape[-1]) 

49 frac = (np.abs(np.sum(fft_1_flattened * fft_2_flattened.conj(), axis=-1))**2 

50 / ((np.sum(fft_1_flattened * fft_1_flattened.conj(), axis=-1)) 

51 * np.sum(fft_2_flattened * fft_2_flattened.conj(), axis=-1))) 

52 return frac.reshape(fft_1_original_shape[:-1]) 

53 

54 

55def trac(th_1, th_2=None): 

56 if th_2 is None: 

57 th_2 = th_1 

58 th_1_original_shape = th_1.shape 

59 th_1_flattened = th_1.reshape(-1, th_1.shape[-1]) 

60 th_2_flattened = th_2.reshape(-1, th_2.shape[-1]) 

61 trac = (np.abs(np.sum(th_1_flattened * th_2_flattened.conj(), axis=-1))**2 

62 / ((np.sum(th_1_flattened * th_1_flattened.conj(), axis=-1)) 

63 * np.sum(th_2_flattened * th_2_flattened.conj(), axis=-1))) 

64 return trac.reshape(th_1_original_shape[:-1]) 

65 

66 

67def msf(shapes, reference_shapes=None): 

68 if reference_shapes is None: 

69 reference_shapes = shapes 

70 output = (np.einsum('...ij,...ij->...j', shapes, reference_shapes.conj()) / 

71 np.einsum('...ij,...ij->...j', reference_shapes, reference_shapes.conj())) 

72 return output 

73 

74 

75def orthog(shapes_1, mass_matrix, shapes_2=None, scaling=None): 

76 if scaling not in ['none', 'unity', None]: 

77 raise ValueError('Invalid scaling, should be one of "none", "unity", or None') 

78 if shapes_2 is None: 

79 shapes_2 = shapes_1 

80 mat = np.moveaxis(shapes_1, -2, -1) @ mass_matrix @ shapes_2 

81 if scaling == 'unity': 

82 diagonal = np.einsum('...ii->...i', mat) 

83 scaling = 1/np.sqrt(diagonal) 

84 scaling_matrix = np.zeros(mat.shape) 

85 scaling_matrix[..., 

86 np.arange(scaling_matrix.shape[-2]), 

87 np.arange(scaling_matrix.shape[-1])] = scaling 

88 mat = scaling_matrix @ mat @ scaling_matrix 

89 return mat 

90 

91 

92def matrix_plot(shape_matrix, ax=None, display_values=(0.1, 1.1), text_size=12, vmin=0, vmax=1, 

93 boundaries=None): 

94 if boundaries is None: 

95 # Display number not index 

96 @ticker.FuncFormatter 

97 def major_formatter(x, pos): 

98 return '{:0.0f}'.format(x + 1) 

99 cm = plt.get_cmap() 

100 else: 

101 # Add boundaries to the shape matrix 

102 boundaries = np.array(boundaries) 

103 shape_matrix_original = shape_matrix.copy() 

104 n_shapes = shape_matrix_original.shape[0] 

105 shape_matrix = np.nan * np.empty([v + len(boundaries) for v in shape_matrix_original.shape]) 

106 index_map = {i + np.sum(boundaries <= i): i for i in np.arange(n_shapes)} 

107 inverse_index_map = {i: i + np.sum(boundaries <= i) for i in np.arange(n_shapes)} 

108 outputs = np.arange(n_shapes) 

109 inputs = np.array([inverse_index_map[i] for i in outputs]) 

110 shape_matrix[ 

111 inputs[:, np.newaxis], inputs 

112 ] = shape_matrix_original[ 

113 outputs[:, np.newaxis], outputs] 

114 cm = copy.copy(plt.get_cmap()) 

115 cm.set_bad(color='w') 

116 

117 @ticker.FuncFormatter 

118 def major_formatter(x, pos): 

119 x = int(np.round(x)) 

120 if x not in index_map: 

121 return '' 

122 else: 

123 x = index_map[x] 

124 level = np.sum(x >= boundaries) 

125 if level == 0: 

126 shape_index = x 

127 else: 

128 shape_index = x - boundaries[level - 1] 

129 return '{:},{:}'.format(level + 1, shape_index + 1) 

130 if ax is None: 

131 fig, ax = plt.subplots() 

132 out = ax.imshow(shape_matrix, vmin=vmin, vmax=vmax, cmap=cm) 

133 plt.colorbar(out, ax=ax) 

134 ax.xaxis.set_major_formatter(major_formatter) 

135 ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) 

136 ax.yaxis.set_major_formatter(major_formatter) 

137 ax.yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) 

138 ax.set_xlabel('Shape Number') 

139 ax.set_ylabel('Shape Number') 

140 if display_values is not None: 

141 for key, val in np.ndenumerate(shape_matrix): 

142 if ((True if display_values[0] is None else (val > display_values[0])) 

143 and 

144 (True if display_values[1] is None else (val <= display_values[1]))): 

145 ax.text(key[1], key[0], '{:0.0f}'.format(val * 100), 

146 fontdict={'size': text_size}, ha='center', va='center') 

147 return ax