Source code for optimism.test.test_LinAlg

import jax
from jax import numpy as np
from jax.test_util import check_grads
from scipy.spatial.transform import Rotation
import unittest

from optimism import LinAlg
from .TestFixture import TestFixture

[docs] def generate_n_random_symmetric_matrices(n, minval=0.0, maxval=1.0): key = jax.random.PRNGKey(0) As = jax.random.uniform(key, (n,3,3), minval=minval, maxval=maxval) return jax.vmap(lambda A: np.dot(A.T,A), (0,))(As)
sqrtm_jit = jax.jit(LinAlg.sqrtm) logm_iss_jit = jax.jit(LinAlg.logm_iss)
[docs] class TestLinAlg(TestFixture):
[docs] def setUp(self): self.sym_mat = generate_n_random_symmetric_matrices(1)[0] # make a matrix with 2 identical eigenvalues R = Rotation.random(random_state=41).as_matrix() eigvals = np.array([2., 0.5, 2.]) self.sym_mat_double_degeneracy = R@np.diag(eigvals)@R.T
### sqrtm ###
[docs] def test_sqrtm_jit(self): sqrtC = sqrtm_jit(self.sym_mat) self.assertTrue(not np.isnan(sqrtC).any())
[docs] def test_sqrtm(self): mats = generate_n_random_symmetric_matrices(100) sqrtMats = jax.vmap(sqrtm_jit, (0,))(mats) shouldBeMats = jax.vmap(lambda A: np.dot(A,A), (0,))(sqrtMats) self.assertArrayNear(shouldBeMats, mats, 10)
[docs] def test_sqrtm_fwd_mode_derivative(self): check_grads(LinAlg.sqrtm, (self.sym_mat,), order=2, modes=["fwd"])
[docs] def test_sqrtm_rev_mode_derivative(self): check_grads(LinAlg.sqrtm, (self.sym_mat,), order=2, modes=["rev"])
[docs] def test_sqrtm_on_degenerate_eigenvalues(self): C = self.sym_mat_double_degeneracy sqrtC = LinAlg.sqrtm(C) shouldBeC = np.dot(sqrtC, sqrtC) self.assertArrayNear(shouldBeC, C, 12) check_grads(LinAlg.sqrtm, (C,), order=2, modes=["rev"])
[docs] def test_sqrtm_on_10x10(self): key = jax.random.PRNGKey(0) F = jax.random.uniform(key, (10,10), minval=1e-8, maxval=10.0) C = F.T@F sqrtC = LinAlg.sqrtm(C) shouldBeC = np.dot(sqrtC,sqrtC) self.assertArrayNear(shouldBeC, C, 11)
[docs] def test_sqrtm_derivatives_on_10x10(self): key = jax.random.PRNGKey(0) F = jax.random.uniform(key, (10,10), minval=1e-8, maxval=10.0) C = F.T@F check_grads(LinAlg.sqrtm, (C,), order=1, modes=["fwd", "rev"])
### sqrtm ###
[docs] def test_logm_iss_on_matrix_near_identity(self): key = jax.random.PRNGKey(0) id_perturbation = 1.0 + jax.random.uniform(key, (3,), minval=1e-8, maxval=0.01) A = np.diag(id_perturbation) logA = LinAlg.logm_iss(A) self.assertArrayNear(logA, np.diag(np.log(id_perturbation)), 12)
[docs] def test_logm_iss_on_double_degenerate_eigenvalues(self): C = self.sym_mat_double_degeneracy logC = LinAlg.logm_iss(C) explogC = jax.scipy.linalg.expm(logC) self.assertArrayNear(C, explogC, 8)
[docs] def test_logm_iss_on_triple_degenerate_eigvalues(self): A = 4.0*np.identity(3) logA = LinAlg.logm_iss(A) self.assertArrayNear(logA, np.log(4.0)*np.identity(3), 12)
[docs] def test_logm_iss_jit(self): C = generate_n_random_symmetric_matrices(1)[0] logC = logm_iss_jit(C) self.assertFalse(np.isnan(logC).any())
[docs] def test_logm_iss_on_full_3x3s(self): mats = generate_n_random_symmetric_matrices(1000) logMats = jax.vmap(logm_iss_jit, (0,))(mats) shouldBeMats = jax.vmap(lambda A: jax.scipy.linalg.expm(A), (0,))(logMats) self.assertArrayNear(shouldBeMats, mats, 7)
[docs] def test_logm_iss_fwd_mode_derivative(self): check_grads(logm_iss_jit, (self.sym_mat,), order=1, modes=['fwd'])
[docs] def test_logm_iss_rev_mode_derivative(self): check_grads(logm_iss_jit, (self.sym_mat,), order=1, modes=['rev'])
[docs] def test_logm_iss_hessian_on_double_degenerate_eigenvalues(self): C = self.sym_mat_double_degeneracy check_grads(jax.jacrev(LinAlg.logm_iss), (C,), order=1, modes=['fwd'], rtol=1e-9, atol=1e-9, eps=1e-5)
[docs] def test_logm_iss_derivatives_on_double_degenerate_eigenvalues(self): C = self.sym_mat_double_degeneracy check_grads(LinAlg.logm_iss, (C,), order=1, modes=['fwd']) check_grads(LinAlg.logm_iss, (C,), order=1, modes=['rev'])
[docs] def test_logm_iss_derivatives_on_triple_degenerate_eigenvalues(self): A = 4.0*np.identity(3) check_grads(LinAlg.logm_iss, (A,), order=1, modes=['fwd']) check_grads(LinAlg.logm_iss, (A,), order=1, modes=['rev'])
[docs] def test_logm_iss_on_10x10(self): key = jax.random.PRNGKey(0) F = jax.random.uniform(key, (10,10), minval=1e-8, maxval=10.0) C = F.T@F logC = LinAlg.logm_iss(C) explogC = jax.scipy.linalg.expm(logC) self.assertArrayNear(explogC, C, 8)