Polymorphic Matrix Transpose
One of Sarek’s most powerful features is polymorphism. Unlike C/CUDA where you often write separate kernels for float, double, and int, Sarek allows you to write generic functions that can be reused across different types.
This example demonstrates a generic transpose function that works on int32, float32, and even custom record types.
1. Generic Helper Function
First, we define a generic helper function using [@sarek.module]. This tells Sarek this function is intended for GPU execution but will be compiled only when used inside a kernel.
open Sarek
(* A generic transpose logic that works for any type 'a *)
let[@sarek.module] do_transpose (input : 'a vector) (output : 'a vector)
(width : int) (height : int) (tid : int) =
let n = width * height in
if tid < n then begin
let col = tid mod width in
let row = tid / width in
let in_idx = (row * width) + col in
let out_idx = (col * height) + row in
(* Reads and writes 'a - works for floats, ints, structs... *)
output.(out_idx) <- input.(in_idx)
end
2. Defining Concrete Kernels
We can now define specific kernels that “monomorphize” (specialize) the generic helper for specific data types.
For Basic Types (float32, int32)
(* Kernel specialized for Float32 *)
let%kernel transpose_float32 (input : float32 vector) (output : float32 vector)
(width : int32) (height : int32) =
let tid = get_global_id 0 in
do_transpose input output width height tid
(* Kernel specialized for Int32 *)
let%kernel transpose_int32 (input : int32 vector) (output : int32 vector)
(width : int32) (height : int32) =
let tid = get_global_id 0 in
do_transpose input output width height tid
For Custom Records (structs)
Sarek also supports custom record types. We define the type with [@@sarek.type] and can immediately use our generic transpose logic on it.
(* Define a custom GPU-compatible record *)
type point3d = {
x : float32;
y : float32;
z : float32
} [@@sarek.type]
(* Kernel specialized for Point3D records *)
let%kernel transpose_point3d (input : point3d vector) (output : point3d vector)
(width : int32) (height : int32) =
let tid = get_global_id 0 in
(* Sarek automatically handles the structure layout and memory copying *)
do_transpose input output width height tid
3. Host Code
The host code looks standard, but notice how we handle the custom point3d vector.
let run_polymorphic_tests () =
let width, height = 1024, 1024 in
let n = width * height in
let device = Device.get_default () in
let block = (256, 1, 1) in
let grid = ((n + 255)/256, 1, 1) in
(* 1. Run Float32 Transpose *)
let a = Vector.create Float32 n in
let b = Vector.create Float32 n in
(* ... init a ... *)
Execute.run transpose_float32 ~device ~grid ~block
[Vec a; Vec b; Int32 width; Int32 height];
(* 2. Run Custom Struct Transpose *)
(* Create a vector for our custom type *)
let points_in = Vector.create_custom
(module struct type t = point3d let size = 12 end) n in
let points_out = Vector.create_custom
(module struct type t = point3d let size = 12 end) n in
(* Initialize with OCaml records *)
Vector.set points_in 0 { x=1.0; y=2.0; z=3.0 };
(* Run the same logic on structs! *)
Execute.run transpose_point3d ~device ~grid ~block
[Vec points_in; Vec points_out; Int32 width; Int32 height]
Why this matters
In CUDA or OpenCL C, supporting point3d would require writing a new kernel transpose_point3d_kernel and manually handling the struct fields, or using complex C++ templates. In Sarek, the compiler handles the type specialization, structure layout, and memory access patterns automatically.