Welcome to the pancax book!
Description
Pancax is a set of tool written on top of jax and equinox to facilitate research that leverages physics informed neural networks (PINNs) in challenging application domains such as computational solid mechanics.
Remainder of the book
The remainder of the book is outlined as follows...
Theory
In this section, basic theory behind PDEs of interest when using pancax are summarized.
Strong Form
Consider a general PDE defined in the strong form as follows where is an arbitrary solution field (potentially vector or tensor valued) and is a differential operator. To fully close an initial boundary value problem (IBVP) we also need appropriate initial and boundary conditions. We can write this generally as where the above represent general initial conditions on the solution field and time derivative and general boundary conditions on the the solution field in the form of general Dirichlet BCs and general Robin/Neumann BCs.
Our goal in pancax is to seek approximate solutions to the above general equation via physics informed neural networks (PINNs). We can do this in a number of ways, but the simplest and most classical is the following approach. We approximate the solution field with a multi-layer perceptron (MLP) With this in hand, we can now take derivatives with respect to or to obtain the necessary derivatives in our operator. For example which put another way is saying to differentiate the outputs of the neural network with respect to a subset of some of it's inputs. This can be achieved in any standard modern neural network library through automatic differentiation. However, this is not always an optimal approach in terms of both runtime and numerical accuracy, especially when it comes to nested differentiation.
Other derivatives can re written similarly and some are given below
One of the more common examples encountered in the PINN literature is the Burger's equation. This is given by
Lagrangian Form
The Lagrangian is given by the difference of the kinetic and potential energy, where for a continuum occupying a domain we have and is problem dependent. For the quasi-static limit we can assume and
Solid Mechanics
The original motivator for writing pancax was for applications in solid mechanics, specifically for leveraging digital image correlation data in inverse problems in PINNs.
The best approach to work with solid mechanics is to build off of an approach known in the PINNs community as the "Deep Energy Method" (DEM) which is really just a variational principle in solid mechanics with neural network approximations for the solution field.
We can write this variational principle as follows
The first variation of this gives rise to the principle of virtual work which can be simplified to and using the identities where and we can re-write the above as Using either of the above forms gives rise to traction boundary condition enforcement by construction.
Implementations
This chapter describes the different implementation types available for PINNs in pancax. They include collocation PINNs, variational PINNs, and Lagrangian PINNs. If you would like a new style of implementation, please open a pull request.
Collocation PINNs
Collocation PINNs are the most commonly encountered type of PINN in the literuate. The main concept is there is an unconnected point cloud (e.g. meshless) of inputs to a neural network . The goal is to minimize the residual of the strong form of the governing equation of interest via standard neural network optimizers such as stochastic gradient descent, Adam, etc.
Let's consider a general equation of the form and define its residual to be and approximate the solution by .
We can define a loss function for the residual as follow which can be expanded to
Regarding boundary conditions (BCs), the most common choice (although usually incorrect) is to weakly enforce BCs via additional loss function terms. Alternatively, signed distance functions (SDFs) can be used to exactly enforce Dirichlet and Neumann BCs by imposing structure on the solution space.
Sub-optimal Enforcement of BCs
For the loss function approach, terms for different BC types are as follows
The total loss function is then
Optimal Enforcement of BCs
For Dirchlet BCs where is a function constructed such that on . This ensures enforcement of Dirichlet BCs by construction and alleviates the burden of learning on the network for this relationship.
More complex relationships are also possible for enforcing Neumann and Robin BCs TODO right this down.
In this case, the total loss function is greatly simplified to just the residual term
Design
This chapter goes through the design of pancax for those interested in embarking on development, whether that be adding a simple physics kernel, or developing an entirely new PINN formulation where pancax is just being leveraged for IO and other utilities common to different PINN formulations.
Examples
All examples require writing a python script (or jupyter notebook if that is your preference) and at the minimum having the below line.
from pancax import *
This will bring in all of the well tested and most used portions of the pancax codebase. There may be in development portions of the code that require more complex import statements.
Table of Contents
- pancax
- pancax.loss_functions.bc_loss_functions
- pancax.loss_functions.weak_form_loss_functions
- pancax.loss_functions.ic_loss_function
- pancax.loss_functions.data_loss_functions
- pancax.loss_functions.strong_form_loss_functions
- pancax.loss_functions.utils
- pancax.loss_functions.base_loss_function
- pancax.loss_functions
- pancax.optimizers.base
- pancax.optimizers.adam
- pancax.optimizers.utils
- pancax.optimizers
- pancax.optimizers.lbfgs
- pancax.fem.traction_bc
- pancax.fem.sparse_matrix_assembler
- pancax.fem.elements.tet4_element
- pancax.fem.elements.quad9_element
- pancax.fem.elements.quad4_element
- pancax.fem.elements
- pancax.fem.elements.line_element
- pancax.fem.elements.hex8_element
- pancax.fem.elements.base_element
- pancax.fem.elements.tet10_element
- pancax.fem.elements.simplex_tri_element
- pancax.fem.read_exodus_mesh
- pancax.fem.function_space
- pancax.fem
- pancax.fem.surface
- pancax.fem.dof_manager
- pancax.fem.quadrature_rules
- pancax.fem.mesh
- pancax.trainer
- pancax.__main__
- pancax.logging
- pancax.math.tensor_math
- pancax.math
- pancax.math.math
- pancax.post_processor
- pancax.deprecated.properties
- pancax.deprecated.physics_loss_functions_strong_form
- pancax.history_writer
- pancax.networks.initialization
- pancax.networks.field_physics_pair
- pancax.networks.base
- pancax.networks.fields
- pancax.networks.parameters
- pancax.networks.elm
- pancax.networks.rbf
- pancax.networks.mlp
- pancax.networks
- pancax.networks.ml_dirichlet_field
- pancax.timer
- pancax.utils
- pancax.data.full_field_data
- pancax.data
- pancax.data.global_data
- pancax.problems.inverse_problem
- pancax.problems.forward_problem
- pancax.problems
- pancax.bvps.biaxial_tension
- pancax.bvps.simple_shear
- pancax.bvps.uniaxial_tension
- pancax.bvps
- pancax.domains.base
- pancax.domains.variational_domain
- pancax.domains.delta_pinn_domain
- pancax.domains
- pancax.domains.collocation_domain
- pancax.physics_kernels.laplace_beltrami
- pancax.physics_kernels.base
- pancax.physics_kernels.beer_lambert_law
- pancax.physics_kernels.poisson
- pancax.physics_kernels.heat_equation
- pancax.physics_kernels.burgers_equation
- pancax.physics_kernels
- pancax.physics_kernels.solid_mechanics
- pancax.constitutive_models.properties
- pancax.constitutive_models.base
- pancax.constitutive_models.mechanics.base
- pancax.constitutive_models.mechanics.hyperelasticity.neohookean
- pancax.constitutive_models.mechanics.hyperelasticity.gent
- pancax.constitutive_models.mechanics.hyperelasticity.blatz_ko
- pancax.constitutive_models.mechanics.hyperelasticity.hencky
- pancax.constitutive_models.mechanics.hyperelasticity.swanson
- pancax.constitutive_models.mechanics.hyperelasticity
- pancax.constitutive_models.mechanics
- pancax.constitutive_models
- pancax.bcs.dirichlet_bc
- pancax.bcs
- pancax.bcs.distance_functions
- pancax.bcs.neumann_bc
pancax
pancax.loss_functions.bc_loss_functions
pancax.loss_functions.weak_form_loss_functions
EnergyLoss Objects
class EnergyLoss(PhysicsLossFunction)
Energy loss function akin to the deep energy method.
Calculates the following quantity
.. math:: \mathcal{L} = w\Pi\left[u\right] = w\int_\Omega\psi\left(\mathbf{F}\right)
Arguments:
weight
: weight for this loss function
EnergyAndResidualLoss Objects
class EnergyAndResidualLoss(PhysicsLossFunction)
Energy and residual loss function used in Hamel et. al
Calculates the following quantity
.. math:: \mathcal{L} = w_1\Pi\left[u\right] + w_2\delta\Pi\left[u\right]_{free}
Arguments:
energy_weight
: Weight for the energy w_1residual_weight
: Weight for the residual w_2
pancax.loss_functions.ic_loss_function
pancax.loss_functions.data_loss_functions
pancax.loss_functions.strong_form_loss_functions
pancax.loss_functions.utils
pancax.loss_functions.base_loss_function
BaseLossFunction Objects
class BaseLossFunction(eqx.Module)
Base class for loss functions. Currently does nothing but helps build a type hierarchy.
BCLossFunction Objects
class BCLossFunction(BaseLossFunction)
Base class for boundary condition loss functions.
A load_step
method is expect with the following
type signature
load_step(self, params, domain, t)
PhysicsLossFunction Objects
class PhysicsLossFunction(BaseLossFunction)
Base class for physics loss functions.
A load_step
method is expect with the following
type signature
load_step(self, params, domain, t)
pancax.loss_functions
pancax.optimizers.base
pancax.optimizers.adam
pancax.optimizers.utils
pancax.optimizers
pancax.optimizers.lbfgs
pancax.fem.traction_bc
pancax.fem.sparse_matrix_assembler
pancax.fem.elements.tet4_element
pancax.fem.elements.quad9_element
pancax.fem.elements.quad4_element
pancax.fem.elements
pancax.fem.elements.line_element
pancax.fem.elements.hex8_element
pancax.fem.elements.base_element
ShapeFunctions Objects
class ShapeFunctions(NamedTuple)
Shape functions and shape function gradients (in the parametric space).
Arguments:
values
: Values of the shape functions at a discrete set of points. Shape is(nPts, nNodes)
, wherenPts
is the number of points at which the shame functinos are evaluated, andnNodes
is the number of nodes in the element (which is equal to the number of shape functions).gradients
: Values of the parametric gradients of the shape functions. Shape is(nPts, nDim, nNodes)
, wherenDim
is the number of spatial dimensions. Line elements are an exception, which have shape(nPts, nNdodes)
.
BaseElement Objects
class BaseElement(eqx.Module)
Base class for different element technologies
Arguments:
elementType
: Element type namedegree
: Polynomial degreecoordinates
: Nodal coordinates in the reference configurationvertexNodes
: Vertex node number, 0-basedfaceNodes
: Nodes associated with each face, 0-basedinteriorNodes
: Nodes in the interior, 0-based or empty
compute_shapes
@abstractmethod
def compute_shapes(nodalPoints, evaluationPoints)
Method to be defined to calculate shape function values and gradients given a list of nodal points (usually the vertexNodes) and a list of evaluation points (usually the quadrature points).
pancax.fem.elements.tet10_element
pancax.fem.elements.simplex_tri_element
pancax.fem.read_exodus_mesh
read_exodus_mesh
def read_exodus_mesh(fileName: str)
Arguments:
fileName
: file name of exodus mesh to read
Returns:
A mesh object
pancax.fem.function_space
NonAllocatedFunctionSpace Objects
class NonAllocatedFunctionSpace(eqx.Module)
compute_field_gradient
def compute_field_gradient(u, X)
Takes in element level coordinates X and field u
evaluate_on_element
def evaluate_on_element(U, X, state, dt, props, func)
Takes in element level field, coordinates, states, etc. and evaluates the function func
FunctionSpace Objects
class FunctionSpace(eqx.Module)
Data needed for calculus on functions in the discrete function space.
In describing the shape of the attributes, ne
is the number of
elements in the mesh, nqpe
is the number of quadrature points per
element, npe
is the number of nodes per element, and nd
is the
spatial dimension of the domain.
Arguments:
shapes
: Shape function values on each element, shape (ne, nqpe, npe)vols
: Volume attributed to each quadrature point. That is, the quadrature weight (on the parameteric element domain) multiplied by the Jacobian determinant of the map from the parent element to the element in the domain. Shape (ne, nqpe).shapeGrads
: Derivatives of the shape functions with respect to the spatial coordinates of the domain. Shape (ne, nqpe, npe, nd).mesh
: TheMesh
object of the domain.quadratureRule
: TheQuadratureRule
on which to sample the shape functions.isAxisymmetric
: boolean indicating if the function space data are axisymmetric.
construct_function_space
def construct_function_space(mesh, quadratureRule, mode2D='cartesian')
Construct a discrete function space.
Arguments:
- `` (
param mesh: The mesh of the domain.
): None - `` (
param quadratureRule: The quadrature rule to be used for integrating on the
): None domain.
: None- `` (
param mode2D: A string indicating how the 2D domain is interpreted for
): None integration. Valid values are ``cartesian`` and ``axisymmetric``.
: None- `Axisymetric mode will include the factor of 2pir in the ``vols```: None
attribute.
: None
Returns:
The ``FunctionSpace`` object.
:
construct_function_space_from_parent_element
def construct_function_space_from_parent_element(mesh,
shapeOnRef,
quadratureRule,
mode2D='cartesian')
Construct a function space with precomputed shape function data on the parent element.
This version of the function space constructor is Jax-transformable, and in particular can be jitted. The computation of the shape function values and derivatives on the parent element is not transformable in general. However, the mapping of the shape function data to the elements in the mesh is transformable. One can precompute the parent element shape functions once and for all, and then use this special factory function to construct the function space and avoid the non-transformable part of the operation. The primary use case is for shape sensitivities: the coordinates of the mesh change, and we want Jax to pick up the sensitivities of the shape function derivatives in space to the coordinate changes (which occurs through the mapping from the parent element to the spatial domain).
Arguments:
- `` (
param mesh: The mesh of the domain.
): None - `` (
param shapeOnRef: A tuple of the shape function values and gradients on the
): None parent element, evaluated at the quadrature points. The caller must
: Nonetake care to ensure the shape functions are evaluated at the same
: Nonepoints as contained in the ``quadratureRule`` parameter.
: None- `` (
param quadratureRule: The quadrature rule to be used for integrating on the
): domain. - `` (
param mode2D: A string indicating how the 2D domain is interpreted for
): None integration. See the default factory function for details.
: None
Returns:
The ``FunctionSpace`` object.
:
integrate_over_block
def integrate_over_block(
functionSpace,
U,
X,
stateVars,
props,
dt,
func,
block,
*params,
modify_element_gradient=default_modify_element_gradient)
Integrates a density function over a block of the mesh.
Arguments:
functionSpace
: Function space object to do the integration with.U
: The vector of dofs for the primal field in the functional.X
: Nodal coordinatesstateVars
: Internal state variable array.dt
: Current time incrementfunc
: Lagrangian density function to integrate, Must have the signaturefunc(u, dudx, q, x, *params) -> scalar
, whereu
is the primal field,q
is the value of the internal variables,x
is the current point coordinates, and*params
is a variadic set of additional parameters, which correspond to the*params
argument. block: Group of elements to integrate over. This is an array of element indices. For performance, the elements within the block should be numbered consecutively.modify_element_gradient
: Optional function that modifies the gradient at the element level. This can be to set the particular 2D mode, and additionally to enforce volume averaging on the gradient operator. This is a keyword-only argument.
Returns
A scalar value for the integral of the density functional func
integrated over the
block of elements.
evaluate_on_block
def evaluate_on_block(functionSpace,
U,
X,
stateVars,
dt,
props,
func,
block,
*params,
modify_element_gradient=default_modify_element_gradient)
Evaluates a density function at every quadrature point in a block of the mesh.
Arguments:
functionSpace
: Function space object to do the evaluation with.U
: The vector of dofs for the primal field in the functional.X
: Nodal coordinatesstateVars
: Internal state variable array.dt
: Current time incrementfunc
: Lagrangian density function to evaluate, Must have the signaturefunc(u, dudx, q, x, *params) -> scalar
, whereu
is the primal field,q
is the value of the internal variables,x
is the current point coordinates, and*params
is a variadic set of additional parameters, which correspond to the*params
argument.block
: Group of elements to evaluate over. This is an array of element indices. For performance, the elements within the block should be numbered consecutively.modify_element_gradient
: Optional function that modifies the gradient at the element level. This can be to set the particular 2D mode, and additionally to enforce volume averaging on the gradient operator. This is a keyword-only argument.
Returns
An array of shape (numElements, numQuadPtsPerElement) that contains the scalar values of the
density functional func
at every quadrature point in the block.
integrate_element_from_local_field
def integrate_element_from_local_field(
elemNodalField,
elemNodalCoords,
elemStates,
dt,
elemShapes,
elemShapeGrads,
elemVols,
func,
modify_element_gradient=default_modify_element_gradient)
Integrate over element with element nodal field as input. This allows element residuals and element stiffness matrices to computed.
get_nodal_values_on_edge
def get_nodal_values_on_edge(functionSpace, nodalField, edge)
Get nodal values of a field on an element edge.
Arguments:
functionSpace
: a FunctionSpace objectnodalField
: The nodal vector defined over the mesh (shape is number of nodes by number of field components)edge
: tuple containing the element number containing the edge and the permutation (0, 1, or 2) of the edge within the triangle
interpolate_nodal_field_on_edge
def interpolate_nodal_field_on_edge(functionSpace, U, interpolationPoints,
edge)
Interpolate a nodal field to specified points on an element edge.
Arguments:
functionSpace
: a FunctionSpace objectU
: the nodal values arrayinterpolationPoints
: coordinates of points (in the 1D parametric space) to interpolate toedge
: tuple containing the element number containing the edge and the permutation (0, 1, or 2) of the edge within the triangle
pancax.fem
pancax.fem.surface
pancax.fem.dof_manager
DofManager Objects
class DofManager()
Collection of arrays needed to differentiate between fixed and free dofs for fem like calculations.
TODO better document the parameters in this guy
__init__
def __init__(mesh, dim: int, DirichletBCs: List[DirichletBC]) -> None
Arguments:
functionSpace
:FunctionSpace
objectdim
: The number of dims (really the number of active dofs for the physics)DirichletBCs
: A list of ofDirichletBC
objects
get_bc_size
def get_bc_size() -> int
Returns:
the number of fixed dofs
get_unknown_size
def get_unknown_size() -> int
Returns:
the size of the unkowns vector
create_field
def create_field(Uu, Ubc=0.0) -> Float[Array, "nn nd"]
Arguments:
Uu
: Vector of unknown valuesUbc
: Values for bc to apply
Returns:
U, a field of unknowns and bcs combined.
get_bc_values
def get_bc_values(U) -> Float[Array, "nb"]
Arguments:
U
: a nodal field
Returns:
the bc values in the field U
get_unknown_values
def get_unknown_values(U) -> Float[Array, "nu"]
Arguments:
U
: a nodal field
Returns:
the unknown values in the field U
pancax.fem.quadrature_rules
QuadratureRule Objects
class QuadratureRule(eqx.Module)
Quadrature rule points and weights.
A namedtuple
containing xigauss
, a numpy array of the
coordinates of the sample points in the reference domain, and
wgauss
, a numpy array with the weights.
Arguments:
xigauss
: coordinates of gauss points in reference elementwgauss
: weights of gauss points in reference element
create_quadrature_rule_1D
def create_quadrature_rule_1D(degree: int) -> QuadratureRule
Creates a Gauss-Legendre quadrature on the unit interval.
The rule can exactly integrate polynomials of degree up to
degree
.
Parameters
degree: Highest degree polynomial to be exactly integrated by the quadrature rule
Returns
A QuadratureRule
named tuple containing the quadrature point coordinates
and the weights.
create_quadrature_rule_on_quad
def create_quadrature_rule_on_quad(quad_degree: int) -> QuadratureRule
Arguments:
quad_degree
: degree of quadrature rule to create
create_quadrature_rule_on_triangle
def create_quadrature_rule_on_triangle(degree: int) -> QuadratureRule
Creates a Gauss-Legendre quadrature on the unit triangle.
The rule can exactly integrate 2D polynomials up to the value of
degree
. The domain is the triangle between the vertices
(0, 0)-(1, 0)-(0, 1). The rules here are guaranteed to be
cyclically symmetric in triangular coordinates and to have strictly
positive weights.
Parameters
degree: Highest degree polynomial to be exactly integrated by the quadrature rule
Returns
A QuadratureRule
named tuple containing the quadrature point coordinates
and the weights.
create_padded_quadrature_rule_1D
def create_padded_quadrature_rule_1D(degree)
Creates 1D Gauss quadrature rule data that are padded to maintain a uniform size, which makes this function jit-able.
This function is inteded to be used only when jit compilation of calls to the quadrature rules are needed. Otherwise, prefer to use the standard quadrature rules. The standard rules do not contain extra 0s for padding, which makes them more efficient when used repeatedly (such as in the global energy).
Arguments:
degree
- degree of highest polynomial to be integrated exactly
pancax.fem.mesh
Mesh Objects
class Mesh(eqx.Module)
Triangle mesh representing a domain.
Arguments:
coords
: Coordinates of the nodes, shape(nNodes, nDim)
.conns
: Nodal connectivity table of the elements.simplexNodesOrdinals
: Indices of the nodes that are vertices.parentElement
: AParentElement
that is the element type in parametric space. A mesh can contain only 1 element type.parentElement1d
:blocks
: A dictionary mapping element block names to the indices of the elements in the block.nodeSets
: A dictionary mapping node set names to the indices of the nodes.sideSets
: A dictionary mapping side set names to the edges. The edge data structure is a tuple of the element index and the local number of the edge within that element. For example, triangle elements will have edge 0, 1, or 2 for this entry.
num_dimensions
@property
def num_dimensions() -> int
Returns:
dimension number of mesh
num_elements
@property
def num_elements() -> int
Returns:
number of elements in mesh
num_nodes
@property
def num_nodes() -> int
number of nodes in mesh
create_edges
def create_edges(conns)
Generate topological information about edges in a triangulation.
Parameters
conns : (nTriangles, 3) array Connectivity table of the triangulation.
Returns
edgeConns : (nEdges, 2) array Vertices of each edge. Boundary edges are always in the counter-clockwise sense, so that the interior of the body is on the left side when walking from the first vertex to the second. edges : (nEdges, 4) array Edge-to-triangle topological information. Each row provides the follwing information for each edge: [leftT, leftP, rightT, rightP], where leftT is the ID of the triangle to the left, leftP is the permutation of the edge in the left triangle (edge 0, 1, or 2), rightT is the ID of the triangle to the right, and rightP is the permutation of the edge in the right triangle. If the edge is a boundary edge, the values of rightT and rightP are -1.
get_edge_field
def get_edge_field(mesh: Mesh, edge, field)
Evaluate field on nodes of an element edge.
Arguments:
Arguments:
mesh
: a Mesh objectedge
: tuple containing the element number containing the edge and the permutation (0, 1, or 2) of the edge within the triangle
compute_edge_vectors
def compute_edge_vectors(mesh: Mesh, edgeCoords)
Get geometric vectors for an element edge.
Assumes that the edgs has a constant shape jacobian, that is, the transformation from the parent element is affine.
Arguments
Arguments:
mesh
: a Mesh objectedgeCoords
: coordinates of all nodes on the edge, in the order defined by the 1D parent element convention
Returns tuple (t, n, j) with
Returns:
t
: the unit tangent vector
pancax.trainer
pancax.__main__
pancax.logging
pancax.math.tensor_math
sqrtm_dbp
def sqrtm_dbp(A)
Matrix square root by product form of Denman-Beavers iteration.
Translated from the Matrix Function Toolbox http://www.ma.man.ac.uk/~higham/mftoolbox Nicholas J. Higham, Functions of Matrices: Theory and Computation, SIAM, Philadelphia, PA, USA, 2008. ISBN 978-0-898716-46-7,
log_pade_pf
def log_pade_pf(A, n)
Logarithmic map by Padé approximant and partial fractions
pancax.math
pancax.math.math
sum2
def sum2(a)
Sum a vector to much higher accuracy than numpy.sum.
Parameters
a : ndarray, with only one axis (shape [n,])
Returns
sum : real The sum of the numbers in the array
This special sum method computes the result as accurate as if computed in quadruple precision.
Reference: T. Ogita, S. M. Rump, and S. Oishi. Accurate sum and dot product. SIAM J. Sci. Comput., Vol 26, No 6, pp. 1955-1988. doi: 10.1137/030601818
dot2
def dot2(x, y)
Compute inner product of 2 vectors to much higher accuracy than numpy.dot.
Arguments:
- `` (
param x: ndarray, with only one axis (shape [n,])
): None - `` (
param y: ndarray, with only one axis (shape [n,])
): None
Returns:
return dotprod: real
: The inner product of the input vectors.
pancax.post_processor
VtkPostProcessor Objects
class VtkPostProcessor()
convert_exodus_to_xml
def convert_exodus_to_xml(exodus_file, xml_file)
Converts an Exodus II file to a VTK XML MultiBlock dataset.
get_vtk_cell_type
def get_vtk_cell_type(exodus_element_type)
Map Exodus element types to VTK cell types.
pancax.deprecated.properties
Properties Objects
class Properties(eqx.Module)
Arguments:
prop_mins
: Minimum allowable propertiesprop_maxs
: Maximum allowable propertiesprop_params
: Actual tunable parameters
__init__
def __init__(prop_mins: jax.Array,
prop_maxs: jax.Array,
key: jax.random.PRNGKey,
activation_func: Optional[Callable] = jax.nn.sigmoid) -> None
Arguments:
prop_mins
: Minimum allowable propertiesprop_maxs
: Maximum allowable propertieskey
: rng key
__call__
def __call__() -> Float[Array, "np"]
Returns:
Predicted properties
FixedProperties Objects
class FixedProperties(Properties)
__init__
def __init__(props: jax.Array) -> None
Arguments:
props
: Property values to be fixed
pancax.deprecated.physics_loss_functions_strong_form
pancax.history_writer
pancax.networks.initialization
zero_init
def zero_init(key: jax.random.PRNGKey, shape) -> Float[Array, "no ni"]
Arguments:
weight
: current weight array for sizingkey
: rng key
Returns:
A new set of weights
trunc_init
def trunc_init(key: jax.random.PRNGKey, shape) -> Float[Array, "no ni"]
Arguments:
weight
: current weight array for sizingkey
: rng key
Returns:
A new set of weights
init_linear_weight
def init_linear_weight(model: eqx.Module, init_fn: Callable,
key: jax.random.PRNGKey) -> eqx.Module
Arguments:
model
: equinox modelinit_fn
: function to initialize weigth withkey
: rng key
Returns:
a new equinox model
init_linear
def init_linear(model: eqx.Module, init_fn: Callable, key: jax.random.PRNGKey)
Arguments:
model
: equinox modelinit_fn
: function to initialize weigth withkey
: rng key
Returns:
a new equinox model
pancax.networks.field_physics_pair
FieldPhysicsPair Objects
class FieldPhysicsPair(BasePancaxModel)
Data structure for storing a set of field network
parameters and a physics object
Arguments:
fields
: field network parameters objectphysics
: physics object
__iter__
def __iter__()
Iterator for user friendliness
pancax.networks.base
BasePancaxModel Objects
class BasePancaxModel(eqx.Module)
Base class for pancax model parameters.
This includes a few helper methods
pancax.networks.fields
pancax.networks.parameters
Parameters Objects
class Parameters(BasePancaxModel)
Data structure for storing all parameters
needed for a model
Arguments:
field
: field network parameters objectphysics
: physics objectstate
: state object
pancax.networks.elm
pancax.networks.rbf
pancax.networks.mlp
Linear
def Linear(n_inputs: int, n_outputs: int, key: jax.random.PRNGKey)
Arguments:
n_inputs
: Number of inputs to linear layern_outputs
: Number of outputs of the linear layerkey
: rng key
Returns:
Equinox Linear layer
pancax.networks
pancax.networks.ml_dirichlet_field
pancax.timer
TimerError Objects
class TimerError(Exception)
A custom exception used to report errors in use of Timer class
Timer Objects
@dataclass
class Timer(ContextDecorator)
Time your code using a class, context manager, or decorator
__post_init__
def __post_init__() -> None
Initialization: add timer to dict of timers
start
def start() -> None
Start a new timer
stop
def stop() -> float
Stop the timer, and report the elapsed time
__enter__
def __enter__() -> "Timer"
Start a new timer as a context manager
__exit__
def __exit__(*exc_info: Any) -> None
Stop the context manager timer
pancax.utils
pancax.data.full_field_data
FullFieldData Objects
class FullFieldData(eqx.Module)
Data structure to store full field data used as ground truth
for output fields of a PINN when solving inverse problems.
Arguments:
inputs
: Data that serves as inputs to the PINNoutputs
: Data that serves as outputs of the PINNn_time_steps
: Variable used for book keeping
pancax.data
pancax.data.global_data
GlobalData Objects
class GlobalData(eqx.Module)
Data structure that holds global data to be used as
ground truth for some global field calculated from PINN outputs used in inverse modeling training
Arguments:
times
: A set of times used to compare to physics calculationsdisplacements
: Currently hardcoded to use a displacement-force curve TODOoutputs
: Field used as ground truth, hardcoded essentially to a reaction force nown_nodes
: Book-keeping variable for number of nodes on nodeset to measure global response fromn_time_steps
: Book-keeping variablereaction_nodes
: Node set nodes for where to measure reaction forcesreaction_dof
: Degree of freedom to use for reaction force calculation
times
change to inputs?
pancax.problems.inverse_problem
pancax.problems.forward_problem
pancax.problems
pancax.bvps.biaxial_tension
pancax.bvps.simple_shear
pancax.bvps.uniaxial_tension
pancax.bvps
pancax.domains.base
pancax.domains.variational_domain
pancax.domains.delta_pinn_domain
pancax.domains
pancax.domains.collocation_domain
pancax.physics_kernels.laplace_beltrami
pancax.physics_kernels.base
nodal_pp
def nodal_pp(func, has_props=False, jit=True)
Arguments:
func
: Function to use for a nodal property output variablehas_props
: Whether or not this function need propertiesjit
: Whether or not to jit this function
BasePhysics Objects
class BasePhysics(eqx.Module)
var_name_to_method
= field(default_factory=lambda: {})
dirichlet_bc_func
= lambda x, t, z: z
x_mins
= jnp.zeros(3)
x_maxs
= jnp.zeros(3)
pancax.physics_kernels.beer_lambert_law
pancax.physics_kernels.poisson
pancax.physics_kernels.heat_equation
pancax.physics_kernels.burgers_equation
pancax.physics_kernels
pancax.physics_kernels.solid_mechanics
pancax.constitutive_models.properties
pancax.constitutive_models.base
pancax.constitutive_models.mechanics.base
MechanicsModel Objects
class MechanicsModel(ConstitutiveModel)
energy
@abstractmethod
def energy(grad_u: Tensor, *args) -> Scalar
This method returns the algorithmic strain energy density.
I1
def I1(grad_u: Tensor) -> Scalar
Calculates the first invariant
- grad_u: the displacement gradient
$$ I_1 = tr\left(\mathbf{F}^T\mathbf{F}\right) $$
I1_bar
def I1_bar(grad_u: Tensor) -> Scalar
Calculates the first distortional invariant
- grad_u: the displacement gradient
$$ \bar{I}_1 = J^{-2/3}tr\left(\mathbf{F}^T\mathbf{F}\right) $$
jacobian
def jacobian(grad_u: Tensor) -> Scalar
This simply calculate the jacobian but with guard rails to return nonsensical numbers if a non-positive jacobian is encountered during training.
- grad_u: the displacement gradient
$$ J = det(\mathbf{F}) $$
pancax.constitutive_models.mechanics.hyperelasticity.neohookean
NeoHookean Objects
class NeoHookean(HyperelasticModel)
NeoHookean model with the following model form
$$
\psi(\mathbf{F}) = \frac{1}{2}K\left[\frac{1}{2}\left(J^2 - \ln J\right)\right] + \frac{1}{2}G\left(\bar{I}_1 - 3\right) $$
pancax.constitutive_models.mechanics.hyperelasticity.gent
Gent Objects
class Gent(HyperelasticModel)
Gent model with the following model form
$$
\psi(\mathbf{F}) = \frac{1}{2}K\left[\frac{1}{2}\left(J^2 - \ln J\right)\right] - \frac{1}{2}GJ_m\ln\left(1 - \frac{\bar{I}_1 - 3}{J_m}\right) $$
pancax.constitutive_models.mechanics.hyperelasticity.blatz_ko
pancax.constitutive_models.mechanics.hyperelasticity.hencky
pancax.constitutive_models.mechanics.hyperelasticity.swanson
Swanson Objects
class Swanson(HyperelasticModel)
Swanson model truncated to 4 parameters
$$
\psi(\mathbf{F}) = K\left(J\ln J - J + 1\right) + \frac{3}{2}A_1\left(\frac{\bar{I}_1}{3} - 1\right)^{P_1} + \frac{3}{2}C_1\left(\frac{\bar{I}_1}{3} - 1\right)^{R_1} $$
pancax.constitutive_models.mechanics.hyperelasticity
pancax.constitutive_models.mechanics
pancax.constitutive_models
pancax.bcs.dirichlet_bc
DirichletBC Objects
class DirichletBC(eqx.Module)
Arguments:
nodeSet
: A name for a nodeset in the meshcomponent
: The dof to apply the dirichlet bc tofunction
: A function f(x, t) = u that gives the value to enforce on the (nodeset, component) of a field. This defaults to the zero function