import jax
from jax import numpy as np
from jax.test_util import check_grads
from scipy.spatial.transform import Rotation
import unittest
from optimism import LinAlg
from .TestFixture import TestFixture
[docs]
def generate_n_random_symmetric_matrices ( n , minval = 0.0 , maxval = 1.0 ):
key = jax . random . PRNGKey ( 0 )
As = jax . random . uniform ( key , ( n , 3 , 3 ), minval = minval , maxval = maxval )
return jax . vmap ( lambda A : np . dot ( A . T , A ), ( 0 ,))( As )
sqrtm_jit = jax . jit ( LinAlg . sqrtm )
logm_iss_jit = jax . jit ( LinAlg . logm_iss )
[docs]
class TestLinAlg ( TestFixture ):
[docs]
def setUp ( self ):
self . sym_mat = generate_n_random_symmetric_matrices ( 1 )[ 0 ]
# make a matrix with 2 identical eigenvalues
R = Rotation . random ( random_state = 41 ) . as_matrix ()
eigvals = np . array ([ 2. , 0.5 , 2. ])
self . sym_mat_double_degeneracy = R @np . diag ( eigvals ) @R . T
### sqrtm ###
[docs]
def test_sqrtm_jit ( self ):
sqrtC = sqrtm_jit ( self . sym_mat )
self . assertTrue ( not np . isnan ( sqrtC ) . any ())
[docs]
def test_sqrtm ( self ):
mats = generate_n_random_symmetric_matrices ( 100 )
sqrtMats = jax . vmap ( sqrtm_jit , ( 0 ,))( mats )
shouldBeMats = jax . vmap ( lambda A : np . dot ( A , A ), ( 0 ,))( sqrtMats )
self . assertArrayNear ( shouldBeMats , mats , 10 )
[docs]
def test_sqrtm_fwd_mode_derivative ( self ):
check_grads ( LinAlg . sqrtm , ( self . sym_mat ,), order = 2 , modes = [ "fwd" ])
[docs]
def test_sqrtm_rev_mode_derivative ( self ):
check_grads ( LinAlg . sqrtm , ( self . sym_mat ,), order = 2 , modes = [ "rev" ])
[docs]
def test_sqrtm_on_degenerate_eigenvalues ( self ):
C = self . sym_mat_double_degeneracy
sqrtC = LinAlg . sqrtm ( C )
shouldBeC = np . dot ( sqrtC , sqrtC )
self . assertArrayNear ( shouldBeC , C , 12 )
check_grads ( LinAlg . sqrtm , ( C ,), order = 2 , modes = [ "rev" ])
[docs]
def test_sqrtm_on_10x10 ( self ):
key = jax . random . PRNGKey ( 0 )
F = jax . random . uniform ( key , ( 10 , 10 ), minval = 1e-8 , maxval = 10.0 )
C = F . T @F
sqrtC = LinAlg . sqrtm ( C )
shouldBeC = np . dot ( sqrtC , sqrtC )
self . assertArrayNear ( shouldBeC , C , 11 )
[docs]
def test_sqrtm_derivatives_on_10x10 ( self ):
key = jax . random . PRNGKey ( 0 )
F = jax . random . uniform ( key , ( 10 , 10 ), minval = 1e-8 , maxval = 10.0 )
C = F . T @F
check_grads ( LinAlg . sqrtm , ( C ,), order = 1 , modes = [ "fwd" , "rev" ])
### sqrtm ###
[docs]
def test_logm_iss_on_matrix_near_identity ( self ):
key = jax . random . PRNGKey ( 0 )
id_perturbation = 1.0 + jax . random . uniform ( key , ( 3 ,), minval = 1e-8 , maxval = 0.01 )
A = np . diag ( id_perturbation )
logA = LinAlg . logm_iss ( A )
self . assertArrayNear ( logA , np . diag ( np . log ( id_perturbation )), 12 )
[docs]
def test_logm_iss_on_double_degenerate_eigenvalues ( self ):
C = self . sym_mat_double_degeneracy
logC = LinAlg . logm_iss ( C )
explogC = jax . scipy . linalg . expm ( logC )
self . assertArrayNear ( C , explogC , 8 )
[docs]
def test_logm_iss_on_triple_degenerate_eigvalues ( self ):
A = 4.0 * np . identity ( 3 )
logA = LinAlg . logm_iss ( A )
self . assertArrayNear ( logA , np . log ( 4.0 ) * np . identity ( 3 ), 12 )
[docs]
def test_logm_iss_jit ( self ):
C = generate_n_random_symmetric_matrices ( 1 )[ 0 ]
logC = logm_iss_jit ( C )
self . assertFalse ( np . isnan ( logC ) . any ())
[docs]
def test_logm_iss_on_full_3x3s ( self ):
mats = generate_n_random_symmetric_matrices ( 1000 )
logMats = jax . vmap ( logm_iss_jit , ( 0 ,))( mats )
shouldBeMats = jax . vmap ( lambda A : jax . scipy . linalg . expm ( A ), ( 0 ,))( logMats )
self . assertArrayNear ( shouldBeMats , mats , 7 )
[docs]
def test_logm_iss_fwd_mode_derivative ( self ):
check_grads ( logm_iss_jit , ( self . sym_mat ,), order = 1 , modes = [ 'fwd' ])
[docs]
def test_logm_iss_rev_mode_derivative ( self ):
check_grads ( logm_iss_jit , ( self . sym_mat ,), order = 1 , modes = [ 'rev' ])
[docs]
def test_logm_iss_hessian_on_double_degenerate_eigenvalues ( self ):
C = self . sym_mat_double_degeneracy
check_grads ( jax . jacrev ( LinAlg . logm_iss ), ( C ,), order = 1 , modes = [ 'fwd' ], rtol = 1e-9 , atol = 1e-9 , eps = 1e-5 )
[docs]
def test_logm_iss_derivatives_on_double_degenerate_eigenvalues ( self ):
C = self . sym_mat_double_degeneracy
check_grads ( LinAlg . logm_iss , ( C ,), order = 1 , modes = [ 'fwd' ])
check_grads ( LinAlg . logm_iss , ( C ,), order = 1 , modes = [ 'rev' ])
[docs]
def test_logm_iss_derivatives_on_triple_degenerate_eigenvalues ( self ):
A = 4.0 * np . identity ( 3 )
check_grads ( LinAlg . logm_iss , ( A ,), order = 1 , modes = [ 'fwd' ])
check_grads ( LinAlg . logm_iss , ( A ,), order = 1 , modes = [ 'rev' ])
[docs]
def test_logm_iss_on_10x10 ( self ):
key = jax . random . PRNGKey ( 0 )
F = jax . random . uniform ( key , ( 10 , 10 ), minval = 1e-8 , maxval = 10.0 )
C = F . T @F
logC = LinAlg . logm_iss ( C )
explogC = jax . scipy . linalg . expm ( logC )
self . assertArrayNear ( explogC , C , 8 )
Copy to clipboard