Source code for optimism.test.test_JaxConfig

import unittest

from optimism.JaxConfig import *
from . import TestFixture

from jax import config


[docs] class TestJaxConfiguration(TestFixture.TestFixture):
[docs] def test_double_precision_mode_is_on(self): a = np.array([1.0]) self.assertTrue(a.dtype == np.float64)
[docs] def test_debug_nans_is_off(self): try: np.log(-1.0) except FloatingPointError: self.fail("Floating point exception raised. Check to see if 'jax_debug_nans' has been left activated from a debugging session.")
[docs] def test_debug_infs_is_off(self): try: a = np.array([1.0]) a/0 except FloatingPointError: self.fail("Floating point exception raised. Check to see if 'jax_debug_infs' has been left activated from a debugging session.")
[docs] def test_jit_is_enabled(self): self.assertFalse(config.jax_disable_jit)
if __name__ == "__main__": unittest.main()