Skip to content

hook_dsl Module

Domain-Specific Language for simulation hooks.

Overview

The hook_dsl module provides the core components for the declarative hook system.

Complete Module Reference

natal.hook_dsl

Backward-compatible hook DSL imports.

The implementation now lives under natal.hooks. This module remains as a stable compatibility layer for existing imports.

CompiledEventHooks

CompiledEventHooks()

Container for event-wise combined hook callables.

Kernel code expects one callable per event name. This class stores those callables and optionally the declarative HookProgram registry. When hooks are present and Numba is enabled, lifecycle wrappers are pre-compiled with hooks as globals so Numba caching survives restarts.

Source code in src/natal/hooks/compiler.py
def __init__(self) -> None:
    self.first = _noop_hook
    self.early = _noop_hook
    self.late = _noop_hook
    self.finish = _noop_hook
    self.registry = None
    self.run_tick_fn = None
    self.run_fn = None
    self.run_discrete_tick_fn = None
    self.run_discrete_fn = None
    self.spatial_tick_fn = None
    self.spatial_run_fn = None
    self.spatial_discrete_tick_fn = None
    self.spatial_discrete_run_fn = None
    self._event_hooks = dict.fromkeys(EVENT_NAMES, _noop_hook)
from_compiled_hooks staticmethod
from_compiled_hooks(compiled_hooks: List[CompiledHookDescriptor], registry: Optional[HookProgram] = None, include_spatial_wrappers: bool = False) -> CompiledEventHooks

Build event-wise combined callables and lifecycle wrappers.

Unlike the previous Jinja2-codegen approach, this method generates only the necessary lifecycle wrapper per hook combination using compile_lifecycle_wrapper, which produces a uniquely-named njit function with hooks as globals. This ensures Numba cache=True works across process restarts.

Source code in src/natal/hooks/compiler.py
@staticmethod
def from_compiled_hooks(
    compiled_hooks: List[CompiledHookDescriptor],
    registry: Optional[HookProgram] = None,
    include_spatial_wrappers: bool = False,
) -> CompiledEventHooks:
    """Build event-wise combined callables and lifecycle wrappers.

    Unlike the previous Jinja2-codegen approach, this method generates
    only the necessary lifecycle wrapper per hook combination using
    ``compile_lifecycle_wrapper``, which produces a uniquely-named njit
    function with hooks as globals. This ensures Numba ``cache=True``
    works across process restarts.
    """
    from ..numba_utils import NUMBA_ENABLED

    if NUMBA_ENABLED:
        for desc in compiled_hooks:
            if desc.py_wrapper is not None:
                raise TypeError(
                    f"Hook '{desc.name}' uses py_wrapper, which is not allowed when Numba is enabled."
                )

    result = CompiledEventHooks()
    result.registry = registry
    hooks_by_event: Dict[str, List[Tuple[int, HookCallable, DemeSelector]]] = {name: [] for name in EVENT_NAMES}

    for desc in compiled_hooks:
        if desc.njit_fn is not None and desc.event in hooks_by_event:
            hooks_by_event[desc.event].append((desc.priority, desc.njit_fn, desc.deme_selector))

    for event_name, hook_list in hooks_by_event.items():
        if hook_list:
            hook_list.sort(key=lambda x: x[0])
            njit_fns = [fn for _, fn, _ in hook_list]
            deme_selectors = cast("List[DemeSelector]", [ds for _, _, ds in hook_list])
            combined = compile_combined_hook(njit_fns, deme_selectors)
            result.set_hook(event_name, combined)

    # Pre-compile lifecycle wrappers per hook combination so Numba
    # caches the compilation across process restarts. The wrappers use
    # module-level globals for the combined hooks instead of function
    # parameters, giving each combination a unique source-code hash.
    first_hook = result.first
    early_hook = result.early
    late_hook = result.late

    # Always compile lifecycle wrappers when Numba is enabled so the
    # population model can use them unconditionally.  Even with zero
    # user hooks the wrapper compiles with _noop_hook globals, and its
    # source hash stays stable across runs.
    if NUMBA_ENABLED:
        result.run_tick_fn, result.run_fn = compile_lifecycle_wrapper(
            False, first_hook, early_hook, late_hook,
        )
        result.run_discrete_tick_fn, result.run_discrete_fn = compile_lifecycle_wrapper(
            True, first_hook, early_hook, late_hook,
        )

        if include_spatial_wrappers:
            result.spatial_tick_fn, result.spatial_run_fn = compile_spatial_lifecycle_wrapper(
                False, first_hook, early_hook, late_hook,
            )
            result.spatial_discrete_tick_fn, result.spatial_discrete_run_fn = compile_spatial_lifecycle_wrapper(
                True, first_hook, early_hook, late_hook,
            )

    return result

Op

Factory helpers for building declarative operations.

The methods here only build data objects and do not touch population state. Compilation happens later in compile_declarative_hook.

scale staticmethod
scale(genotypes: Union[str, List[str], Literal['*']] = '*', ages: Union[int, List[int], range, Literal['*']] = '*', sex: Literal['female', 'male', 'both'] = 'both', factor: float = 1.0, when: Optional[str] = None) -> HookOp

Create a scaling operation that multiplies counts by a factor.

Parameters:

Name Type Description Default
genotypes Union[str, List[str], Literal['*']]

Genotype selector ("*" for all, specific genotype, or list)

'*'
ages Union[int, List[int], range, Literal['*']]

Age selector ("*" for all, specific age, range, or list)

'*'
sex Literal['female', 'male', 'both']

Sex selector ("female", "male", or "both")

'both'
factor float

Scaling factor (e.g., 0.5 halves the count, 2.0 doubles it)

1.0
when Optional[str]

Optional condition expression (e.g., "tick >= 100")

None

Returns:

Name Type Description
HookOp HookOp

Operation descriptor for compilation

Source code in src/natal/hooks/declarative.py
@staticmethod
def scale(
    genotypes: Union[str, List[str], Literal["*"]] = "*",
    ages: Union[int, List[int], range, Literal["*"]] = "*",
    sex: Literal["female", "male", "both"] = "both",
    factor: float = 1.0,
    when: Optional[str] = None,
) -> HookOp:
    """Create a scaling operation that multiplies counts by a factor.

    Args:
        genotypes: Genotype selector ("*" for all, specific genotype, or list)
        ages: Age selector ("*" for all, specific age, range, or list)
        sex: Sex selector ("female", "male", or "both")
        factor: Scaling factor (e.g., 0.5 halves the count, 2.0 doubles it)
        when: Optional condition expression (e.g., "tick >= 100")

    Returns:
        HookOp: Operation descriptor for compilation
    """
    return HookOp(OpType.SCALE, genotypes, ages, sex, factor, when)
set_count staticmethod
set_count(genotypes: Union[str, List[str], Literal['*']] = '*', ages: Union[int, List[int], range, Literal['*']] = '*', sex: Literal['female', 'male', 'both'] = 'both', value: float = 0.0, when: Optional[str] = None) -> HookOp

Create an operation that sets counts to a specific value.

Parameters:

Name Type Description Default
genotypes Union[str, List[str], Literal['*']]

Genotype selector

'*'
ages Union[int, List[int], range, Literal['*']]

Age selector

'*'
sex Literal['female', 'male', 'both']

Sex selector

'both'
value float

Target count value (individuals will be added/removed to match)

0.0
when Optional[str]

Optional condition expression

None

Returns:

Name Type Description
HookOp HookOp

Operation descriptor for compilation

Source code in src/natal/hooks/declarative.py
@staticmethod
def set_count(
    genotypes: Union[str, List[str], Literal["*"]] = "*",
    ages: Union[int, List[int], range, Literal["*"]] = "*",
    sex: Literal["female", "male", "both"] = "both",
    value: float = 0.0,
    when: Optional[str] = None,
) -> HookOp:
    """Create an operation that sets counts to a specific value.

    Args:
        genotypes: Genotype selector
        ages: Age selector
        sex: Sex selector
        value: Target count value (individuals will be added/removed to match)
        when: Optional condition expression

    Returns:
        HookOp: Operation descriptor for compilation
    """
    return HookOp(OpType.SET, genotypes, ages, sex, value, when)
add staticmethod
add(genotypes: Union[str, List[str], Literal['*']] = '*', ages: Union[int, List[int], range, Literal['*']] = '*', sex: Literal['female', 'male', 'both'] = 'both', delta: float = 0.0, when: Optional[str] = None) -> HookOp

Create an operation that adds a fixed number of individuals.

Parameters:

Name Type Description Default
genotypes Union[str, List[str], Literal['*']]

Genotype selector

'*'
ages Union[int, List[int], range, Literal['*']]

Age selector

'*'
sex Literal['female', 'male', 'both']

Sex selector

'both'
delta float

Number of individuals to add (can be negative to remove)

0.0
when Optional[str]

Optional condition expression

None

Returns:

Name Type Description
HookOp HookOp

Operation descriptor for compilation

Source code in src/natal/hooks/declarative.py
@staticmethod
def add(
    genotypes: Union[str, List[str], Literal["*"]] = "*",
    ages: Union[int, List[int], range, Literal["*"]] = "*",
    sex: Literal["female", "male", "both"] = "both",
    delta: float = 0.0,
    when: Optional[str] = None,
) -> HookOp:
    """Create an operation that adds a fixed number of individuals.

    Args:
        genotypes: Genotype selector
        ages: Age selector
        sex: Sex selector
        delta: Number of individuals to add (can be negative to remove)
        when: Optional condition expression

    Returns:
        HookOp: Operation descriptor for compilation
    """
    return HookOp(OpType.ADD, genotypes, ages, sex, delta, when)
subtract staticmethod
subtract(genotypes: Union[str, List[str], Literal['*']] = '*', ages: Union[int, List[int], range, Literal['*']] = '*', sex: Literal['female', 'male', 'both'] = 'both', delta: float = 0.0, when: Optional[str] = None) -> HookOp

Create an operation that subtracts a fixed number of individuals.

Parameters:

Name Type Description Default
genotypes Union[str, List[str], Literal['*']]

Genotype selector

'*'
ages Union[int, List[int], range, Literal['*']]

Age selector

'*'
sex Literal['female', 'male', 'both']

Sex selector

'both'
delta float

Number of individuals to subtract

0.0
when Optional[str]

Optional condition expression

None

Returns:

Name Type Description
HookOp HookOp

Operation descriptor for compilation

Source code in src/natal/hooks/declarative.py
@staticmethod
def subtract(
    genotypes: Union[str, List[str], Literal["*"]] = "*",
    ages: Union[int, List[int], range, Literal["*"]] = "*",
    sex: Literal["female", "male", "both"] = "both",
    delta: float = 0.0,
    when: Optional[str] = None,
) -> HookOp:
    """Create an operation that subtracts a fixed number of individuals.

    Args:
        genotypes: Genotype selector
        ages: Age selector
        sex: Sex selector
        delta: Number of individuals to subtract
        when: Optional condition expression

    Returns:
        HookOp: Operation descriptor for compilation
    """
    return HookOp(OpType.SUBTRACT, genotypes, ages, sex, delta, when)
kill staticmethod
kill(genotypes: Union[str, List[str], Literal['*']] = '*', ages: Union[int, List[int], range, Literal['*']] = '*', sex: Literal['female', 'male', 'both'] = 'both', prob: float = 0.0, when: Optional[str] = None) -> HookOp

Create a probabilistic killing operation.

Parameters:

Name Type Description Default
genotypes Union[str, List[str], Literal['*']]

Genotype selector

'*'
ages Union[int, List[int], range, Literal['*']]

Age selector

'*'
sex Literal['female', 'male', 'both']

Sex selector

'both'
prob float

Probability of killing each selected individual (0.0 to 1.0)

0.0
when Optional[str]

Optional condition expression

None

Returns:

Name Type Description
HookOp HookOp

Operation descriptor for compilation

Raises:

Type Description
ValueError

If probability is not in [0, 1]

Source code in src/natal/hooks/declarative.py
@staticmethod
def kill(
    genotypes: Union[str, List[str], Literal["*"]] = "*",
    ages: Union[int, List[int], range, Literal["*"]] = "*",
    sex: Literal["female", "male", "both"] = "both",
    prob: float = 0.0,
    when: Optional[str] = None,
) -> HookOp:
    """Create a probabilistic killing operation.

    Args:
        genotypes: Genotype selector
        ages: Age selector
        sex: Sex selector
        prob: Probability of killing each selected individual (0.0 to 1.0)
        when: Optional condition expression

    Returns:
        HookOp: Operation descriptor for compilation

    Raises:
        ValueError: If probability is not in [0, 1]
    """
    if not 0.0 <= prob <= 1.0:
        raise ValueError(f"prob must be in [0, 1], got {prob}")
    return HookOp(OpType.KILL, genotypes, ages, sex, prob, when)
sample staticmethod
sample(genotypes: Union[str, List[str], Literal['*']] = '*', ages: Union[int, List[int], range, Literal['*']] = '*', sex: Literal['female', 'male', 'both'] = 'both', size: int = 0, when: Optional[str] = None) -> HookOp

Create a sampling operation that selects individuals without replacement.

Parameters:

Name Type Description Default
genotypes Union[str, List[str], Literal['*']]

Genotype selector

'*'
ages Union[int, List[int], range, Literal['*']]

Age selector

'*'
sex Literal['female', 'male', 'both']

Sex selector

'both'
size int

Number of individuals to sample

0
when Optional[str]

Optional condition expression

None

Returns:

Name Type Description
HookOp HookOp

Operation descriptor for compilation

Source code in src/natal/hooks/declarative.py
@staticmethod
def sample(
    genotypes: Union[str, List[str], Literal["*"]] = "*",
    ages: Union[int, List[int], range, Literal["*"]] = "*",
    sex: Literal["female", "male", "both"] = "both",
    size: int = 0,
    when: Optional[str] = None,
) -> HookOp:
    """Create a sampling operation that selects individuals without replacement.

    Args:
        genotypes: Genotype selector
        ages: Age selector
        sex: Sex selector
        size: Number of individuals to sample
        when: Optional condition expression

    Returns:
        HookOp: Operation descriptor for compilation
    """
    return HookOp(OpType.SAMPLE, genotypes, ages, sex, float(size), when)
stop_if_zero staticmethod
stop_if_zero(genotypes: Union[str, List[str], Literal['*']] = '*', ages: Union[int, List[int], range, Literal['*']] = '*', sex: Literal['female', 'male', 'both'] = 'both', when: Optional[str] = None) -> HookOp

Create an operation that stops the simulation if selected count reaches zero.

Parameters:

Name Type Description Default
genotypes Union[str, List[str], Literal['*']]

Genotype selector

'*'
ages Union[int, List[int], range, Literal['*']]

Age selector

'*'
sex Literal['female', 'male', 'both']

Sex selector

'both'
when Optional[str]

Optional condition expression

None

Returns:

Name Type Description
HookOp HookOp

Operation descriptor for compilation

Source code in src/natal/hooks/declarative.py
@staticmethod
def stop_if_zero(
    genotypes: Union[str, List[str], Literal["*"]] = "*",
    ages: Union[int, List[int], range, Literal["*"]] = "*",
    sex: Literal["female", "male", "both"] = "both",
    when: Optional[str] = None,
) -> HookOp:
    """Create an operation that stops the simulation if selected count reaches zero.

    Args:
        genotypes: Genotype selector
        ages: Age selector
        sex: Sex selector
        when: Optional condition expression

    Returns:
        HookOp: Operation descriptor for compilation
    """
    return HookOp(OpType.STOP_IF_ZERO, genotypes, ages, sex, 0.0, when)
stop_if_below staticmethod
stop_if_below(genotypes: Union[str, List[str], Literal['*']] = '*', ages: Union[int, List[int], range, Literal['*']] = '*', sex: Literal['female', 'male', 'both'] = 'both', threshold: float = 1.0, when: Optional[str] = None) -> HookOp

Create an operation that stops the simulation if count falls below threshold.

Parameters:

Name Type Description Default
genotypes Union[str, List[str], Literal['*']]

Genotype selector

'*'
ages Union[int, List[int], range, Literal['*']]

Age selector

'*'
sex Literal['female', 'male', 'both']

Sex selector

'both'
threshold float

Minimum count threshold

1.0
when Optional[str]

Optional condition expression

None

Returns:

Name Type Description
HookOp HookOp

Operation descriptor for compilation

Source code in src/natal/hooks/declarative.py
@staticmethod
def stop_if_below(
    genotypes: Union[str, List[str], Literal["*"]] = "*",
    ages: Union[int, List[int], range, Literal["*"]] = "*",
    sex: Literal["female", "male", "both"] = "both",
    threshold: float = 1.0,
    when: Optional[str] = None,
) -> HookOp:
    """Create an operation that stops the simulation if count falls below threshold.

    Args:
        genotypes: Genotype selector
        ages: Age selector
        sex: Sex selector
        threshold: Minimum count threshold
        when: Optional condition expression

    Returns:
        HookOp: Operation descriptor for compilation
    """
    return HookOp(OpType.STOP_IF_BELOW, genotypes, ages, sex, float(threshold), when)
stop_if_above staticmethod
stop_if_above(genotypes: Union[str, List[str], Literal['*']] = '*', ages: Union[int, List[int], range, Literal['*']] = '*', sex: Literal['female', 'male', 'both'] = 'both', threshold: float = 1000000.0, when: Optional[str] = None) -> HookOp

Create an operation that stops the simulation if count exceeds threshold.

Parameters:

Name Type Description Default
genotypes Union[str, List[str], Literal['*']]

Genotype selector

'*'
ages Union[int, List[int], range, Literal['*']]

Age selector

'*'
sex Literal['female', 'male', 'both']

Sex selector

'both'
threshold float

Maximum count threshold

1000000.0
when Optional[str]

Optional condition expression

None

Returns:

Name Type Description
HookOp HookOp

Operation descriptor for compilation

Source code in src/natal/hooks/declarative.py
@staticmethod
def stop_if_above(
    genotypes: Union[str, List[str], Literal["*"]] = "*",
    ages: Union[int, List[int], range, Literal["*"]] = "*",
    sex: Literal["female", "male", "both"] = "both",
    threshold: float = 1_000_000.0,
    when: Optional[str] = None,
) -> HookOp:
    """Create an operation that stops the simulation if count exceeds threshold.

    Args:
        genotypes: Genotype selector
        ages: Age selector
        sex: Sex selector
        threshold: Maximum count threshold
        when: Optional condition expression

    Returns:
        HookOp: Operation descriptor for compilation
    """
    return HookOp(OpType.STOP_IF_ABOVE, genotypes, ages, sex, float(threshold), when)
stop_if_extinction staticmethod
stop_if_extinction(when: Optional[str] = None) -> HookOp

Create an operation that stops the simulation if total population goes extinct.

Parameters:

Name Type Description Default
when Optional[str]

Optional condition expression

None

Returns:

Name Type Description
HookOp HookOp

Operation descriptor for compilation

Source code in src/natal/hooks/declarative.py
@staticmethod
def stop_if_extinction(when: Optional[str] = None) -> HookOp:
    """Create an operation that stops the simulation if total population goes extinct.

    Args:
        when: Optional condition expression

    Returns:
        HookOp: Operation descriptor for compilation
    """
    return HookOp(OpType.STOP_IF_EXTINCTION, "*", "*", "both", 0.0, when)

HookExecutor

HookExecutor(registry: HookProgram, hooks_by_event: Dict[int, List[CompiledHookDescriptor]])

Python-layer coordinator for all hook execution modes.

This class is used by population event dispatch where both njit and Python callback hooks must coexist around the declarative CSR core.

Source code in src/natal/hooks/executor.py
def __init__(
    self,
    registry: HookProgram,
    hooks_by_event: Dict[int, List[CompiledHookDescriptor]],
) -> None:
    self.registry = registry
    self.hooks_by_event = hooks_by_event
from_compiled_hooks staticmethod
from_compiled_hooks(registry: HookProgram, compiled_hooks: List[CompiledHookDescriptor]) -> HookExecutor

Group descriptors by event and sort by priority.

Source code in src/natal/hooks/executor.py
@staticmethod
def from_compiled_hooks(
    registry: HookProgram,
    compiled_hooks: List[CompiledHookDescriptor],
) -> HookExecutor:
    """Group descriptors by event and sort by priority."""
    from collections import defaultdict

    hooks_by_event: Dict[int, List[CompiledHookDescriptor]] = defaultdict(list)
    for desc in compiled_hooks:
        event_id = EVENT_ID_MAP.get(desc.event)
        if event_id is not None:
            if desc.plan is not None or desc.njit_fn is not None or desc.py_wrapper is not None:
                hooks_by_event[event_id].append(desc)

    for event_id in hooks_by_event:
        hooks_by_event[event_id].sort(key=lambda x: x.priority)

    return HookExecutor(registry, dict(hooks_by_event))
execute_event
execute_event(event_id: int, population: BasePopulation[Any], tick: int, deme_id: int = 0) -> int

Run one event with priority ordering across hook types.

Source code in src/natal/hooks/executor.py
def execute_event(
    self,
    event_id: int,
    population: BasePopulation[Any],
    tick: int,
    deme_id: int = 0,
) -> int:
    """Run one event with priority ordering across hook types."""
    if event_id < 0 or event_id >= NUM_EVENTS:
        return RESULT_CONTINUE

    ind_count = population.state.individual_count

    # Prepare optional sperm-storage arrays for kernels that require them.
    sperm_store = getattr(population.state, "sperm_storage", None)
    has_sperm_storage = sperm_store is not None and sperm_store.size > 0
    if not has_sperm_storage:
        sperm_store = np.zeros((0, 0, 0), dtype=np.float64)
    assert sperm_store is not None
    is_stochastic = bool(getattr(getattr(population, "_config", None), "is_stochastic", False))
    use_continuous_sampling = bool(
        getattr(getattr(population, "_config", None), "use_continuous_sampling", False)
    )

    from ..numba_utils import NUMBA_ENABLED

    # Unified timeline: descriptors are pre-sorted by priority (stable).
    for desc in self.hooks_by_event.get(event_id, []):
        if not deme_selector_matches(desc.deme_selector, deme_id):
            continue

        if desc.plan is not None:
            result = execute_csr_event_arrays(
                np.int32(1),
                np.int32(1),
                np.array([0, 1], dtype=np.int32),
                np.array([desc.plan.n_ops], dtype=np.int32),
                np.array([0, desc.plan.n_ops], dtype=np.int32),
                desc.plan.op_types,
                desc.plan.gidx_offsets,
                desc.plan.gidx_data,
                desc.plan.age_offsets,
                desc.plan.age_data,
                desc.plan.sex_masks.flatten(),
                desc.plan.params,
                desc.plan.condition_offsets,
                desc.plan.condition_types,
                desc.plan.condition_params,
                np.array([0], dtype=np.int32),   # selector ANY
                np.array([0, 0], dtype=np.int32),
                np.array([], dtype=np.int32),
                0,
                ind_count,
                sperm_store,
                has_sperm_storage,
                tick,
                is_stochastic,
                use_continuous_sampling,
                deme_id,
            )
            if result == RESULT_STOP:
                return RESULT_STOP

        if desc.njit_fn is not None:
            try:
                result = desc.njit_fn(ind_count, tick, deme_id)
                if result == RESULT_STOP:
                    return RESULT_STOP
            except Exception as e:
                raise RuntimeError(f"Error in njit hook '{desc.name}': {e}") from e

        if desc.py_wrapper is not None and desc.njit_fn is None:
            if NUMBA_ENABLED:
                raise RuntimeError(
                    f"Python py_wrapper hook '{desc.name}' is not allowed when Numba is enabled."
                )
            try:
                # Check if it expects (population) or (ind_count, tick, deme_id)
                import inspect
                sig = inspect.signature(desc.py_wrapper)
                params = list(sig.parameters.values())
                if len(params) == 1:
                    # Single param - population
                    desc.py_wrapper(population)
                else:
                    # Custom hook - ind_count, tick, deme_id
                    desc.py_wrapper(ind_count, tick, deme_id)
            except Exception as e:
                raise RuntimeError(f"Error in py_wrapper hook '{desc.name}': {e}") from e

    return RESULT_CONTINUE

CompiledHookDescriptor dataclass

CompiledHookDescriptor(name: str, event: str, priority: int = 0, deme_selector: DemeSelector = '*', plan: Optional[CompiledHookPlan] = None, selectors: Dict[str, ndarray] = _empty_selector_map(), static_arrays: Tuple[ndarray, ...] = tuple(), meta: Dict[str, int] = _empty_meta_map(), njit_fn: Optional[Callable[..., object]] = None, py_wrapper: Optional[Callable[..., object]] = None, ops: Optional[List[HookOp]] = None)

Unified descriptor for all hook modes.

Exactly one of plan, njit_fn, or py_wrapper is typically used as the primary execution payload for a descriptor.

CompiledHookPlan dataclass

CompiledHookPlan(n_ops: int, op_types: ndarray, gidx_offsets: ndarray, gidx_data: ndarray, age_offsets: ndarray, age_data: ndarray, sex_masks: ndarray, params: ndarray, condition_offsets: ndarray, condition_types: ndarray, condition_params: ndarray)

Compiled declarative plan with CSR-style flattened arrays.

Variable-length fields (genotypes/ages/conditions) are represented via *_offsets + *_data to keep kernel inputs contiguous and compact.

HookOp dataclass

HookOp(op_type: OpType, genotypes: Union[str, List[str], Literal['*']] = '*', ages: Union[int, List[int], range, Literal['*']] = '*', sex: Literal['female', 'male', 'both'] = 'both', param: float = 1.0, condition: Optional[str] = None)

Single declarative operation before compilation.

Fields in this class can still be symbolic (for example genotype labels). The compiler resolves all symbolic fields into concrete integer arrays.

HookProgram

Bases: NamedTuple

Event-grouped plain-data CSR representation for declarative hooks.

OpType

Bases: IntEnum

Operation opcodes consumed by the runtime kernel.

We intentionally keep integer values stable because these values are serialized into CompiledHookPlan.op_types and interpreted in the executor hot-loop.

compile_combined_hook

compile_combined_hook(njit_fns: List[HookCallable], deme_selectors: Optional[List[DemeSelector]] = None) -> HookCallable

Combine multiple njit hooks into one generated njit function.

We generate source code instead of composing Python closures so the result remains callable from njit kernels.

When deme_selectors is provided and contains non-wildcard values, each hook call is wrapped with an if deme_id == X guard so that per-deme hooks only execute for their target deme(s) — critical for spatial simulations where all hooks share one combined function.

Parameters:

Name Type Description Default
njit_fns List[HookCallable]

List of njit-compiled hook functions.

required
deme_selectors Optional[List[DemeSelector]]

Optional per-function deme target. When None or all "*", no guards are generated (panmictic-safe).

None
Source code in src/natal/hooks/compiler.py
def compile_combined_hook(
    njit_fns: List[HookCallable],
    deme_selectors: Optional[List[DemeSelector]] = None,
) -> HookCallable:
    """Combine multiple njit hooks into one generated njit function.

    We generate source code instead of composing Python closures so the result
    remains callable from njit kernels.

    When ``deme_selectors`` is provided and contains non-wildcard values,
    each hook call is wrapped with an ``if deme_id == X`` guard so that
    per-deme hooks only execute for their target deme(s) — critical for
    spatial simulations where all hooks share one combined function.

    Args:
        njit_fns: List of njit-compiled hook functions.
        deme_selectors: Optional per-function deme target.  When ``None``
            or all ``"*"``, no guards are generated (panmictic-safe).
    """
    if len(njit_fns) == 0:
        return _noop_hook

    # Normalize to list so pyright can track the type (not Optional).
    ds_list: List[DemeSelector] = deme_selectors if deme_selectors is not None else []
    needs_guard = any(ds != "*" for ds in ds_list)

    # Without guards, single-hook combos can return the function directly.
    if not needs_guard and len(njit_fns) == 1:
        return njit_fns[0]

    # Stable key ensures deterministic module names and cache reuse.
    if needs_guard:
        combined_parts = ["combined_guarded"]
        for fn, ds in zip(njit_fns, ds_list):
            combined_parts.append(stable_callable_identity(fn))
            combined_parts.append(str(ds))
    else:
        combined_parts = ["combined"] + [stable_callable_identity(fn) for fn in njit_fns]
    key = hash_key(combined_parts)
    fn_name = f"_combined_hook_{key}"
    module_stem = f"combined_hook_{key}"
    placeholder_names = [f"_FN_{i}" for i in range(len(njit_fns))]

    # Generated module imports the same switch helper as the rest of hook DSL.
    lines = ["from natal.hook_dsl import njit_switch"]
    lines.extend([f"{placeholder} = None" for placeholder in placeholder_names])
    lines.extend(
        [
            "",
            "@njit_switch(cache=True)",
            f"def {fn_name}(ind_count, tick, deme_id=-1):",
        ]
    )

    if needs_guard:
        for placeholder, ds in zip(placeholder_names, ds_list):
            if ds == "*":
                lines.append(f"    _result = {placeholder}(ind_count, tick, deme_id)")
                lines.append("    if _result != 0:")
                lines.append("        return _result")
            elif isinstance(ds, int):
                lines.append(f"    if deme_id == {int(ds)}:")
                lines.append(f"        _result = {placeholder}(ind_count, tick, deme_id)")
                lines.append("        if _result != 0:")
                lines.append("            return _result")
            elif isinstance(ds, range):
                lines.append(f"    if {ds.start} <= deme_id < {ds.stop}:")
                lines.append(f"        _result = {placeholder}(ind_count, tick, deme_id)")
                lines.append("        if _result != 0:")
                lines.append("            return _result")
            else:
                # List or tuple — generate a tuple literal for Numba's ``in``.
                items = ", ".join(str(int(x)) for x in ds)
                lines.append(f"    if deme_id in ({items}):")
                lines.append(f"        _result = {placeholder}(ind_count, tick, deme_id)")
                lines.append("        if _result != 0:")
                lines.append("            return _result")
    else:
        for placeholder in placeholder_names:
            lines.append(f"    _result = {placeholder}(ind_count, tick, deme_id)")
            lines.append("    if _result != 0:")
            lines.append("        return _result")
    lines.append("    return 0")
    lines.append("")

    module_path = write_codegen_module(module_stem, "\n".join(lines))
    module = load_codegen_module(module_stem, module_path)

    for placeholder, fn in zip(placeholder_names, njit_fns):
        setattr(module, placeholder, fn)

    return getattr(module, fn_name)

hook

hook(event: Optional[str] = None, selectors: Optional[Dict[str, Any]] = None, priority: int = 0, custom: bool = False, deme: DemeSelector = '*') -> Callable[[Callable[..., Any]], DecoratedHookFn]

Decorator entrypoint for all supported hook authoring styles.

The decorated function gets a register(pop, event_override=None) helper that compiles and registers a CompiledHookDescriptor.

Hook type is determined by: - selectors specified -> Selector hook - custom=True or has required params -> Custom hook - otherwise -> Declarative hook (function returns List[HookOp])

For custom/selector hooks, Numba compilation is automatic — you do not need to stack @njit. If Numba is enabled, the function is wrapped with njit_switch automatically. If Numba is disabled, a pure-Python wrapper is used.

When a custom hook is called inside a spatial prange region, the deme_id parameter receives the current deme index, enabling one hook function to handle all demes with per-deme branching logic.

Parameters:

Name Type Description Default
event Optional[str]

Hook event name.

None
selectors Optional[Dict[str, Any]]

Optional symbolic selectors for selector-mode hooks.

None
priority int

Execution priority (lower values run earlier).

0
custom bool

If True, treat as custom hook (function is called directly).

False
deme DemeSelector

Target deme(s) for spatial populations. "*" (default) means all demes. Accepts a single int, list, tuple, or range.

'*'
Source code in src/natal/hooks/compiler.py
def hook(
    event: Optional[str] = None,
    selectors: Optional[Dict[str, Any]] = None,
    priority: int = 0,
    custom: bool = False,
    deme: DemeSelector = "*",
) -> Callable[[Callable[..., Any]], DecoratedHookFn]:
    """Decorator entrypoint for all supported hook authoring styles.

    The decorated function gets a ``register(pop, event_override=None)``
    helper that compiles and registers a ``CompiledHookDescriptor``.

    Hook type is determined by:
    - selectors specified -> Selector hook
    - custom=True or has required params -> Custom hook
    - otherwise -> Declarative hook (function returns List[HookOp])

    For custom/selector hooks, Numba compilation is automatic — you do
    **not** need to stack ``@njit``.  If Numba is enabled, the function is
    wrapped with ``njit_switch`` automatically.  If Numba is disabled, a
    pure-Python wrapper is used.

    When a custom hook is called inside a spatial ``prange`` region, the
    ``deme_id`` parameter receives the current deme index, enabling one
    hook function to handle all demes with per-deme branching logic.

    Args:
        event: Hook event name.
        selectors: Optional symbolic selectors for selector-mode hooks.
        priority: Execution priority (lower values run earlier).
        custom: If True, treat as custom hook (function is called directly).
        deme: Target deme(s) for spatial populations.  ``"*"`` (default)
            means all demes.  Accepts a single int, list, tuple, or range.
    """
    def decorator(func: Callable[..., Any]) -> DecoratedHookFn:
        hook_func = cast(DecoratedHookFn, func)
        hook_func.meta = {
            "event": event,
            "selectors": selectors or {},
            "priority": priority,
            "custom": custom,
            "deme_selector": deme,
        }
        hook_func.compiled = None
        hook_func.event = event
        hook_func.selectors = selectors or {}
        hook_func.priority = priority
        hook_func.custom = custom
        hook_func.deme_selector = deme

        def register(
            pop: BasePopulation[Any],
            event_override: Optional[str] = None,
            deme_selector_override: Optional[DemeSelector] = None,
        ) -> CompiledHookDescriptor:
            """Compile this hook against one population instance."""
            from ..numba_utils import NUMBA_ENABLED
            from .types import CompiledHookDescriptor

            actual_event = event_override or event
            actual_deme_selector: DemeSelector = deme if deme_selector_override is None else deme_selector_override
            if actual_event is None:
                raise ValueError(
                    f"Event not specified for hook '{func.__name__}'. "
                    "Specify in decorator @hook(event='...') or call pop.set_hook('event', hook)"
                )

            has_required_params = _has_required_parameters(func)
            is_declarative_pop_hook = _is_declarative_population_hook(func)
            is_custom_or_selector = custom or selectors is not None or (has_required_params and not is_declarative_pop_hook)

            if is_custom_or_selector:
                if selectors is not None:
                    desc = compile_selector_hook(
                        func,
                        pop,
                        actual_event,
                        selectors,
                        priority,
                        deme_selector=actual_deme_selector,
                    )
                else:
                    if is_njit_function(func):
                        # Already njit-decorated
                        norm_fn = func
                        desc = CompiledHookDescriptor(
                            name=func.__name__,
                            event=actual_event,
                            priority=priority,
                            deme_selector=actual_deme_selector,
                            njit_fn=norm_fn,
                            meta={"n_genotypes": pop.index_registry.num_genotypes(), "n_ages": pop.config.n_ages},
                        )
                    else:
                        # Try to use njit_switch
                        try:
                            decorated_func = njit_switch(cache=False)(func)
                            # Check if it's a valid compiled function
                            if NUMBA_ENABLED and is_njit_function(decorated_func):
                                norm_fn = _normalize_njit_fn(decorated_func)
                                desc = CompiledHookDescriptor(
                                    name=func.__name__,
                                    event=actual_event,
                                    priority=priority,
                                    deme_selector=actual_deme_selector,
                                    njit_fn=norm_fn,
                                    meta={"n_genotypes": pop.index_registry.num_genotypes(), "n_ages": pop.config.n_ages},
                                )
                            else:
                                # NUMBA_ENABLED is False, use py wrapper
                                wrapped_func = _normalize_py_hook(func)
                                desc = CompiledHookDescriptor(
                                    name=func.__name__,
                                    event=actual_event,
                                    priority=priority,
                                    deme_selector=actual_deme_selector,
                                    njit_fn=None,
                                    py_wrapper=wrapped_func,
                                    meta={"n_genotypes": pop.index_registry.num_genotypes(), "n_ages": pop.config.n_ages},
                                )
                        except Exception:
                            # Fall back to py wrapper
                            wrapped_func = _normalize_py_hook(func)
                            desc = CompiledHookDescriptor(
                                name=func.__name__,
                                event=actual_event,
                                priority=priority,
                                deme_selector=actual_deme_selector,
                                njit_fn=None,
                                py_wrapper=wrapped_func,
                                meta={"n_genotypes": pop.index_registry.num_genotypes(), "n_ages": pop.config.n_ages},
                            )
            elif is_declarative_pop_hook:
                # Single population parameter - use as py_wrapper, but check numba enabled
                if NUMBA_ENABLED:
                    raise TypeError(
                        f"Python hook '{func.__name__}' is not allowed when Numba is enabled. "
                        "Please convert it to @njit or use declarative Op hooks."
                    )
                desc = CompiledHookDescriptor(
                    name=func.__name__,
                    event=actual_event,
                    priority=priority,
                    deme_selector=actual_deme_selector,
                    py_wrapper=func,
                    meta={"n_genotypes": pop.index_registry.num_genotypes(), "n_ages": pop.config.n_ages},
                )
            else:
                result = func()
                if isinstance(result, list):
                    result_ops = cast(List[object], result)
                    if not all(isinstance(op, HookOp) for op in result_ops):
                        raise TypeError(
                            f"Declarative hook '{func.__name__}' must return List[HookOp], "
                            "or use custom=True for custom mode."
                        )
                    ops = cast(List[HookOp], result_ops)
                    desc = compile_declarative_hook(
                        ops,
                        pop,
                        actual_event,
                        priority,
                        deme_selector=actual_deme_selector,
                        name=func.__name__,
                    )
                else:
                    raise TypeError(
                        f"Hook '{func.__name__}' must return List[HookOp] for declarative mode, "
                        "or use custom=True for custom mode."
                    )

            hook_func.compiled = desc  # type: ignore
            pop.register_compiled_hook(desc)
            return desc

        hook_func.register = register  # type: ignore
        return hook_func

    return decorator

compile_declarative_hook

compile_declarative_hook(ops: List[HookOp], pop: BasePopulation[Any], event: str, priority: int = 0, deme_selector: DemeSelector = '*', name: str = 'declarative_hook') -> CompiledHookDescriptor

Compile declarative ops into a CompiledHookDescriptor.

The compiler packs all per-op fields into parallel arrays. Offsets arrays (*_offsets) define CSR spans for variable-length selector/condition data and avoid Python object usage in runtime kernels.

Source code in src/natal/hooks/declarative.py
def compile_declarative_hook(
    ops: List[HookOp],
    pop: BasePopulation[Any],
    event: str,
    priority: int = 0,
    deme_selector: DemeSelector = "*",
    name: str = "declarative_hook",
) -> CompiledHookDescriptor:
    """Compile declarative ops into a ``CompiledHookDescriptor``.

    The compiler packs all per-op fields into parallel arrays. Offsets arrays
    (``*_offsets``) define CSR spans for variable-length selector/condition
    data and avoid Python object usage in runtime kernels.
    """
    # Get population configuration and registry for resolving genotype/age indices
    index_registry = pop.index_registry
    diploid_genotypes = index_registry.index_to_genotype
    n_genotypes = index_registry.num_genotypes()
    n_ages = pop.config.n_ages

    # Initialize data structures for storing compiled hook operations
    # These will be packed into parallel arrays for efficient runtime execution

    # 1. Operation type stream - stores the operation code for each hook
    op_types_list: List[int] = []

    # 2. Genotype selection data (CSR format)
    # gidx_offsets: CSR offsets defining genotype index ranges for each operation
    # gidx_data: Flattened list of all genotype indices across all operations
    gidx_offsets: List[int] = [0]  # Start with offset 0 for the first operation
    gidx_data_list: List[int] = []

    # 3. Age selection data (CSR format)
    # age_offsets: CSR offsets defining age index ranges for each operation
    # age_data: Flattened list of all age indices across all operations
    age_offsets: List[int] = [0]  # Start with offset 0 for the first operation
    age_data_list: List[int] = []

    # 4. Sex selection and operation parameters
    # sex_masks: Boolean masks for male/female selection (2D array: [op][sex])
    # params: Numeric parameters for each operation (e.g., fitness values)
    sex_masks_list: List[NDArray[np.bool_]] = []
    params_list: List[float] = []

    # 5. Condition expression data (CSR format)
    # condition_offsets: CSR offsets defining condition token ranges for each operation
    # condition_types: Flattened list of condition operation types
    # condition_params: Flattened list of condition parameters
    condition_offsets: List[int] = [0]  # Start with offset 0 for the first operation
    condition_types_list: List[int] = []
    condition_params_list: List[int] = []

    # Process each hook operation and compile it into the packed arrays
    for op in ops:
        # 1) Operation type - convert enum to integer for efficient runtime lookup
        op_types_list.append(int(op.op_type))

        # 2) Genotype span - resolve genotype selectors to actual genotype indices
        # Examples: "A1|A1" -> [0], "*" -> [0, 1, 2, ..., n_genotypes-1]
        gidx_array = _resolve_genotypes(op.genotypes, index_registry, diploid_genotypes, n_genotypes)
        gidx_data_list.extend(gidx_array.tolist())
        gidx_offsets.append(len(gidx_data_list))  # Record end offset for this operation

        # 3) Age span - resolve age selectors to actual age indices
        # Examples: "0-5" -> [0, 1, 2, 3, 4, 5], "*" -> [0, 1, ..., n_ages-1]
        age_array = _resolve_ages(op.ages, n_ages)
        age_data_list.extend(age_array.tolist())
        age_offsets.append(len(age_data_list))  # Record end offset for this operation

        # 4) Sex mask + numeric parameter
        # Convert sex selector to boolean mask [male_selected, female_selected]
        sex_masks_list.append(_resolve_sex(op.sex))
        params_list.append(float(op.param))  # Convert parameter to float

        # 5) Compiled condition token span
        # Parse condition expression into RPN (Reverse Polish Notation) tokens
        cond_types, cond_params = _parse_condition(op.condition)
        condition_types_list.extend(cond_types.tolist())
        condition_params_list.extend(cond_params.tolist())
        condition_offsets.append(len(condition_types_list))  # Record end offset

    # Create the compiled execution plan with all packed arrays
    plan = CompiledHookPlan(
        n_ops=len(ops),  # Total number of operations

        # Operation type stream - each element is an integer operation code
        op_types=np.array(op_types_list, dtype=np.int32),

        # Genotype selection data in CSR format
        # gidx_offsets[i] to gidx_offsets[i+1] defines genotype indices for operation i
        gidx_offsets=np.array(gidx_offsets, dtype=np.int32),
        gidx_data=np.array(gidx_data_list, dtype=np.int32) if gidx_data_list else np.array([], dtype=np.int32),

        # Age selection data in CSR format
        # age_offsets[i] to age_offsets[i+1] defines age indices for operation i
        age_offsets=np.array(age_offsets, dtype=np.int32),
        age_data=np.array(age_data_list, dtype=np.int32) if age_data_list else np.array([], dtype=np.int32),

        # Sex selection masks - 2D boolean array [n_ops x 2]
        # Each row: [male_selected, female_selected]
        sex_masks=np.vstack(sex_masks_list) if sex_masks_list else np.zeros((0, 2), dtype=np.bool_),

        # Operation parameters - numeric values for each operation
        params=np.array(params_list, dtype=np.float64),

        # Condition expression data in CSR format
        # condition_offsets[i] to condition_offsets[i+1] defines condition tokens for operation i
        condition_offsets=np.array(condition_offsets, dtype=np.int32),
        condition_types=np.array(condition_types_list, dtype=np.int32),
        condition_params=np.array(condition_params_list, dtype=np.int32),
    )

    # Return the complete hook descriptor with metadata
    return CompiledHookDescriptor(
        name=name,                    # Human-readable name for debugging
        event=event,                  # Simulation event when this hook triggers
        priority=priority,            # Execution priority (higher = earlier)
        deme_selector=deme_selector, # Which demes this hook applies to
        plan=plan,                    # Compiled execution plan
        meta={"n_genotypes": n_genotypes, "n_ages": n_ages},  # Population metadata
        ops=ops,                     # Original operations for reference/debugging
    )

build_hook_program

build_hook_program(program: HookProgram) -> HookProgram

Compatibility hook for future HookProgram validation/migration.

Source code in src/natal/hooks/executor.py
def build_hook_program(program: HookProgram) -> HookProgram:
    """Compatibility hook for future HookProgram validation/migration."""
    return program

deme_selector_matches

deme_selector_matches(selector: DemeSelector, deme_id: int) -> bool

Return whether one deme id should execute under selector.

Supported forms: - "*" for all demes - int for one deme - list/tuple/range for a set of demes

Source code in src/natal/hooks/executor.py
def deme_selector_matches(selector: DemeSelector, deme_id: int) -> bool:
    """Return whether one deme id should execute under ``selector``.

    Supported forms:
    - "*" for all demes
    - int for one deme
    - list/tuple/range for a set of demes
    """
    if selector == "*":
        return True
    if isinstance(selector, int):
        return selector == deme_id
    if isinstance(selector, range):
        return deme_id in selector
    return deme_id in selector

execute_csr_event_arrays

execute_csr_event_arrays(n_events: int | integer[Any], n_hooks: int | integer[Any], hook_offsets: ndarray, n_ops_list: ndarray, op_offsets: ndarray, op_types_data: ndarray, gidx_offsets_data: ndarray, gidx_data: ndarray, age_offsets_data: ndarray, age_data: ndarray, sex_masks_data: ndarray, params_data: ndarray, condition_offsets_data: ndarray, condition_types_data: ndarray, condition_params_data: ndarray, deme_selector_types: ndarray, deme_selector_offsets: ndarray, deme_selector_data: ndarray, event_id: int, individual_count: ndarray, sperm_storage: ndarray, has_sperm_storage: bool, tick: int, is_stochastic: bool, use_continuous_sampling: bool, deme_id: int) -> int

Execute one event from flattened CSR arrays.

Inputs are plain arrays extracted from HookProgram. This function is the hottest part of declarative hook runtime.

Source code in src/natal/hooks/executor.py
@njit_switch(cache=True)
def execute_csr_event_arrays(
    n_events: int | np.integer[Any],
    n_hooks: int | np.integer[Any],
    hook_offsets: np.ndarray,
    n_ops_list: np.ndarray,
    op_offsets: np.ndarray,
    op_types_data: np.ndarray,
    gidx_offsets_data: np.ndarray,
    gidx_data: np.ndarray,
    age_offsets_data: np.ndarray,
    age_data: np.ndarray,
    sex_masks_data: np.ndarray,
    params_data: np.ndarray,
    condition_offsets_data: np.ndarray,
    condition_types_data: np.ndarray,
    condition_params_data: np.ndarray,
    deme_selector_types: np.ndarray,
    deme_selector_offsets: np.ndarray,
    deme_selector_data: np.ndarray,
    event_id: int,
    individual_count: np.ndarray,
    sperm_storage: np.ndarray,
    has_sperm_storage: bool,
    tick: int,
    is_stochastic: bool,
    use_continuous_sampling: bool,
    deme_id: int,
) -> int:
    """Execute one event from flattened CSR arrays.

    Inputs are plain arrays extracted from ``HookProgram``. This function is
    the hottest part of declarative hook runtime.
    """
    if event_id < 0 or event_id >= n_events:
        return 0

    # Event span -> hook span -> op span (three-level CSR traversal)
    hook_start = hook_offsets[event_id]
    hook_end = hook_offsets[event_id + 1]

    for hook_idx in range(hook_start, hook_end):
        if hook_idx < 0 or hook_idx >= n_hooks:
            continue

        # Filtering by deme_id using serialized selector data
        if not njit_deme_selector_matches(
            deme_selector_types[hook_idx],
            deme_selector_offsets[hook_idx],
            deme_selector_offsets[hook_idx + 1],
            deme_selector_data,
            deme_id,
        ):
            continue

        op_start = op_offsets[hook_idx]
        op_end = op_offsets[hook_idx + 1]

        for op_idx in range(op_start, op_end):
            cond_start = condition_offsets_data[op_idx]
            cond_end = condition_offsets_data[op_idx + 1]

            if not _eval_csr_condition_program(
                condition_types_data,
                condition_params_data,
                cond_start,
                cond_end,
                tick,
            ):
                continue

            op_type = op_types_data[op_idx]
            param = params_data[op_idx]

            gidx_start = gidx_offsets_data[op_idx]
            gidx_end = gidx_offsets_data[op_idx + 1]
            age_start = age_offsets_data[op_idx]
            age_end = age_offsets_data[op_idx + 1]

            sex_mask_idx = op_idx * 2
            sex_female = sex_masks_data[sex_mask_idx]
            sex_male = sex_masks_data[sex_mask_idx + 1]

            for sex_idx in range(2):
                if sex_idx == 0 and not sex_female:
                    continue
                if sex_idx == 1 and not sex_male:
                    continue

                for age_idx_ptr in range(age_start, age_end):
                    age = age_data[age_idx_ptr]

                    for gidx_ptr in range(gidx_start, gidx_end):
                        gidx = gidx_data[gidx_ptr]
                        current = individual_count[sex_idx, age, gidx]

                        # Convert each operation to a target count first, then
                        # route through one unified update function so survival
                        # semantics are consistent across operators.
                        if op_type == 0:     # Op.scale
                            target = max(0.0, current * param)
                        elif op_type == 1:   # Op.set
                            target = max(0.0, param)
                        elif op_type == 2:   # Op.add
                            target = max(0.0, current + param)
                        elif op_type == 3:   # Op.subtract
                            target = max(0.0, current - param)
                        elif op_type == 4:   # Op.kill
                            target = max(0.0, current * (1.0 - param))
                        elif op_type == 5:   # Op.sample
                            target = min(current, max(0.0, param))
                        else:
                            target = current

                        if op_type <= 5:   # Op.scale, Op.set, Op.add, Op.subtract, Op.kill, Op.sample
                            if sex_idx == 0 and has_sperm_storage:
                                individual_count[sex_idx, age, gidx] = _apply_target_with_sperm(
                                    current,
                                    target,
                                    sperm_storage[age, gidx, :],
                                    is_stochastic,
                                    use_continuous_sampling,
                                )
                            else:
                                individual_count[sex_idx, age, gidx] = _apply_target_without_sperm(
                                    current,
                                    target,
                                    is_stochastic,
                                    use_continuous_sampling,
                                )

            # STOP_IF_* operators aggregate selected cells and may short-circuit
            # event execution with RESULT_STOP.
            if op_type == 6 or op_type == 7 or op_type == 8:   # Op.stop_if_zero, Op.stop_if_below, Op.stop_if_above
                selected_total = 0.0
                for sex_idx in range(2):
                    if sex_idx == 0 and not sex_female:
                        continue
                    if sex_idx == 1 and not sex_male:
                        continue

                    for age_idx_ptr in range(age_start, age_end):
                        age = age_data[age_idx_ptr]
                        for gidx_ptr in range(gidx_start, gidx_end):
                            gidx = gidx_data[gidx_ptr]
                            selected_total += individual_count[sex_idx, age, gidx]

                if op_type == 6 and selected_total <= 0.0:
                    return RESULT_STOP
                if op_type == 7 and selected_total < param:
                    return RESULT_STOP
                if op_type == 8 and selected_total > param:
                    return RESULT_STOP
            elif op_type == 9:   # Op.stop_if_extinction
                if individual_count.sum() <= 0.0:
                    return RESULT_STOP

    return RESULT_CONTINUE

execute_csr_event_program

execute_csr_event_program(program: HookProgram, event_id: int, individual_count: ndarray, tick: int) -> int

Compatibility wrapper with deterministic defaults and no sperm storage.

Source code in src/natal/hooks/executor.py
@njit_switch(cache=True)
def execute_csr_event_program(
    program: HookProgram,
    event_id: int,
    individual_count: np.ndarray,
    tick: int,
) -> int:
    """Compatibility wrapper with deterministic defaults and no sperm storage."""
    dummy_sperm = np.zeros((0, 0, 0), dtype=np.float64)
    return execute_csr_event_program_with_state(
        program,
        event_id,
        individual_count,
        dummy_sperm,
        tick,
        False,
        False,
        False,  # use_continuous_sampling
        0,      # deme_id
    )

execute_csr_event_program_with_state

execute_csr_event_program_with_state(program: HookProgram, event_id: int, individual_count: ndarray, sperm_storage: ndarray, tick: int, is_stochastic: bool, has_sperm_storage: bool, use_continuous_sampling: bool, deme_id: int = 0) -> int

Execute event directly from HookProgram while exposing state flags.

Source code in src/natal/hooks/executor.py
@njit_switch(cache=True)
def execute_csr_event_program_with_state(
    program: HookProgram,
    event_id: int,
    individual_count: np.ndarray,
    sperm_storage: np.ndarray,
    tick: int,
    is_stochastic: bool,
    has_sperm_storage: bool,
    use_continuous_sampling: bool,
    deme_id: int = 0,
) -> int:
    """Execute event directly from ``HookProgram`` while exposing state flags."""
    return execute_csr_event_arrays(
        program.n_events,
        program.n_hooks,
        program.hook_offsets,
        program.n_ops_list,
        program.op_offsets,
        program.op_types_data,
        program.gidx_offsets_data,
        program.gidx_data,
        program.age_offsets_data,
        program.age_data,
        program.sex_masks_data,
        program.params_data,
        program.condition_offsets_data,
        program.condition_types_data,
        program.condition_params_data,
        program.deme_selector_types,
        program.deme_selector_offsets,
        program.deme_selector_data,
        event_id,
        individual_count,
        sperm_storage,
        has_sperm_storage,
        tick,
        is_stochastic,
        use_continuous_sampling,
        deme_id,
    )

compile_selector_hook

compile_selector_hook(func: Callable[..., Any], pop: BasePopulation[Any], event: str, selectors_spec: Dict[str, SelectorSpec], priority: int = 0, deme_selector: DemeSelector = '*') -> CompiledHookDescriptor

Compile selector hook into njit or python descriptor.

resolved stores canonical selector arrays and is reused by both execution paths.

For selector hooks, Numba compilation depends on: - If function is @njit decorated, use it directly - Otherwise, use global NUMBA_ENABLED setting (auto-wrap if enabled)

Source code in src/natal/hooks/selector.py
def compile_selector_hook(
    func: Callable[..., Any],
    pop: BasePopulation[Any],
    event: str,
    selectors_spec: Dict[str, SelectorSpec],
    priority: int = 0,
    deme_selector: DemeSelector = "*",
) -> CompiledHookDescriptor:
    """Compile selector hook into njit or python descriptor.

    ``resolved`` stores canonical selector arrays and is reused by both
    execution paths.

    For selector hooks, Numba compilation depends on:
    - If function is @njit decorated, use it directly
    - Otherwise, use global NUMBA_ENABLED setting (auto-wrap if enabled)
    """
    index_registry = pop.registry
    diploid_genotypes = index_registry.index_to_genotype

    resolved = {
        name: _resolve_selector_to_array(spec, index_registry, diploid_genotypes)
        for name, spec in selectors_spec.items()
    }

    meta = {
        "n_genotypes": index_registry.num_genotypes(),
        "n_ages": pop.config.n_ages,
    }

    from ..numba_utils import NUMBA_ENABLED

    is_njit_fn = is_numba_dispatcher(func)

    if is_njit_fn or NUMBA_ENABLED:
        # Numba path: generate a thin wrapper with literal selector args.
        # The wrapper will call the user function (whether @njit or not)

        # Handle signature normalization for user function (2 or 3 args before selectors)
        py_func = getattr(func, "py_func", func)
        sig = inspect.signature(py_func)
        # Check if user fn expects deme_id (3 positional args before kwargs)
        has_deme_id = len([p for p in sig.parameters.values() if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)]) >= 3

        njit_fn = _compile_selector_njit_wrapper(func, resolved, has_deme_id)
        return CompiledHookDescriptor(
            name=func.__name__,
            event=event,
            priority=priority,
            deme_selector=deme_selector,
            selectors=resolved,
            meta=meta,
            njit_fn=njit_fn,
        )

    # Python path: pass scalar for length-1 selectors, full array otherwise.
    def py_wrapper(population: BasePopulation[Any]) -> None:
        kwargs = _build_selector_python_kwargs(resolved)
        func(population, **kwargs)

    return CompiledHookDescriptor(
        name=func.__name__,
        event=event,
        priority=priority,
        deme_selector=deme_selector,
        selectors=resolved,
        meta=meta,
        py_wrapper=py_wrapper,
    )

njit_switch

njit_switch(func: Callable[P, R], *, cache: bool = True, parallel: bool = False, fastmath: bool = False, **njit_kwargs: Any) -> Callable[P, R]
njit_switch(func: None = None, *, cache: bool = True, parallel: bool = False, fastmath: bool = False, **njit_kwargs: Any) -> Callable[[Callable[P, R]], Callable[P, R]]
njit_switch(func: Optional[Callable[P, R]] = None, *, cache: bool = True, parallel: bool = False, fastmath: bool = False, **njit_kwargs: Any) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]

Numba @njit decorator for functions (controlled by global NUMBA_ENABLED flag).

Parameters:

Name Type Description Default
func Optional[Callable[P, R]]

Function to decorate

None
cache bool

Cache compiled functions (default: True)

True
parallel bool

Enable automatic parallelization (default: False)

False
fastmath bool

Enable fast math optimizations (default: False)

False
**njit_kwargs Any

Additional arguments for numba.njit

{}

Examples:

@njit_switch
def my_func(x):
    ...

@njit_switch(parallel=True, fastmath=True)
def my_parallel_func(x):
    ...
Source code in src/natal/numba_utils.py
def njit_switch(
    func: Optional[Callable[P, R]] = None,
    *,
    cache: bool = True,
    parallel: bool = False,
    fastmath: bool = False,
    **njit_kwargs: Any,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
    """
    Numba @njit decorator for functions (controlled by global NUMBA_ENABLED flag).

    Args:
        func: Function to decorate
        cache: Cache compiled functions (default: True)
        parallel: Enable automatic parallelization (default: False)
        fastmath: Enable fast math optimizations (default: False)
        **njit_kwargs: Additional arguments for numba.njit

    Examples:
        ```python
        @njit_switch
        def my_func(x):
            ...

        @njit_switch(parallel=True, fastmath=True)
        def my_parallel_func(x):
            ...
        ```
    """

    def decorator(fn: Callable[P, R]) -> Callable[P, R]:
        numba_func: Optional[Callable[P, R]] = None

        if NUMBA_ENABLED:
            try:
                from numba import njit  # pyright: ignore
                from numba.core import config as numba_config  # pyright: ignore

                _apply_numba_cache_dir()

                # Keep all JIT/cache output under one project switch.
                config_obj: Any = numba_config
                setattr(config_obj, "DEBUG_CACHE", 1 if NUMBA_LOG_ENABLED else 0)  # noqa: B010
                _install_cache_log_formatter()
                _install_dispatcher_compile_formatter()

                numba_func = cast(
                    Callable[P, R],
                    njit(
                        fn,
                        cache=cache,
                        parallel=parallel,
                        fastmath=fastmath,
                        **njit_kwargs,
                    ),
                )
            except ImportError:
                pass
            except Exception:
                pass

        # Return compiled or original function
        return numba_func if numba_func is not None else fn

    # Support both @njit_switch and @njit_switch(...)
    if func is not None:
        return decorator(func)
    return decorator