Module Sarek_ir_types

Sarek_ir_types - Pure type definitions for GPU kernel IR

This module contains only type definitions with no external dependencies. Used by spoc_framework for typed generate_source signature.

type memspace =
  1. | Global
  2. | Shared
  3. | Local

Memory spaces

type elttype =
  1. | TInt32
  2. | TInt64
  3. | TFloat32
  4. | TFloat64
  5. | TBool
  6. | TUnit
  7. | TRecord of string * (string * elttype) list
    (*

    Record type: name and field list

    *)
  8. | TVariant of string * (string * elttype list) list
    (*

    Variant type: name and constructor list with arg types

    *)
  9. | TArray of elttype * memspace
    (*

    Array type with element type and memory space

    *)
  10. | TVec of elttype
    (*

    Vector (GPU array parameter)

    *)

Element types

type var = {
  1. var_name : string;
  2. var_id : int;
  3. var_type : elttype;
  4. var_mutable : bool;
}

Variables with type info

type const =
  1. | CInt32 of int32
  2. | CInt64 of int64
  3. | CFloat32 of float
  4. | CFloat64 of float
  5. | CBool of bool
  6. | CUnit

Constants

type binop =
  1. | Add
  2. | Sub
  3. | Mul
  4. | Div
  5. | Mod
  6. | Eq
  7. | Ne
  8. | Lt
  9. | Le
  10. | Gt
  11. | Ge
  12. | And
  13. | Or
  14. | Shl
  15. | Shr
  16. | BitAnd
  17. | BitOr
  18. | BitXor

Binary operators

type unop =
  1. | Neg
  2. | Not
  3. | BitNot

Unary operators

type for_dir =
  1. | Upto
  2. | Downto

Loop direction

type pattern =
  1. | PConstr of string * string list
  2. | PWild

Match pattern

type expr =
  1. | EConst of const
  2. | EVar of var
  3. | EBinop of binop * expr * expr
  4. | EUnop of unop * expr
  5. | EArrayRead of string * expr
    (*

    arridx

    *)
  6. | EArrayReadExpr of expr * expr
    (*

    base_expridx for complex bases

    *)
  7. | ERecordField of expr * string
    (*

    r.field

    *)
  8. | EIntrinsic of string list * string * expr list
    (*

    module path, name, args

    *)
  9. | ECast of elttype * expr
  10. | ETuple of expr list
  11. | EApp of expr * expr list
  12. | ERecord of string * (string * expr) list
    (*

    Record construction: type name, field values

    *)
  13. | EVariant of string * string * expr list
    (*

    Variant construction: type name, constructor, args

    *)
  14. | EArrayLen of string
    (*

    Array length intrinsic

    *)
  15. | EArrayCreate of elttype * expr * memspace
    (*

    elem type, size, memspace

    *)
  16. | EIf of expr * expr * expr
    (*

    condition, then, else - value-returning if

    *)
  17. | EMatch of expr * (pattern * expr) list
    (*

    scrutinee, cases - value-returning match

    *)

Expressions (pure, no side effects)

type lvalue =
  1. | LVar of var
  2. | LArrayElem of string * expr
  3. | LArrayElemExpr of expr * expr
  4. | LRecordField of lvalue * string

L-values (assignable locations)

type stmt =
  1. | SAssign of lvalue * expr
  2. | SSeq of stmt list
  3. | SIf of expr * stmt * stmt option
  4. | SWhile of expr * stmt
  5. | SFor of var * expr * expr * for_dir * stmt
  6. | SMatch of expr * (pattern * stmt) list
  7. | SReturn of expr
  8. | SBarrier
    (*

    Block-level barrier (__syncthreads)

    *)
  9. | SWarpBarrier
    (*

    Warp-level sync (__syncwarp)

    *)
  10. | SExpr of expr
    (*

    Side-effecting expression

    *)
  11. | SEmpty
  12. | SLet of var * expr * stmt
    (*

    Let binding: let v = e in body

    *)
  13. | SLetMut of var * expr * stmt
    (*

    Mutable let: let v = ref e in body

    *)
  14. | SPragma of string list * stmt
    (*

    Pragma hints wrapping a statement

    *)
  15. | SMemFence
    (*

    Memory fence (threadfence)

    *)
  16. | SBlock of stmt
    (*

    Scoped block - creates a C scope for variable isolation

    *)
  17. | SNative of {
    1. gpu : framework:string -> string;
      (*

      Generate GPU code for framework

      *)
    2. ocaml : ocaml_closure;
      (*

      Typed OCaml fallback

      *)
    }
    (*

    Inline native GPU code with OCaml fallback

    *)

Statements (imperative, side effects)

and decl =
  1. | DParam of var * array_info option
  2. | DLocal of var * expr option
  3. | DShared of string * elttype * expr option

Declarations

and array_info = {
  1. arr_elttype : elttype;
  2. arr_memspace : memspace;
}
and helper_func = {
  1. hf_name : string;
  2. hf_params : var list;
  3. hf_ret_type : elttype;
  4. hf_body : stmt;
}

Helper function (device function called from kernel)

and native_arg =
  1. | NA_Int32 of int32
  2. | NA_Int64 of int64
  3. | NA_Float32 of float
  4. | NA_Float64 of float
  5. | NA_Vec of {
    1. length : int;
    2. elem_size : int;
    3. type_name : string;
    4. get_f32 : int -> float;
    5. set_f32 : int -> float -> unit;
    6. get_f64 : int -> float;
    7. set_f64 : int -> float -> unit;
    8. get_i32 : int -> int32;
    9. set_i32 : int -> int32 -> unit;
    10. get_i64 : int -> int64;
    11. set_i64 : int -> int64 -> unit;
    12. get_any : int -> Stdlib.Obj.t;
    13. set_any : int -> Stdlib.Obj.t -> unit;
    14. get_vec : unit -> Stdlib.Obj.t;
    }

Native argument type for kernel execution. Typed arguments without Obj.t - used by PPX-generated native functions.

and ocaml_closure = {
  1. run : block:(int * int * int) -> grid:(int * int * int) -> native_arg array -> unit;
}

Typed Helpers for Custom Types

These functions encapsulate the Obj operations so that PPX-generated code doesn't need to use Obj directly. The type parameter is inferred from context.

val vec_get_custom : 'a. native_arg -> int -> 'a

Get element from NA_Vec as custom type. Type is inferred from usage.

val vec_set_custom : 'a. native_arg -> int -> 'a -> unit

Set element in NA_Vec from custom type. Type is inferred from usage.

val vec_length : native_arg -> int

Get length from NA_Vec

val vec_as_vector : 'a. native_arg -> 'a

Get underlying vector. Used when passing vectors to functions/intrinsics that need the actual Vector.t type. Returns type-erased value that the caller casts to the appropriate Vector.t type.

type native_fn_t =
  1. | NativeFn of parallel:bool -> block:(int * int * int) -> grid:(int * int * int) -> native_arg array -> unit

Native function type for V2 execution. Uses typed native_arg.

type kernel = {
  1. kern_name : string;
  2. kern_params : decl list;
  3. kern_locals : decl list;
  4. kern_body : stmt;
  5. kern_types : (string * (string * elttype) list) list;
    (*

    Record type definitions: (type_name, (field_name, field_type); ...)

    *)
  6. kern_variants : (string * (string * elttype list) list) list;
    (*

    Variant type definitions: (type_name, (constructor_name, payload_types); ...)

    *)
  7. kern_funcs : helper_func list;
    (*

    Helper functions defined in kernel scope

    *)
  8. kern_native_fn : native_fn_t option;
    (*

    Optional pre-compiled native function for CPU execution

    *)
}

Kernel representation