Module Sarek_typed_ast

type tparam = {
  1. tparam_name : string;
  2. tparam_type : Sarek_types.typ;
  3. tparam_index : int;
  4. tparam_is_vec : bool;
    (*

    Is this a vector parameter?

    *)
  5. tparam_id : int;
    (*

    Variable ID for this parameter

    *)
}

Typed kernel parameter - defined before texpr so TELetRec can reference it

type texpr = {
  1. te : texpr_desc;
  2. ty : Sarek_types.typ;
    (*

    Always resolved, never contains unbound TVar

    *)
  3. te_loc : Sarek_ast.loc;
}

Typed expression - every node has its type

and texpr_desc =
  1. | TEUnit
  2. | TEBool of bool
  3. | TEInt of int
  4. | TEInt32 of int32
  5. | TEInt64 of int64
  6. | TEFloat of float
  7. | TEDouble of float
  8. | TEVar of string * int
    (*

    name, variable id

    *)
  9. | TEVecGet of texpr * texpr
  10. | TEVecSet of texpr * texpr * texpr
  11. | TEArrGet of texpr * texpr
  12. | TEArrSet of texpr * texpr * texpr
  13. | TEFieldGet of texpr * string * int
    (*

    expr, field name, field index

    *)
  14. | TEFieldSet of texpr * string * int * texpr
  15. | TEBinop of Sarek_ast.binop * texpr * texpr
  16. | TEUnop of Sarek_ast.unop * texpr
  17. | TEApp of texpr * texpr list
  18. | TEAssign of string * int * texpr
    (*

    name, var_id, value

    *)
  19. | TELet of string * int * texpr * texpr
    (*

    name, var_id, value, body

    *)
  20. | TELetRec of string * int * tparam list * texpr * texpr
    (*

    name, fn_id, params, fn_body, continuation

    *)
  21. | TELetMut of string * int * texpr * texpr
  22. | TEIf of texpr * texpr * texpr option
  23. | TEFor of string * int * texpr * texpr * Sarek_ast.for_dir * texpr
    (*

    var, id, lo, hi, dir, body

    *)
  24. | TEWhile of texpr * texpr
  25. | TESeq of texpr list
    (*

    Flattened sequence

    *)
  26. | TEMatch of texpr * (tpattern * texpr) list
  27. | TERecord of string * (string * texpr) list
    (*

    type_name, fields

    *)
  28. | TEConstr of string * string * texpr option
    (*

    type_name, constr_name, arg

    *)
  29. | TETuple of texpr list
  30. | TEReturn of texpr
  31. | TECreateArray of texpr * Sarek_types.typ * Sarek_types.memspace
  32. | TEGlobalRef of string * Sarek_types.typ
    (*

    External ref with its type

    *)
  33. | TENative of {
    1. gpu : Ppxlib.expression;
      (*

      fun dev -> "cuda/opencl code"

      *)
    2. ocaml : Ppxlib.expression;
      (*

      OCaml fallback for interpreter/native

      *)
    }
    (*

    Native code with GPU and OCaml expressions

    *)
  34. | TEPragma of string list * texpr
    (*

    pragma options body

    *)
  35. | TEIntrinsicConst of Sarek_env.intrinsic_ref
    (*

    Reference to intrinsic constant

    *)
  36. | TEIntrinsicFun of Sarek_env.intrinsic_ref * Sarek_core_primitives.convergence option * texpr list
    (*

    Reference to intrinsic function, convergence requirement, args

    *)
  37. | TELetShared of string * int * Sarek_types.typ * texpr option * texpr
    (*

    name, var_id, elem_type, size_opt, body

    *)
  38. | TESuperstep of string * bool * texpr * texpr
    (*

    name, divergent, body, continuation

    *)
  39. | TEOpen of string list * texpr
    (*

    let open M.N in e

    *)
and tpattern = {
  1. tpat : tpattern_desc;
  2. tpat_ty : Sarek_types.typ;
  3. tpat_loc : Sarek_ast.loc;
}
and tpattern_desc =
  1. | TPAny
  2. | TPVar of string * int
    (*

    name, var_id

    *)
  3. | TPConstr of string * string * tpattern option
    (*

    type_name, constr, arg

    *)
  4. | TPTuple of tpattern list
type ttype_decl =
  1. | TTypeRecord of {
    1. tdecl_name : string;
    2. tdecl_module : string option;
    3. tdecl_fields : (string * Sarek_types.typ * bool) list;
      (*

      name, type, mutable

      *)
    4. tdecl_loc : Sarek_ast.loc;
    }
  2. | TTypeVariant of {
    1. tdecl_name : string;
    2. tdecl_module : string option;
    3. tdecl_constructors : (string * Sarek_types.typ option) list;
    4. tdecl_loc : Sarek_ast.loc;
    }

Typed type declarations

type tmodule_item =
  1. | TMConst of string * int * Sarek_types.typ * texpr
    (*

    let name : ty = expr, var id

    *)
  2. | TMFun of string * bool * tparam list * texpr
    (*

    TMFun(name, is_recursive, params, body)

    *)

Typed module item

type tkernel = {
  1. tkern_name : string option;
  2. tkern_type_decls : ttype_decl list;
  3. tkern_module_items : tmodule_item list;
    (*

    Number of module items that are external (from @sarek.module). The first N items in tkern_module_items are external, the rest are inline (defined within the kernel payload). Native code gen should skip external items - they're already available via the OCaml module system.

    *)
  4. tkern_external_item_count : int;
  5. tkern_params : tparam list;
  6. tkern_body : texpr;
  7. tkern_return_type : Sarek_types.typ;
  8. tkern_loc : Sarek_ast.loc;
}

Typed kernel definition

type tdecl_item =
  1. | TDType of ttype_decl
  2. | TDFun of string * tparam list * texpr
  3. | TDConst of string * int * Sarek_types.typ * texpr
val var_id_counter : int Stdlib.Atomic.t

Variable ID generator (thread-safe)

val fresh_var_id : unit -> int
val reset_var_id_counter : unit -> unit

Create a simple typed expression

val resolve_type : Sarek_types.typ -> Sarek_types.typ

Get the fully resolved type (follow all links)

val pp_texpr : Stdlib.Format.formatter -> texpr -> unit

Pretty printing

val pp_tpattern : Stdlib.Format.formatter -> tpattern -> unit
val pp_tparam : Stdlib.Format.formatter -> tparam -> unit
val pp_tkernel : Stdlib.Format.formatter -> tkernel -> unit