Kernel-based Assembly (JIT-Friendly)#

This tutorial shows how to assemble using explicit, JIT-compiled element kernels. The idea is to separate:

  • element kernels (what gets JIT-compiled), and

  • assembly (scatter into global vectors/matrices).

This makes the JIT boundary visible and reusable.

Core idea#

make_element_*_kernel returns a JIT-compiled function that operates per element. You can pass that kernel to space.assemble_*.

The kernel must return the integrated element contribution (not the integrand). By contrast, a form (form(ctx, params)) returns the per-quadrature integrand, and the assembler applies wJ and sums over quadrature points. If you pass an untagged raw kernel to space.assemble with kind=, FluxFEM emits a one-time warning; prefer tagging with @ff.kernel.

import fluxfem as ff
import jax
import jax.numpy as jnp

mesh = ff.StructuredHexBox(nx=2, ny=1, nz=1, lx=1.0, ly=1.0, lz=1.0).build()
space = ff.make_hex_space(mesh, dim=1, intorder=2)

# bilinear: kernel(ctx) -> (n_ldofs, n_ldofs)
ker_K = ff.make_element_bilinear_kernel(ff.diffusion_form, 1.0, jit=True)
K = space.assemble(ff.diffusion_form, 1.0, kernel=ker_K)

# linear: kernel(ctx) -> (n_ldofs,)
def linear_kernel(ctx):
    integrand = ff.scalar_body_force_form(ctx, 2.0)
    wJ = ctx.w * ctx.test.detJ
    return (integrand * wJ[:, None]).sum(axis=0)

ker_F = jax.jit(linear_kernel)
F = space.assemble(ff.scalar_body_force_form, 2.0, kernel=ker_F)

You can also use the unified entry point for all kernel kinds:

# bilinear: kernel(ctx) -> (n_ldofs, n_ldofs)
ker_K = ff.make_element_kernel(ff.diffusion_form, 1.0, kind="bilinear")
K = space.assemble(ff.diffusion_form, 1.0, kernel=ker_K)

# linear: kernel(ctx) -> (n_ldofs,)
ker_F = ff.make_element_kernel(ff.scalar_body_force_form, 2.0, kind="linear")
F = space.assemble(ff.scalar_body_force_form, 2.0, kernel=ker_F)

# residual: kernel(ctx, u_elem) -> (n_ldofs,)
ker_R = ff.make_element_kernel(res_form, params, kind="residual")
R = space.assemble_residual(res_form, u, params, kernel=ker_R)

# jacobian: kernel(u_elem, ctx) -> (n_ldofs, n_ldofs)
ker_J = ff.make_element_kernel(res_form, params, kind="jacobian")
J = space.assemble_jacobian(res_form, u, params, kernel=ker_J)
J_dense = J.to_dense()

Using compiled DSL forms#

The kernel helpers also work with weak-form DSL objects as long as you pass the compiled form (get_compiled()).

import fluxfem as ff
import fluxfem.helpers_wf as h_wf

form_wf = ff.BilinearForm.volume(
    lambda u, v, p: p.kappa * (v.grad @ u.grad) * h_wf.dOmega()
).get_compiled()

params = ff.Params(kappa=1.0)
ker = ff.make_element_bilinear_kernel(form_wf, params, jit=True)
K = space.assemble(form_wf, params, kernel=ker)

Residual and Jacobian#

Residual kernels take (ctx, u_elem), while Jacobian kernels take (u_elem, ctx) to match JAX’s jacrev convention.

def simple_residual(ctx, u_elem, _params):
    # Return integrand only; assembly applies wJ and sums.
    return jnp.broadcast_to(u_elem, (ctx.w.shape[0], u_elem.shape[0]))

u = jnp.zeros(space.n_dofs)
ker_R = ff.make_element_residual_kernel(simple_residual, params=None)
ker_J = ff.make_element_jacobian_kernel(simple_residual, params=None)

R = space.assemble_residual(simple_residual, u, params=None, kernel=ker_R)
J = space.assemble_jacobian(simple_residual, u, params=None, kernel=ker_J)
J_dense = J.to_dense()

Notes#

  • The kernel controls JIT compilation. If shapes are stable, the compiled code is reused across calls.

  • You can combine this with chunked assembly (policy=ff.AssemblyPolicy.chunked(...)) to stabilize compilation when the last chunk is smaller.