import jax
from jax import numpy as np
from jax.test_util import check_grads
from scipy.spatial.transform import Rotation
import unittest
from optimism import LinAlg
from .TestFixture import TestFixture
[docs]
def generate_n_random_symmetric_matrices(n, minval=0.0, maxval=1.0):
key = jax.random.PRNGKey(0)
As = jax.random.uniform(key, (n,3,3), minval=minval, maxval=maxval)
return jax.vmap(lambda A: np.dot(A.T,A), (0,))(As)
sqrtm_jit = jax.jit(LinAlg.sqrtm)
logm_iss_jit = jax.jit(LinAlg.logm_iss)
[docs]
class TestLinAlg(TestFixture):
[docs]
def setUp(self):
self.sym_mat = generate_n_random_symmetric_matrices(1)[0]
# make a matrix with 2 identical eigenvalues
R = Rotation.random(random_state=41).as_matrix()
eigvals = np.array([2., 0.5, 2.])
self.sym_mat_double_degeneracy = R@np.diag(eigvals)@R.T
### sqrtm ###
[docs]
def test_sqrtm_jit(self):
sqrtC = sqrtm_jit(self.sym_mat)
self.assertTrue(not np.isnan(sqrtC).any())
[docs]
def test_sqrtm(self):
mats = generate_n_random_symmetric_matrices(100)
sqrtMats = jax.vmap(sqrtm_jit, (0,))(mats)
shouldBeMats = jax.vmap(lambda A: np.dot(A,A), (0,))(sqrtMats)
self.assertArrayNear(shouldBeMats, mats, 10)
[docs]
def test_sqrtm_fwd_mode_derivative(self):
check_grads(LinAlg.sqrtm, (self.sym_mat,), order=2, modes=["fwd"])
[docs]
def test_sqrtm_rev_mode_derivative(self):
check_grads(LinAlg.sqrtm, (self.sym_mat,), order=2, modes=["rev"])
[docs]
def test_sqrtm_on_degenerate_eigenvalues(self):
C = self.sym_mat_double_degeneracy
sqrtC = LinAlg.sqrtm(C)
shouldBeC = np.dot(sqrtC, sqrtC)
self.assertArrayNear(shouldBeC, C, 12)
check_grads(LinAlg.sqrtm, (C,), order=2, modes=["rev"])
[docs]
def test_sqrtm_on_10x10(self):
key = jax.random.PRNGKey(0)
F = jax.random.uniform(key, (10,10), minval=1e-8, maxval=10.0)
C = F.T@F
sqrtC = LinAlg.sqrtm(C)
shouldBeC = np.dot(sqrtC,sqrtC)
self.assertArrayNear(shouldBeC, C, 11)
[docs]
def test_sqrtm_derivatives_on_10x10(self):
key = jax.random.PRNGKey(0)
F = jax.random.uniform(key, (10,10), minval=1e-8, maxval=10.0)
C = F.T@F
check_grads(LinAlg.sqrtm, (C,), order=1, modes=["fwd", "rev"])
### sqrtm ###
[docs]
def test_logm_iss_on_matrix_near_identity(self):
key = jax.random.PRNGKey(0)
id_perturbation = 1.0 + jax.random.uniform(key, (3,), minval=1e-8, maxval=0.01)
A = np.diag(id_perturbation)
logA = LinAlg.logm_iss(A)
self.assertArrayNear(logA, np.diag(np.log(id_perturbation)), 12)
[docs]
def test_logm_iss_on_double_degenerate_eigenvalues(self):
C = self.sym_mat_double_degeneracy
logC = LinAlg.logm_iss(C)
explogC = jax.scipy.linalg.expm(logC)
self.assertArrayNear(C, explogC, 8)
[docs]
def test_logm_iss_on_triple_degenerate_eigvalues(self):
A = 4.0*np.identity(3)
logA = LinAlg.logm_iss(A)
self.assertArrayNear(logA, np.log(4.0)*np.identity(3), 12)
[docs]
def test_logm_iss_jit(self):
C = generate_n_random_symmetric_matrices(1)[0]
logC = logm_iss_jit(C)
self.assertFalse(np.isnan(logC).any())
[docs]
def test_logm_iss_on_full_3x3s(self):
mats = generate_n_random_symmetric_matrices(1000)
logMats = jax.vmap(logm_iss_jit, (0,))(mats)
shouldBeMats = jax.vmap(lambda A: jax.scipy.linalg.expm(A), (0,))(logMats)
self.assertArrayNear(shouldBeMats, mats, 7)
[docs]
def test_logm_iss_fwd_mode_derivative(self):
check_grads(logm_iss_jit, (self.sym_mat,), order=1, modes=['fwd'])
[docs]
def test_logm_iss_rev_mode_derivative(self):
check_grads(logm_iss_jit, (self.sym_mat,), order=1, modes=['rev'])
[docs]
def test_logm_iss_hessian_on_double_degenerate_eigenvalues(self):
C = self.sym_mat_double_degeneracy
check_grads(jax.jacrev(LinAlg.logm_iss), (C,), order=1, modes=['fwd'], rtol=1e-9, atol=1e-9, eps=1e-5)
[docs]
def test_logm_iss_derivatives_on_double_degenerate_eigenvalues(self):
C = self.sym_mat_double_degeneracy
check_grads(LinAlg.logm_iss, (C,), order=1, modes=['fwd'])
check_grads(LinAlg.logm_iss, (C,), order=1, modes=['rev'])
[docs]
def test_logm_iss_derivatives_on_triple_degenerate_eigenvalues(self):
A = 4.0*np.identity(3)
check_grads(LinAlg.logm_iss, (A,), order=1, modes=['fwd'])
check_grads(LinAlg.logm_iss, (A,), order=1, modes=['rev'])
[docs]
def test_logm_iss_on_10x10(self):
key = jax.random.PRNGKey(0)
F = jax.random.uniform(key, (10,10), minval=1e-8, maxval=10.0)
C = F.T@F
logC = LinAlg.logm_iss(C)
explogC = jax.scipy.linalg.expm(logC)
self.assertArrayNear(explogC, C, 8)