Post

CuTe DSL Basics: A Practical Introduction

CuTe DSL Basics: A Practical Introduction

CuTe DSL Basics — From Hello to Tiled Kernels

This tutorial turns the CuTe DSL script snippets into a connected story: we start with a first GPU kernel, learn how dynamic printing and data types work, build and slice tensors, then graduate to vectorized kernels with tiling, thread–value layouts, and finally CUDA Graphs. Code blocks are ready to run as-is.

1) First kernel and launch (hello world)

  • Define a GPU kernel with @cute.kernel; get thread index via cute.arch.thread_idx().
  • Launch kernels from a @cute.jit host function using .launch(grid=..., block=...).
  • Initialize CUDA with cutlass.cuda.initialize_cuda_context() (for clarity/control).
  • Compile-ahead and re-use with cute.compile.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import cutlass
import cutlass.cute as cute

@cute.kernel
def hello_kernel():
    tidx, _, _ = cute.arch.thread_idx()
    if tidx == 0:
        cute.printf("Hello from GPU")

@cute.jit
def hello_world():
    cutlass.cuda.initialize_cuda_context()
    hello_kernel().launch(grid=(1, 1, 1), block=(32, 1, 1))

compiled = cute.compile(hello_world)
compiled()

Next we’ll see how values show up in logs, and why cute.printf is the right tool inside JIT code.

2) Printing and numeric types: static vs dynamic

  • Python print is static (compile-time) and shows ? for dynamic values.
  • cute.printf prints dynamic (runtime) values for CuTe types/layouts.
  • Device-side printing is supported inside @cute.kernel/@cute.jit.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import cutlass
import cutlass.cute as cute

@cute.jit
def print_demo(a: cutlass.Int32, b: cutlass.Constexpr[int]):
    print("static a:", a)   # => ? (dynamic)
    print("static b:", b)   # => 2
    cute.printf("dynamic a: {}", a)
    cute.printf("dynamic b: {}", b)
    layout = cute.make_layout((a, b))
    print("static layout:", layout)       # (?,2):(1,?)
    cute.printf("dynamic layout: {}", layout)  # (8,2):(1,8)

print_demo(cutlass.Int32(8), 2)

Numeric types in JIT code should be explicit and can be converted at runtime.

1
2
3
4
5
6
7
8
9
10
11
12
import cutlass
import cutlass.cute as cute

@cute.jit
def dtypes():
    a = cutlass.Int32(42)
    b = a.to(cutlass.Float32)
    c = b + 0.5
    d = c.to(cutlass.Int32)
    cute.printf("a={}, b={}, c={}, d={}", a, b, c, d)

dtypes()

With printing and types in hand, let’s talk about tensors and their layouts.

3) Tensors and layouts: build, index, slice

  • Tensor = Engine ∘ Layout: build with cute.make_tensor(ptr_or_engine, layout).
  • Create layouts with cute.make_layout(shape, stride=...).
  • Interop with frameworks via DLPack; print via cute.print_tensor.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack

@cute.jit
def tensor_demo(t: cute.Tensor):
    cute.printf("t[0,0] = {}", t[0, 0])
    sub = t[(None, 0)]   # First row view
    frag = cute.make_fragment(sub.layout, sub.element_type)
    frag.store(sub.load())
    cute.print_tensor(frag)

arr = torch.arange(0, 12, dtype=torch.float32).reshape(3, 4)
tensor_demo(from_dlpack(arr))

Layout:Stride notation (shape:stride)

  • CuTe prints layouts using (shape):(stride) to show how logical indices map to linear offsets.
  • Dynamic values appear with cutlass.Int32(...) at runtime; statically unknown values may show as ? in static prints.
  • Example row-major for (M, N) is (M, N):(N, 1); column-major is (M, N):(1, M).
1
2
3
4
5
6
Shape :  (4,2)
Stride:  (1,4)
  0   4
  1   5
  2   6
  3   7

is a 4x2 column-major layout with stride-1 down the columns and stride-4 across the rows, and

1
2
3
4
5
6
Shape :  (4,2)
Stride:  (2,1)
  0   1
  2   3
  4   5
  6   7

is a 4x2 row-major layout with stride-2 down the columns and stride-1 across the rows. Majorness is simply which mode has stride-1.

1
2
3
4
5
6
7
8
9
10
11
12
13
import cutlass
import cutlass.cute as cute

@cute.jit
def layout_stride_demo(M: cutlass.Int32, N: cutlass.Int32):
    row_major = cute.make_layout((M, N), stride=(N, cutlass.Int32(1)))
    col_major = cute.make_layout((M, N), stride=(cutlass.Int32(1), M))
    print("static row-major:", row_major)
    print("static col-major:", col_major)
    cute.printf("dynamic row-major: {}", row_major)
    cute.printf("dynamic col-major: {}", col_major)

layout_stride_demo(cutlass.Int32(4), cutlass.Int32(3))

Additional slicing follows the same ideas: views are created first, and .load() materializes data.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack

@cute.jit
def slicing_examples(t: cute.Tensor):
    # Scalar access
    cute.printf("t[1,2] = {}", t[1, 2])

    # Entire second row (shape: (N,)) using (None, row_index)
    row = t[(None, 1)]
    row_frag = cute.make_fragment(row.layout, row.element_type)
    row_frag.store(row.load())
    print("Second row:")
    cute.print_tensor(row_frag)

    # Entire third column (shape: (M,)) using (col_index, None)
    col = t[(2, None)]
    col_frag = cute.make_fragment(col.layout, col.element_type)
    col_frag.store(col.load())
    print("Third column:")
    cute.print_tensor(col_frag)

    # Printing the first row directly (*t[2] == *t[2, 0])
    cute.printf(
        "t[2] = {} (equivalent to t[{}])",
        t[2],
        cute.make_identity_tensor(t.layout.shape)[2]
    )

# 4x3 example tensor
arr = torch.arange(12, dtype=torch.float32).reshape(4, 3)
slicing_examples(from_dlpack(arr))

4) Register tensors (TensorSSA) and reductions

  • TensorSSA is a register-level value enabling vectorized elementwise ops.
  • Load/store between memory tensors and SSA: vec = tensor.load(), tensor.store(vec).
  • Reductions via .reduce(op, init, reduction_profile=...).
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import numpy as np
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack

@cute.jit
def ssa_add(dst: cute.Tensor, x: cute.Tensor, y: cute.Tensor):
    xv = x.load()
    yv = y.load()
    dst.store(xv + yv)
    cute.print_tensor(dst)

X = np.ones((2, 3), dtype=np.float32)
Y = np.full((2, 3), 2.0, dtype=np.float32)
Z = np.zeros((2, 3), dtype=np.float32)
ssa_add(from_dlpack(Z), from_dlpack(X), from_dlpack(Y))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack

@cute.jit
def ssa_reduce(a: cute.Tensor):
    v = a.load()
    # Sum of all elements
    total = v.reduce(cute.ReductionOp.ADD, 0.0, reduction_profile=0)
    cute.printf("total sum = {}", total)

    # Row-wise sum -> shape (rows,)
    row_sum = v.reduce(cute.ReductionOp.ADD, 0.0, reduction_profile=(None, 1))
    row_frag = cute.make_fragment(row_sum.shape, cutlass.Float32)
    row_frag.store(row_sum)
    print("Row-wise sum:")
    cute.print_tensor(row_frag)

    # Column-wise sum -> shape (cols,)
    col_sum = v.reduce(cute.ReductionOp.ADD, 0.0, reduction_profile=(1, None))
    col_frag = cute.make_fragment(col_sum.shape, cutlass.Float32)
    col_frag.store(col_sum)
    print("Column-wise sum:")
    cute.print_tensor(col_frag)

A = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
ssa_reduce(from_dlpack(A))

Reduce arguments

  • op: A cute.ReductionOp enum specifying the operation (ADD, MUL, MAX, MIN, …)
  • init: Initial accumulator value (also sets accumulator dtype)
  • reduction_profile: Axes to reduce — 0 for all axes; or a tuple with 1 to reduce / None to keep
1
2
3
4
5
6
7
8
# Example: reduce a (M,N) TensorSSA along columns (axis-1)
@cute.jit
def reduce_cols(a: cute.Tensor):
    v = a.load()
    col_sum = v.reduce(cute.ReductionOp.ADD, 0.0, reduction_profile=(1, None))
    frag = cute.make_fragment(col_sum.shape, cutlass.Float32)
    frag.store(col_sum)
    cute.print_tensor(frag)

With SSA basics covered, let’s scale up: partition tensors for vectorized per-thread work.

5) Vectorized add with zipped_divide

Naively mapping one thread per element is simple but not fast. We partition tensors into per-thread tiles with cute.zipped_divide, slice a specific tile (None, (mi, ni)), load vectors, compute, and store.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack

@cute.kernel
def vadd_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, _, _ = cute.arch.block_idx()
    bdim, _, _ = cute.arch.block_dim()
    idx = bidx * bdim + tidx
    m, n = gA.shape[1]          # thread-domain
    mi = idx // n
    ni = idx % n
    gC[(None, (mi, ni))] = gA[(None, (mi, ni))].load() + gB[(None, (mi, ni))].load()

@cute.jit
def vadd(A: cute.Tensor, B: cute.Tensor, C: cute.Tensor):
    gA = cute.zipped_divide(A, (1, 4))
    gB = cute.zipped_divide(B, (1, 4))
    gC = cute.zipped_divide(C, (1, 4))
    threads = 256
    vadd_kernel(gA, gB, gC).launch(
        grid=(cute.size(gC, mode=[1]) // threads, 1, 1),
        block=(threads, 1, 1),
    )

M, N = 1024, 1024
a = torch.randn(M, N, device="cuda", dtype=torch.float16)
b = torch.randn(M, N, device="cuda", dtype=torch.float16)
c = torch.zeros(M, N, device="cuda", dtype=torch.float16)
vadd_compiled = cute.compile(vadd, from_dlpack(a), from_dlpack(b), from_dlpack(c))
vadd_compiled(from_dlpack(a), from_dlpack(b), from_dlpack(c))

Tile semantics (zipped_divide)

  • cute.zipped_divide(tensor, (1, 4)) produces a 2-mode tiled tensor: mode-0 is the per-thread tile (1,4) and mode-1 indexes tiles across the original tensor.
  • Printed as (shape):(stride) by mode, e.g. the tiled tensor may show: ((1,4),(2048,512)):((0,1),(2048,4))
  • Slicing with (None, (mi, ni)) selects a single per-thread tile, yielding a (1,4) vector on the right-hand side.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack

@cute.jit
def zdiv_demo(mA: cute.Tensor):
    # Partition into per-thread tiles of (1,4)
    gA = cute.zipped_divide(mA, (1, 4))
    print("Tiled tensor gA:", gA)

    # Inspect a specific tile (mi, ni)
    mi = cutlass.Int32(0)
    ni = cutlass.Int32(0)
    tile = gA[(None, (mi, ni))]
    print("Per-thread tile slice:", tile)

    # Materialize tile for printing
    frag = cute.make_fragment(tile.layout, tile.element_type)
    frag.store(tile.load())
    cute.print_tensor(frag)

A = torch.arange(0, 8*8, dtype=torch.float32).reshape(8, 8)
zdiv_demo(from_dlpack(A))

6) TV layout and per-thread fragments

cute.make_layout_tv(thread_layout, value_layout) maps (tid, vid) into a (TileM, TileN) tile. Composing a block-local tensor with this layout yields per-thread fragments.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack

@cute.kernel
def tv_add_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor, tv_layout: cute.Layout):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, _, _ = cute.arch.block_idx()

    # Select the thread-block tile
    blk_coord = ((None, None), bidx)
    blkA = gA[blk_coord]
    blkB = gB[blk_coord]
    blkC = gC[blk_coord]

    # Compose TV layout to map (tid, vid) -> physical address
    tidfrgA = cute.composition(blkA, tv_layout)
    tidfrgB = cute.composition(blkB, tv_layout)
    tidfrgC = cute.composition(blkC, tv_layout)

    # Slice per-thread vector
    thr_coord = (tidx, None)
    thrA = tidfrgA[thr_coord]
    thrB = tidfrgB[thr_coord]
    thrC = tidfrgC[thr_coord]

    thrC[None] = thrA.load() + thrB.load()

@cute.jit
def tv_add(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor):
    # Thread (4,32): 4 groups along M (row), 32 contiguous threads along N (col)
    # Value (4,8): each thread handles 4 rows x 8 contiguous values
    thr_layout = cute.make_layout((4, 32), stride=(32, 1))
    val_layout = cute.make_layout((4, 8), stride=(8, 1))
    tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)

    # Tile tensors into thread-block tiles
    gA = cute.zipped_divide(mA, tiler_mn)
    gB = cute.zipped_divide(mB, tiler_mn)
    gC = cute.zipped_divide(mC, tiler_mn)

    # One block per tile in mode-1; threads per block = TV threads
    tv_add_kernel(gA, gB, gC, tv_layout).launch(
        grid=[cute.size(gC, mode=[1]), 1, 1],
        block=[cute.size(tv_layout, mode=[0]), 1, 1],
    )

M, N = 2048, 2048
a = torch.randn(M, N, device="cuda", dtype=torch.float16)
b = torch.randn(M, N, device="cuda", dtype=torch.float16)
c = torch.zeros(M, N, device="cuda", dtype=torch.float16)
tv_add_compiled = cute.compile(tv_add, from_dlpack(a), from_dlpack(b), from_dlpack(c))
tv_add_compiled(from_dlpack(a), from_dlpack(b), from_dlpack(c))

Explanation

  • make_layout_tv(thr_layout, val_layout) maps (tid, vid) to a (TileM, TileN) tile; composing a block-local tensor yields per-thread fragments.
  • Threads load/store vector fragments defined by val_layout; mapping along N ensures coalesced access; row-grouping across M is handled by the thread layout.
  • Grid uses one block per tile (mode-1 of zipped tensors); block size equals number of TV threads size(tv_layout, mode=[0]).

TV layout mapping (diagram)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
Tiled to Thread Block:

    ((16,256),(128,8))  : ((2048,1),(32768,256))
     ~~~~~~~~  ~~~~~~      ~~~~~~~~
        |        |            |
        |        |            |
        |        `------------------------> Number of Thread Blocks
        |                     |
        |                     |
        `--------------------'
                  |
                  V
             Thread Block
                 Tile

Sliced to Thread-Block local sub-tensor (16,256):  gA[((None, None), bidx)]
Sliced to Thread local sub-tensor (4,8):           tidfrgA[(tidx, None)]

Notes

  • Thread-block tile (TileM,TileN) = (thr_M×val_M, thr_N×val_N)
  • Threads per block = thr_M×thr_N; values per thread = val_M×val_N; products match TileM×TileN

7) Layout algebra: coalesce, composition, divide, product

Layouts map coordinates to indices. Compose layouts to reshape/reorder; divide to tile; coalesce to flatten compatible modes; and replicate via products.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import cutlass
import cutlass.cute as cute

@cute.jit
def layout_demo():
    # Composition and coalesce
    A = cute.make_layout((6, 2), stride=(cutlass.Int32(8), 2))
    B = cute.make_layout((4, 3), stride=(3, 1))
    R = cute.composition(A, B)
    C = cute.coalesce(R)

    # Logical divide with tiler
    L = cute.make_layout((9, (4, 8)), stride=(59, (13, 1)))
    T = (cute.make_layout(3, stride=3),
         cute.make_layout((2, 4), stride=(1, 8)))
    D = cute.logical_divide(L, tiler=T)

    # Logical product/repetition
    P = cute.logical_product(
        cute.make_layout((2, 2), stride=(4, 1)),
        cute.make_layout(6, stride=1),
    )

    cute.printf("A={}, B={}, R={}, C={}", A, B, R, C)
    cute.printf("Divide: {}", D)
    cute.printf("Product: {}", P)

layout_demo()

8) PyTorch CUDA Graphs integration

CUDA Graphs remove CPU launch overhead. JIT-compile functions that accept a CUstream before capture, then replay without re-JITing.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import cutlass
import cutlass.cute as cute
from cuda.bindings.driver import CUstream
from torch.cuda import current_stream

@cute.kernel
def hello_kernel():
    cute.printf("Hello")

@cute.jit
def hello_launch(stream: CUstream):
    hello_kernel().launch(grid=[1, 1, 1], block=[1, 1, 1], stream=stream)

s = current_stream()
compiled = cute.compile(hello_launch, CUstream(s.cuda_stream))

g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
    compiled(CUstream(current_stream().cuda_stream))

g.replay()
torch.cuda.synchronize()

9) Patterns and checklist

  • Compile before CUDA graph capture; pass CUstream explicitly when needed
  • Use cute.printf for dynamic values; Python print is static
  • Prefer explicit numeric types (e.g., cutlass.Int32) in JIT code; convert with .to(...)
  • Vectorize with zipped_divide + slice (None, (mi, ni)) + .load()/.store()
  • Use make_layout_tv to decouple thread/value tiling; compose with block-local tensors
  • Size grids/blocks with cute.size(layout_or_tensor, mode=[...])
  • Dynamic layouts may not coalesce; beware differences vs static

Minimal checklist

  • Initialize CUDA context when needed
  • Define kernels with @cute.kernel; launch from @cute.jit functions
  • Convert external arrays/tensors via from_dlpack
  • Inspect/print tensors and layouts (cute.print_tensor, cute.printf)
  • Build layouts/tilers (make_layout, make_layout_tv), and use divide/composition/product as required
  • Use cute.compile to AOT-compile and reuse
  • Benchmark with torch.cuda.Event when measuring kernels or graphs

10) Key APIs quick reference

  • @cute.kernel, @cute.jit, cute.compile, kernel().launch(...)
  • cute.printf, cute.print_tensor, cutlass.Constexpr
  • Numeric types and conversion: cutlass.Int32/Float32/..., .to(...)
  • Tensor and layout construction: cute.make_layout, cute.make_tensor, from_dlpack, cute.make_identity_tensor
  • Vectorization and tiling: cute.zipped_divide, cute.size, cute.make_layout_tv, cute.composition
  • Register ops: TensorSSA, .load(), .store(...), .reduce(...), cute.ReductionOp
  • CUDA Graphs interop: cuda.bindings.driver.CUstream, torch.cuda.current_stream, torch.cuda.CUDAGraph

Glossary

  • Layout: Mapping from coordinates to offsets (shape + stride), supports static/dynamic params
  • Engine: Pointer-like object supporting offset/deref; composes with layout to form a tensor
  • TensorSSA: Register-level tensor value enabling vectorized elementwise ops and reductions
  • TV layout: Thread/value layout mapping (tid, vid) into logical tensor tile coordinates
  • Tiler: Layout/tuple guiding by-mode composition/division/product
  • Constexpr: Compile-time constant distinct from runtime numeric types
  • DLPack interop: Zero-copy exchange with frameworks (from_dlpack)
This post is licensed under CC BY 4.0 by the author.