Module Sarek_ppx_lib.Sarek_convergence

type exec_mode =
  1. | Converged
    (*

    All threads in workgroup execute this code

    *)
  2. | Diverged
    (*

    Threads may have taken different branches

    *)

Execution mode for convergence analysis

type ctx = {
  1. mode : exec_mode;
}

Analysis context

val init_ctx : ctx

Initial context - start converged

val diverge : 'a -> ctx

Enter diverged mode

val is_thread_varying_intrinsic : Sarek_env.intrinsic_ref -> bool

Check if an intrinsic ref is thread-varying using core primitives

val is_barrier_ref : Sarek_env.intrinsic_ref -> bool

Check if an intrinsic ref is a barrier/convergence point

val is_thread_varying : Sarek_typed_ast.texpr -> bool

Check if an expression's value varies per-thread.

Thread-varying (different per thread):

  • thread_idx_x, thread_idx_y, thread_idx_z
  • global_thread_id
  • Expressions derived from above
  • Array accesses with thread-varying index

Uniform (same for all threads in workgroup):

  • Constants and literals
  • block_idx_x/y/z (same for whole block)
  • block_dim_x/y/z, grid_dim_x/y/z
  • Expressions derived only from uniform values
val is_barrier_intrinsic : Sarek_env.intrinsic_ref -> bool

Check if an intrinsic reference is a barrier

val is_warp_convergence_ref : Sarek_env.intrinsic_ref -> bool

Check if an intrinsic ref requires warp convergence

val check_expr : ctx -> Sarek_typed_ast.texpr -> Sarek_error.error list

Collect errors from convergence analysis

val contains_diverging_control_flow : Sarek_typed_ast.texpr -> bool

Check if an expression contains any control flow with thread-varying conditions. This is used for superstep analysis - the implicit barrier at the end of a superstep requires that no divergence occurs within the body.

val check_module_item : ctx -> Sarek_typed_ast.tmodule_item -> Sarek_error.error list

Check a module item

val check_kernel : Sarek_typed_ast.tkernel -> (unit, Sarek_error.error list) Stdlib.result

Check a kernel for convergence safety

val expr_uses_barriers : Sarek_typed_ast.texpr -> bool

Check if a kernel uses any barriers (explicit or implicit). Used for compile-time optimization of native kernel execution.

val kernel_uses_barriers : Sarek_typed_ast.tkernel -> bool

Check if a kernel uses barriers

Kernel Dimensionality Analysis

Detect which thread/block dimensions a kernel uses to enable optimized native CPU execution. Simple kernels that only use global_idx_x can be run with a simple parallel for loop without the thread_state overhead.

type dim_usage = {
  1. uses_x : bool;
  2. uses_y : bool;
  3. uses_z : bool;
  4. uses_block_dim : bool;
    (*

    Uses block_dim_x/y/z

    *)
  5. uses_grid_dim : bool;
    (*

    Uses grid_dim_x/y/z

    *)
  6. uses_thread_idx : bool;
    (*

    Uses thread_idx_x/y/z directly

    *)
  7. uses_block_idx : bool;
    (*

    Uses block_idx_x/y/z directly

    *)
  8. uses_shared_mem : bool;
    (*

    Uses shared memory

    *)
}

Dimension usage record

val empty_dim_usage : dim_usage
val merge_dim_usage : dim_usage -> dim_usage -> dim_usage
val dim_usage_of_name : string -> dim_usage

Check if an intrinsic name affects dimension usage

val dim_usage_of_intrinsic_ref : Sarek_env.intrinsic_ref -> dim_usage

Get dimension usage from an intrinsic reference

val expr_dim_usage : Sarek_typed_ast.texpr -> dim_usage

Analyze expression for dimension usage

type exec_strategy =
  1. | Simple1D
    (*

    Only uses global_idx_x - can use simple parallel for

    *)
  2. | Simple2D
    (*

    Only uses global_idx_x/y - can use nested loops

    *)
  3. | Simple3D
    (*

    Uses all three dimensions with simple loops

    *)
  4. | FullState
    (*

    Uses block/thread indices or shared memory - needs full state

    *)

Kernel execution strategy based on dimension analysis

val kernel_exec_strategy : Sarek_typed_ast.tkernel -> exec_strategy

Determine the optimal execution strategy for a kernel