Module Sarek_ppx_lib.Sarek_core_primitives

type variance =
  1. | Uniform
    (*

    Same value for all threads in grid

    *)
  2. | BlockVarying
    (*

    Uniform within block, varies between blocks

    *)
  3. | WarpVarying
    (*

    Uniform within warp, varies between warps

    *)
  4. | ThreadVarying
    (*

    Varies per thread

    *)

Variance levels in the GPU execution model. Forms a lattice: Uniform ≤ BlockVarying ≤ WarpVarying ≤ ThreadVarying

type convergence =
  1. | NoEffect
    (*

    Does not affect convergence

    *)
  2. | ConvergencePoint
    (*

    All workgroup threads must reach together

    *)
  3. | WarpConvergence
    (*

    All warp threads must reach together

    *)

Convergence requirements for synchronization primitives

type purity =
  1. | Pure
    (*

    No side effects, can CSE/DCE

    *)
  2. | Impure
    (*

    Has side effects

    *)
  3. | Atomic
    (*

    Atomic memory operation

    *)

Purity classification for optimization

type primitive = {
  1. name : string;
  2. typ : Sarek_types.typ;
  3. variance : variance;
  4. convergence : convergence;
  5. purity : purity;
  6. category : string;
    (*

    For documentation/grouping

    *)
}

A core primitive definition with compile-time semantics

val primitives : primitive list

All registered core primitives

val find : string -> primitive option

Lookup a primitive by name

val find_exn : string -> primitive
val is_core_primitive : string -> bool

Predicates

val is_thread_varying : string -> bool
val is_warp_varying : string -> bool

Check if a primitive has warp-level or finer variance (WarpVarying or ThreadVarying)

val is_convergence_point : string -> bool
val is_warp_convergence_point : string -> bool

Check if a primitive requires warp-level convergence

val requires_convergence : string -> bool

Check if a primitive requires any convergence (block or warp level)

val is_pure : string -> bool
val is_atomic : string -> bool

Check if a primitive is an atomic memory operation

val variance_of : string -> variance option

Get variance of a named primitive

val join_variance : variance -> variance -> variance

Variance lattice operations

val variance_leq : variance -> variance -> bool
val primitives_in_category : string -> primitive list

Get all primitives in a category

val pp_variance : Stdlib.Format.formatter -> variance -> unit

Pretty printing

val pp_convergence : Stdlib.Format.formatter -> convergence -> unit
val pp_purity : Stdlib.Format.formatter -> purity -> unit
val pp_primitive : Stdlib.Format.formatter -> primitive -> unit