Kernel-based Assembly (JIT-Friendly)#
This tutorial shows how to assemble using explicit, JIT-compiled element kernels. The idea is to separate:
element kernels (what gets JIT-compiled), and
assembly (scatter into global vectors/matrices).
This makes the JIT boundary visible and reusable.
Core idea#
make_element_*_kernel returns a JIT-compiled function that operates per element.
You can pass that kernel to space.assemble_*.
The kernel must return the integrated element contribution (not the integrand).
By contrast, a form (form(ctx, params)) returns the per-quadrature
integrand, and the assembler applies wJ and sums over quadrature points.
If you pass an untagged raw kernel to space.assemble with kind=,
FluxFEM emits a one-time warning; prefer tagging with @ff.kernel.
import fluxfem as ff
import jax
import jax.numpy as jnp
mesh = ff.StructuredHexBox(nx=2, ny=1, nz=1, lx=1.0, ly=1.0, lz=1.0).build()
space = ff.make_hex_space(mesh, dim=1, intorder=2)
# bilinear: kernel(ctx) -> (n_ldofs, n_ldofs)
ker_K = ff.make_element_bilinear_kernel(ff.diffusion_form, 1.0, jit=True)
K = space.assemble(ff.diffusion_form, 1.0, kernel=ker_K)
# linear: kernel(ctx) -> (n_ldofs,)
def linear_kernel(ctx):
integrand = ff.scalar_body_force_form(ctx, 2.0)
wJ = ctx.w * ctx.test.detJ
return (integrand * wJ[:, None]).sum(axis=0)
ker_F = jax.jit(linear_kernel)
F = space.assemble(ff.scalar_body_force_form, 2.0, kernel=ker_F)
You can also use the unified entry point for all kernel kinds:
# bilinear: kernel(ctx) -> (n_ldofs, n_ldofs)
ker_K = ff.make_element_kernel(ff.diffusion_form, 1.0, kind="bilinear")
K = space.assemble(ff.diffusion_form, 1.0, kernel=ker_K)
# linear: kernel(ctx) -> (n_ldofs,)
ker_F = ff.make_element_kernel(ff.scalar_body_force_form, 2.0, kind="linear")
F = space.assemble(ff.scalar_body_force_form, 2.0, kernel=ker_F)
# residual: kernel(ctx, u_elem) -> (n_ldofs,)
ker_R = ff.make_element_kernel(res_form, params, kind="residual")
R = space.assemble_residual(res_form, u, params, kernel=ker_R)
# jacobian: kernel(u_elem, ctx) -> (n_ldofs, n_ldofs)
ker_J = ff.make_element_kernel(res_form, params, kind="jacobian")
J = space.assemble_jacobian(res_form, u, params, kernel=ker_J)
J_dense = J.to_dense()
Using compiled DSL forms#
The kernel helpers also work with weak-form DSL objects as long as you pass
the compiled form (get_compiled()).
import fluxfem as ff
import fluxfem.helpers_wf as h_wf
form_wf = ff.BilinearForm.volume(
lambda u, v, p: p.kappa * (v.grad @ u.grad) * h_wf.dOmega()
).get_compiled()
params = ff.Params(kappa=1.0)
ker = ff.make_element_bilinear_kernel(form_wf, params, jit=True)
K = space.assemble(form_wf, params, kernel=ker)
Residual and Jacobian#
Residual kernels take (ctx, u_elem), while Jacobian kernels take
(u_elem, ctx) to match JAX’s jacrev convention.
def simple_residual(ctx, u_elem, _params):
# Return integrand only; assembly applies wJ and sums.
return jnp.broadcast_to(u_elem, (ctx.w.shape[0], u_elem.shape[0]))
u = jnp.zeros(space.n_dofs)
ker_R = ff.make_element_residual_kernel(simple_residual, params=None)
ker_J = ff.make_element_jacobian_kernel(simple_residual, params=None)
R = space.assemble_residual(simple_residual, u, params=None, kernel=ker_R)
J = space.assemble_jacobian(simple_residual, u, params=None, kernel=ker_J)
J_dense = J.to_dense()
Notes#
The kernel controls JIT compilation. If shapes are stable, the compiled code is reused across calls.
You can combine this with chunked assembly (
policy=ff.AssemblyPolicy.chunked(...)) to stabilize compilation when the last chunk is smaller.