Chunked Assembly#

Large meshes can cause JAX to recompile a new kernel every time the global assembly trim of data changes (e.g., the last chunk is smaller than the others). Chunked assembly puts a ceiling on the per-call compilation shape by splitting the mesh into n_chunks pieces and processing each chunk with a fixed-size, padded batch.

Why chunked assembly?#

  • Stable JIT shapes: Each chunk is padded to a constant chunk_size so JAX sees the same array dimensions during tracing, rather than a barely smaller last element batch that would trigger recompilation.

  • Smaller working sets: Chunking keeps the per-call intermediate footprint bounded, which can reduce memory pressure when assembling very large meshes (or running batched jit``+``vmap loops).

  • Optional tracing hints: pad_trace=True emits statistics about the chosen chunk layout, which helps tune n_chunks for your problem.

How to use n_chunks#

Chunking is configured via AssemblyPolicy. This policy is accepted by both functional APIs (fluxfem.core.assemble_*) and space.assemble* helpers. Set n_chunks to a positive integer:

policy = ff.AssemblyPolicy.chunked(
    n_chunks=16,
    pad_trace=True,
)
K = space.assemble(ff.diffusion_form, params=1.0, policy=policy)

Note

space.assemble(..., n_chunks=..., pad_trace=...) style kwargs are no longer supported. Pass tuning options through AssemblyPolicy.

FluxFEM automatically chooses chunk_size = ceil(n_elems / n_chunks) and rounds up the padded element data to the nearest multiple of that size. During chunked assembly the padded elements are evaluated but never used in the final sparse matrix because the final data vector is sliced back to n_elems * n_ldofs^2 (or n_elems * n_ldofs for RHS assembly).

Best practices#

  • Pick powers of two when possible; they keep chunk_size regular and mesh partitions predictable.

  • Set ``n_chunks`` based on mesh size; values larger than n_elems are automatically clamped to n_elems.

  • Combine with ``pad_trace=True`` when tuning. The trace output reports how many elements were padded so you can balance between padding overhead and JIT stability.

  • Use the same ``n_chunks`` for residual/Jacobian pairs so the sparsity pattern remains stable for nonlinear solves.

Chunked assembly is a lightweight way to improve JAX stability without rewriting your weak form or introducing new scalar-friendly kernels. When n_chunks is None (the default) the entire mesh is assembled at once, so you only need to set the option once you see JIT recompile churn or trace warnings for variable-sized inputs.