import unittest
from optimism.JaxConfig import *
from . import TestFixture
from jax import config
[docs]
class TestJaxConfiguration(TestFixture.TestFixture):
[docs]
    def test_double_precision_mode_is_on(self):
        a = np.array([1.0])
        self.assertTrue(a.dtype == np.float64) 
[docs]
    def test_debug_nans_is_off(self):
        try:
            np.log(-1.0)
        except FloatingPointError:
            self.fail("Floating point exception raised. Check to see if 'jax_debug_nans' has been left activated from a debugging session.") 
[docs]
    def test_debug_infs_is_off(self):
        try:
            a = np.array([1.0])
            a/0
        except FloatingPointError:
            self.fail("Floating point exception raised. Check to see if 'jax_debug_infs' has been left activated from a debugging session.") 
[docs]
    def test_jit_is_enabled(self):
        self.assertFalse(config.jax_disable_jit) 
 
if __name__ == "__main__":
    unittest.main()