fromcollectionsimportnamedtuplefromoptimism.JaxConfigimport*fromoptimismimportSparseMatrixAssemblerfromoptimismimportFunctionSpacefromoptimismimportMeshfromoptimism.TensorMathimporttensor_2D_to_3DfromoptimismimportQuadratureRulefromoptimismimportInterpolantsfromtypingimportCallableimportequinoxaseqximportjax# TODO# eventually let's move to some kind of class hierarchy like below# that we can derive off of with shared behavior## normal python inheritance rules apply to equinox Modules# class BaseFunctions(eqx.Module):# compute_output_energy_densities_and_stresses: Callable# compute_initial_state: Callable# TODO further type below so Callable refelcts the actual called arguments and returns
[docs]classMechanicsFunctions(eqx.Module):compute_strain_energy:Callablecompute_updated_internal_variables:Callablecompute_element_stiffnesses:Callablecompute_output_energy_densities_and_stresses:Callablecompute_initial_state:Callableintegrated_material_qoi:Callable# scalar material point QoI integrated over the domaincompute_output_material_qoi:Callable# array of scalar material point QoI computed at each quadrature point
[docs]classDynamicsFunctions(eqx.Module):compute_algorithmic_energy:Callablecompute_updated_internal_variables:Callablecompute_element_hessians:Callablecompute_output_energy_densities_and_stresses:Callablecompute_output_kinetic_energy:Callablecompute_output_strain_energy:Callablecompute_initial_state:Callablecompute_element_masses:Callable# not used for time integration, provided for convenience (spectral analysis, eg)predict:Callablecorrect:Callable
# TODO once we map thingst o equinox classes, make these methods bound to the class# TODO use jax.lax.cond below to make this jit safe# aprently this is unjittable
[docs]def_compute_updated_internal_variables(functionSpace,U,states,props,dt,compute_state_new,modify_element_gradient):# U -> (n_nodes, n_dims) -> Nodal field# state -> (n_els, n_quadrature_points, n_states) -> Quadrature fielddispGrads=FunctionSpace.compute_field_gradient(functionSpace,U,modify_element_gradient)# dispGrads -> (n_els, n_quadrature_points, n_dims, n_dims) -> Quadrature fielddgQuadPointRavel=dispGrads.reshape(dispGrads.shape[0]*dispGrads.shape[1],*dispGrads.shape[2:])# dgQuadPointRavel -> (n_els * n_quadrature_points, n_dims, n_dims) -> Quadrature fieldstQuadPointRavel=states.reshape(states.shape[0]*states.shape[1],*states.shape[2:])prop_vmap_axes=vmapPropValue(props)# -> 0 - vmap over all quadrature points for properties or None - don't vmap over quadrature points for propertiesnew_props=tile_props(props,dispGrads.shape[0],dispGrads.shape[1])# -> (n_els * n_quadrature_pts, n_props) or (n_props,)statesNew=vmap(compute_state_new,(0,0,prop_vmap_axes,None))(dgQuadPointRavel,stQuadPointRavel,new_props,dt)returnstatesNew.reshape(states.shape)
# TODO add props
[docs]def_compute_updated_internal_variables_multi_block(functionSpace,U,states,props,dt,blockModels,modify_element_gradient):dispGrads=FunctionSpace.compute_field_gradient(functionSpace,U,modify_element_gradient)statesNew=np.array(states)forblockKeyinblockModels:elemIds=functionSpace.mesh.blocks[blockKey]blockDispGrads=dispGrads[elemIds]blockStates=states[elemIds]blockProps=props[blockKey]prop_vmap_axes=vmapPropValue(blockProps)# -> 0 - vmap over all quadrature points for properties or None - don't vmap over quadrature points for propertiesnew_props=tile_props(blockProps,dispGrads.shape[0],dispGrads.shape[1])# -> (n_els * n_quadrature_pts, n_props) or (n_props,)compute_state_new=blockModels[blockKey].compute_state_newdgQuadPointRavel=blockDispGrads.reshape(blockDispGrads.shape[0]*blockDispGrads.shape[1],*blockDispGrads.shape[2:])stQuadPointRavel=blockStates.reshape(blockStates.shape[0]*blockStates.shape[1],-1)blockStatesNew=vmap(compute_state_new,(0,0,prop_vmap_axes,None))(dgQuadPointRavel,stQuadPointRavel,new_props,dt).reshape(blockStates.shape)statesNew=statesNew.at[elemIds,:,:blockStatesNew.shape[2]].set(blockStatesNew)returnstatesNew
[docs]def_compute_initial_state_multi_block(fs,blockModels):numQuadPoints=len(fs.quadratureRule)# Store the same number of state variables for every material to make# vmapping easy.## TODO(talamini1): Fix this so that every material only stores what it# needs and doesn't waste memory.## To do this, walk through each material and query number of state# variables. Use max to allocate the global state variable array.numStateVariables=1forblockKeyinblockModels:numStateVariablesForBlock=blockModels[blockKey].compute_initial_state().shape[0]numStateVariables=max(numStateVariables,numStateVariablesForBlock)initialState=np.zeros((Mesh.num_elements(fs.mesh),numQuadPoints,numStateVariables))forblockKeyinblockModels:elemIds=fs.mesh.blocks[blockKey]state=blockModels[blockKey].compute_initial_state()blockInitialState=np.tile(state,(elemIds.size,numQuadPoints,1))initialState=initialState.at[elemIds,:,:blockInitialState.shape[2]].set(blockInitialState)returninitialState
[docs]defcreate_mechanics_functions(functionSpace,mode2D,materialModel,pressureProjectionDegree=None):fs=functionSpaceifmode2D=='plane strain':grad_2D_to_3D=plane_strain_gradient_transformationelifmode2D=='axisymmetric':grad_2D_to_3D=axisymmetric_element_gradient_transformationelse:raisemodify_element_gradient=grad_2D_to_3DifpressureProjectionDegreeisnotNone:masterJ=Interpolants.make_master_tri_element(degree=pressureProjectionDegree)xigauss=functionSpace.quadratureRule.xigaussshapesJ=Interpolants.compute_shapes_on_tri(masterJ,xigauss)defmodify_element_gradient(elemGrads,elemShapes,elemVols,elemNodalDisps,elemNodalCoords):elemGrads=volume_average_J_gradient_transformation(elemGrads,elemVols,shapesJ)returngrad_2D_to_3D(elemGrads,elemShapes,elemVols,elemNodalDisps,elemNodalCoords)defcompute_strain_energy(U,stateVariables,props,dt=0.0):return_compute_strain_energy(fs,U,stateVariables,props,dt,materialModel.compute_energy_density,modify_element_gradient)# TODO add propsdefcompute_updated_internal_variables(U,stateVariables,props,dt=0.0):return_compute_updated_internal_variables(fs,U,stateVariables,props,dt,materialModel.compute_state_new,modify_element_gradient)defcompute_element_stiffnesses(U,stateVariables,props,dt=0.0):return_compute_element_stiffnesses(U,stateVariables,props,dt,fs,materialModel.compute_energy_density,modify_element_gradient)output_lagrangian=strain_energy_density_to_lagrangian_density(materialModel.compute_energy_density)output_constitutive=value_and_grad(output_lagrangian,1)defcompute_output_energy_densities_and_stresses(U,stateVariables,props,dt=0.0):returnFunctionSpace.evaluate_on_block(fs,U,stateVariables,props,dt,output_constitutive,slice(None),modify_element_gradient=modify_element_gradient)defcompute_initial_state():shape=Mesh.num_elements(fs.mesh),len(fs.quadratureRule),1returnnp.tile(materialModel.compute_initial_state(),shape)deflagrangian_qoi(U,gradU,Q,props,X,dt):returnmaterialModel.compute_material_qoi(gradU,Q,props,dt)defintegrated_material_qoi(U,stateVariables,props,dt=0.0):returnFunctionSpace.integrate_over_block(fs,U,stateVariables,props,dt,lagrangian_qoi,slice(None),modify_element_gradient=modify_element_gradient)defcompute_output_material_qoi(U,stateVariables,props,dt=0.0):returnFunctionSpace.evaluate_on_block(fs,U,stateVariables,props,dt,lagrangian_qoi,slice(None),modify_element_gradient=modify_element_gradient)returnMechanicsFunctions(compute_strain_energy,jit(compute_updated_internal_variables),jit(compute_element_stiffnesses),jit(compute_output_energy_densities_and_stresses),compute_initial_state,integrated_material_qoi,jit(compute_output_material_qoi))
# TODO need to update this for props. Eigen won't work otherwise
# TODO need to update this for props. Eigen won't work otherwise
[docs]def_compute_element_masses(functionSpace,U,internals,props,density,modify_element_gradient):deflagrangian_density(V,gradV,Q,props,X,dt):returnkinetic_energy_density(V,density)prop_vmap_axes=vmapPropValue(props)# -> 0 - vmap over all quadrature points for properties or None - don't vmap over quadrature points for propertiesnew_props=tile_props(props,internals.shape[0],internals.shape[1])# -> (n_els * n_quadrature_pts, n_props) or (n_props,)f=vmap(compute_element_stiffness_from_global_fields,(None,None,0,prop_vmap_axes,None,0,0,0,0,None,None))fs=functionSpaceunusedDt=0.0returnf(U,fs.mesh.coords,internals,new_props,unusedDt,fs.mesh.conns,fs.shapes,fs.shapeGrads,fs.vols,lagrangian_density,modify_element_gradient)
[docs]defcompute_newmark_lagrangian(functionSpace,U,UPredicted,internals,props,density,dt,newmarkBeta,strain_energy_density,modify_element_gradient):# We can't quite fuse these kernels because KE uses the velocity field and# the strain energy uses the displacements. If profiling suggests fusing# is beneficial, we could add the time derivative field to the Lagrangian# density definition.deflagrangian_density(W,gradW,Q,props,X,dtime):returnkinetic_energy_density(W,density)KE=FunctionSpace.integrate_over_block(functionSpace,U-UPredicted,internals,props,dt,lagrangian_density,slice(None))KE*=1/(newmarkBeta*dt**2)lagrangian_density=strain_energy_density_to_lagrangian_density(strain_energy_density)SE=FunctionSpace.integrate_over_block(functionSpace,U,internals,props,dt,lagrangian_density,slice(None),modify_element_gradient=modify_element_gradient)returnSE+KE
[docs]def_compute_newmark_element_hessians(functionSpace,U,UPredicted,internals,props,density,dt,newmarkBeta,strain_energy_density,modify_element_gradient):deflagrangian_density(W,gradW,Q,props,X,dtime):returnkinetic_energy_density(W,density)/(newmarkBeta*dtime**2)+strain_energy_density(gradW,Q,props,dtime)prop_vmap_axes=vmapPropValue(props)# -> 0 - vmap over all quadrature points for properties or None - don't vmap over quadrature points for propertiesnew_props=tile_props(props,internals.shape[0],internals.shape[1])# -> (n_els * n_quadrature_pts, n_props) or (n_props,)f=vmap(compute_element_stiffness_from_global_fields,(None,None,0,prop_vmap_axes,None,0,0,0,0,None,None))fs=functionSpaceUAlgorithmic=U-UPredictedreturnf(UAlgorithmic,fs.mesh.coords,internals,new_props,dt,fs.mesh.conns,fs.shapes,fs.shapeGrads,fs.vols,lagrangian_density,modify_element_gradient)
[docs]defparse_2D_to_3D_gradient_transformation(mode2D):ifmode2D=='plane strain':grad_2D_to_3D=plane_strain_gradient_transformationelifmode2D=='axisymmetric':grad_2D_to_3D=axisymmetric_element_gradient_transformationelse:raiseValueError("Unrecognized value for mode2D")returngrad_2D_to_3D
# TODO need to update this for props. Eigen won't work otherwise
[docs]defcompute_traction_potential_energy(fs,U,quadRule,edges,load):"""Compute potential energy of surface tractions. Arguments: fs: a FunctionSpace object U: the nodal displacements quadRule: the 1D quadrature rule to use for the integration edges: array of edges, each row is an edge. Each edge has two entries, the element ID, and the permutation of that edge in the triangle (0, 1, 2). load: Callable that returns the traction vector. The signature is load(X, n), where X is coordinates of a material point, and n is the outward unit normal. time: current time """defcompute_energy_density(u,X,n):traction=load(X,n)return-np.dot(u,traction)returnFunctionSpace.integrate_function_on_edges(fs,compute_energy_density,U,quadRule,edges)