CuTe DSL Basics: A Practical Introduction
CuTe DSL Basics — From Hello to Tiled Kernels
This tutorial turns the CuTe DSL script snippets into a connected story: we start with a first GPU kernel, learn how dynamic printing and data types work, build and slice tensors, then graduate to vectorized kernels with tiling, thread–value layouts, and finally CUDA Graphs. Code blocks are ready to run as-is.
1) First kernel and launch (hello world)
- Define a GPU kernel with
@cute.kernel
; get thread index viacute.arch.thread_idx()
. - Launch kernels from a
@cute.jit
host function using.launch(grid=..., block=...)
. - Initialize CUDA with
cutlass.cuda.initialize_cuda_context()
(for clarity/control). - Compile-ahead and re-use with
cute.compile
.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import cutlass
import cutlass.cute as cute
@cute.kernel
def hello_kernel():
tidx, _, _ = cute.arch.thread_idx()
if tidx == 0:
cute.printf("Hello from GPU")
@cute.jit
def hello_world():
cutlass.cuda.initialize_cuda_context()
hello_kernel().launch(grid=(1, 1, 1), block=(32, 1, 1))
compiled = cute.compile(hello_world)
compiled()
Next we’ll see how values show up in logs, and why cute.printf
is the right tool inside JIT code.
2) Printing and numeric types: static vs dynamic
- Python
print
is static (compile-time) and shows?
for dynamic values. cute.printf
prints dynamic (runtime) values for CuTe types/layouts.- Device-side printing is supported inside
@cute.kernel
/@cute.jit
.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import cutlass
import cutlass.cute as cute
@cute.jit
def print_demo(a: cutlass.Int32, b: cutlass.Constexpr[int]):
print("static a:", a) # => ? (dynamic)
print("static b:", b) # => 2
cute.printf("dynamic a: {}", a)
cute.printf("dynamic b: {}", b)
layout = cute.make_layout((a, b))
print("static layout:", layout) # (?,2):(1,?)
cute.printf("dynamic layout: {}", layout) # (8,2):(1,8)
print_demo(cutlass.Int32(8), 2)
Numeric types in JIT code should be explicit and can be converted at runtime.
1
2
3
4
5
6
7
8
9
10
11
12
import cutlass
import cutlass.cute as cute
@cute.jit
def dtypes():
a = cutlass.Int32(42)
b = a.to(cutlass.Float32)
c = b + 0.5
d = c.to(cutlass.Int32)
cute.printf("a={}, b={}, c={}, d={}", a, b, c, d)
dtypes()
With printing and types in hand, let’s talk about tensors and their layouts.
3) Tensors and layouts: build, index, slice
- Tensor = Engine ∘ Layout: build with
cute.make_tensor(ptr_or_engine, layout)
. - Create layouts with
cute.make_layout(shape, stride=...)
. - Interop with frameworks via DLPack; print via
cute.print_tensor
.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
@cute.jit
def tensor_demo(t: cute.Tensor):
cute.printf("t[0,0] = {}", t[0, 0])
sub = t[(None, 0)] # First row view
frag = cute.make_fragment(sub.layout, sub.element_type)
frag.store(sub.load())
cute.print_tensor(frag)
arr = torch.arange(0, 12, dtype=torch.float32).reshape(3, 4)
tensor_demo(from_dlpack(arr))
Layout:Stride notation (shape:stride)
- CuTe prints layouts using
(shape):(stride)
to show how logical indices map to linear offsets. - Dynamic values appear with
cutlass.Int32(...)
at runtime; statically unknown values may show as?
in static prints. - Example row-major for
(M, N)
is(M, N):(N, 1)
; column-major is(M, N):(1, M)
.
1
2
3
4
5
6
Shape : (4,2)
Stride: (1,4)
0 4
1 5
2 6
3 7
is a 4x2 column-major layout with stride-1 down the columns and stride-4 across the rows, and
1
2
3
4
5
6
Shape : (4,2)
Stride: (2,1)
0 1
2 3
4 5
6 7
is a 4x2 row-major layout with stride-2 down the columns and stride-1 across the rows. Majorness is simply which mode has stride-1.
1
2
3
4
5
6
7
8
9
10
11
12
13
import cutlass
import cutlass.cute as cute
@cute.jit
def layout_stride_demo(M: cutlass.Int32, N: cutlass.Int32):
row_major = cute.make_layout((M, N), stride=(N, cutlass.Int32(1)))
col_major = cute.make_layout((M, N), stride=(cutlass.Int32(1), M))
print("static row-major:", row_major)
print("static col-major:", col_major)
cute.printf("dynamic row-major: {}", row_major)
cute.printf("dynamic col-major: {}", col_major)
layout_stride_demo(cutlass.Int32(4), cutlass.Int32(3))
Additional slicing follows the same ideas: views are created first, and .load()
materializes data.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
@cute.jit
def slicing_examples(t: cute.Tensor):
# Scalar access
cute.printf("t[1,2] = {}", t[1, 2])
# Entire second row (shape: (N,)) using (None, row_index)
row = t[(None, 1)]
row_frag = cute.make_fragment(row.layout, row.element_type)
row_frag.store(row.load())
print("Second row:")
cute.print_tensor(row_frag)
# Entire third column (shape: (M,)) using (col_index, None)
col = t[(2, None)]
col_frag = cute.make_fragment(col.layout, col.element_type)
col_frag.store(col.load())
print("Third column:")
cute.print_tensor(col_frag)
# Printing the first row directly (*t[2] == *t[2, 0])
cute.printf(
"t[2] = {} (equivalent to t[{}])",
t[2],
cute.make_identity_tensor(t.layout.shape)[2]
)
# 4x3 example tensor
arr = torch.arange(12, dtype=torch.float32).reshape(4, 3)
slicing_examples(from_dlpack(arr))
4) Register tensors (TensorSSA) and reductions
TensorSSA
is a register-level value enabling vectorized elementwise ops.- Load/store between memory tensors and SSA:
vec = tensor.load()
,tensor.store(vec)
. - Reductions via
.reduce(op, init, reduction_profile=...)
.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import numpy as np
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
@cute.jit
def ssa_add(dst: cute.Tensor, x: cute.Tensor, y: cute.Tensor):
xv = x.load()
yv = y.load()
dst.store(xv + yv)
cute.print_tensor(dst)
X = np.ones((2, 3), dtype=np.float32)
Y = np.full((2, 3), 2.0, dtype=np.float32)
Z = np.zeros((2, 3), dtype=np.float32)
ssa_add(from_dlpack(Z), from_dlpack(X), from_dlpack(Y))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
@cute.jit
def ssa_reduce(a: cute.Tensor):
v = a.load()
# Sum of all elements
total = v.reduce(cute.ReductionOp.ADD, 0.0, reduction_profile=0)
cute.printf("total sum = {}", total)
# Row-wise sum -> shape (rows,)
row_sum = v.reduce(cute.ReductionOp.ADD, 0.0, reduction_profile=(None, 1))
row_frag = cute.make_fragment(row_sum.shape, cutlass.Float32)
row_frag.store(row_sum)
print("Row-wise sum:")
cute.print_tensor(row_frag)
# Column-wise sum -> shape (cols,)
col_sum = v.reduce(cute.ReductionOp.ADD, 0.0, reduction_profile=(1, None))
col_frag = cute.make_fragment(col_sum.shape, cutlass.Float32)
col_frag.store(col_sum)
print("Column-wise sum:")
cute.print_tensor(col_frag)
A = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
ssa_reduce(from_dlpack(A))
Reduce arguments
- op: A
cute.ReductionOp
enum specifying the operation (ADD
,MUL
,MAX
,MIN
, …) - init: Initial accumulator value (also sets accumulator dtype)
- reduction_profile: Axes to reduce —
0
for all axes; or a tuple with1
to reduce /None
to keep
1
2
3
4
5
6
7
8
# Example: reduce a (M,N) TensorSSA along columns (axis-1)
@cute.jit
def reduce_cols(a: cute.Tensor):
v = a.load()
col_sum = v.reduce(cute.ReductionOp.ADD, 0.0, reduction_profile=(1, None))
frag = cute.make_fragment(col_sum.shape, cutlass.Float32)
frag.store(col_sum)
cute.print_tensor(frag)
With SSA basics covered, let’s scale up: partition tensors for vectorized per-thread work.
5) Vectorized add with zipped_divide
Naively mapping one thread per element is simple but not fast. We partition tensors into per-thread tiles with cute.zipped_divide
, slice a specific tile (None, (mi, ni))
, load vectors, compute, and store.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
@cute.kernel
def vadd_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
bdim, _, _ = cute.arch.block_dim()
idx = bidx * bdim + tidx
m, n = gA.shape[1] # thread-domain
mi = idx // n
ni = idx % n
gC[(None, (mi, ni))] = gA[(None, (mi, ni))].load() + gB[(None, (mi, ni))].load()
@cute.jit
def vadd(A: cute.Tensor, B: cute.Tensor, C: cute.Tensor):
gA = cute.zipped_divide(A, (1, 4))
gB = cute.zipped_divide(B, (1, 4))
gC = cute.zipped_divide(C, (1, 4))
threads = 256
vadd_kernel(gA, gB, gC).launch(
grid=(cute.size(gC, mode=[1]) // threads, 1, 1),
block=(threads, 1, 1),
)
M, N = 1024, 1024
a = torch.randn(M, N, device="cuda", dtype=torch.float16)
b = torch.randn(M, N, device="cuda", dtype=torch.float16)
c = torch.zeros(M, N, device="cuda", dtype=torch.float16)
vadd_compiled = cute.compile(vadd, from_dlpack(a), from_dlpack(b), from_dlpack(c))
vadd_compiled(from_dlpack(a), from_dlpack(b), from_dlpack(c))
Tile semantics (zipped_divide)
cute.zipped_divide(tensor, (1, 4))
produces a 2-mode tiled tensor: mode-0 is the per-thread tile(1,4)
and mode-1 indexes tiles across the original tensor.- Printed as
(shape):(stride)
by mode, e.g. the tiled tensor may show:((1,4),(2048,512)):((0,1),(2048,4))
- Slicing with
(None, (mi, ni))
selects a single per-thread tile, yielding a(1,4)
vector on the right-hand side.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
@cute.jit
def zdiv_demo(mA: cute.Tensor):
# Partition into per-thread tiles of (1,4)
gA = cute.zipped_divide(mA, (1, 4))
print("Tiled tensor gA:", gA)
# Inspect a specific tile (mi, ni)
mi = cutlass.Int32(0)
ni = cutlass.Int32(0)
tile = gA[(None, (mi, ni))]
print("Per-thread tile slice:", tile)
# Materialize tile for printing
frag = cute.make_fragment(tile.layout, tile.element_type)
frag.store(tile.load())
cute.print_tensor(frag)
A = torch.arange(0, 8*8, dtype=torch.float32).reshape(8, 8)
zdiv_demo(from_dlpack(A))
6) TV layout and per-thread fragments
cute.make_layout_tv(thread_layout, value_layout)
maps (tid, vid)
into a (TileM, TileN)
tile. Composing a block-local tensor with this layout yields per-thread fragments.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
@cute.kernel
def tv_add_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor, tv_layout: cute.Layout):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
# Select the thread-block tile
blk_coord = ((None, None), bidx)
blkA = gA[blk_coord]
blkB = gB[blk_coord]
blkC = gC[blk_coord]
# Compose TV layout to map (tid, vid) -> physical address
tidfrgA = cute.composition(blkA, tv_layout)
tidfrgB = cute.composition(blkB, tv_layout)
tidfrgC = cute.composition(blkC, tv_layout)
# Slice per-thread vector
thr_coord = (tidx, None)
thrA = tidfrgA[thr_coord]
thrB = tidfrgB[thr_coord]
thrC = tidfrgC[thr_coord]
thrC[None] = thrA.load() + thrB.load()
@cute.jit
def tv_add(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor):
# Thread (4,32): 4 groups along M (row), 32 contiguous threads along N (col)
# Value (4,8): each thread handles 4 rows x 8 contiguous values
thr_layout = cute.make_layout((4, 32), stride=(32, 1))
val_layout = cute.make_layout((4, 8), stride=(8, 1))
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
# Tile tensors into thread-block tiles
gA = cute.zipped_divide(mA, tiler_mn)
gB = cute.zipped_divide(mB, tiler_mn)
gC = cute.zipped_divide(mC, tiler_mn)
# One block per tile in mode-1; threads per block = TV threads
tv_add_kernel(gA, gB, gC, tv_layout).launch(
grid=[cute.size(gC, mode=[1]), 1, 1],
block=[cute.size(tv_layout, mode=[0]), 1, 1],
)
M, N = 2048, 2048
a = torch.randn(M, N, device="cuda", dtype=torch.float16)
b = torch.randn(M, N, device="cuda", dtype=torch.float16)
c = torch.zeros(M, N, device="cuda", dtype=torch.float16)
tv_add_compiled = cute.compile(tv_add, from_dlpack(a), from_dlpack(b), from_dlpack(c))
tv_add_compiled(from_dlpack(a), from_dlpack(b), from_dlpack(c))
Explanation
make_layout_tv(thr_layout, val_layout)
maps(tid, vid)
to a(TileM, TileN)
tile; composing a block-local tensor yields per-thread fragments.- Threads load/store vector fragments defined by
val_layout
; mapping along N ensures coalesced access; row-grouping across M is handled by the thread layout. - Grid uses one block per tile (mode-1 of zipped tensors); block size equals number of TV threads
size(tv_layout, mode=[0])
.
TV layout mapping (diagram)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
Tiled to Thread Block:
((16,256),(128,8)) : ((2048,1),(32768,256))
~~~~~~~~ ~~~~~~ ~~~~~~~~
| | |
| | |
| `------------------------> Number of Thread Blocks
| |
| |
`--------------------'
|
V
Thread Block
Tile
Sliced to Thread-Block local sub-tensor (16,256): gA[((None, None), bidx)]
Sliced to Thread local sub-tensor (4,8): tidfrgA[(tidx, None)]
Notes
- Thread-block tile
(TileM,TileN) = (thr_M×val_M, thr_N×val_N)
- Threads per block =
thr_M×thr_N
; values per thread =val_M×val_N
; products matchTileM×TileN
7) Layout algebra: coalesce, composition, divide, product
Layouts map coordinates to indices. Compose layouts to reshape/reorder; divide to tile; coalesce to flatten compatible modes; and replicate via products.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import cutlass
import cutlass.cute as cute
@cute.jit
def layout_demo():
# Composition and coalesce
A = cute.make_layout((6, 2), stride=(cutlass.Int32(8), 2))
B = cute.make_layout((4, 3), stride=(3, 1))
R = cute.composition(A, B)
C = cute.coalesce(R)
# Logical divide with tiler
L = cute.make_layout((9, (4, 8)), stride=(59, (13, 1)))
T = (cute.make_layout(3, stride=3),
cute.make_layout((2, 4), stride=(1, 8)))
D = cute.logical_divide(L, tiler=T)
# Logical product/repetition
P = cute.logical_product(
cute.make_layout((2, 2), stride=(4, 1)),
cute.make_layout(6, stride=1),
)
cute.printf("A={}, B={}, R={}, C={}", A, B, R, C)
cute.printf("Divide: {}", D)
cute.printf("Product: {}", P)
layout_demo()
8) PyTorch CUDA Graphs integration
CUDA Graphs remove CPU launch overhead. JIT-compile functions that accept a CUstream
before capture, then replay without re-JITing.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import cutlass
import cutlass.cute as cute
from cuda.bindings.driver import CUstream
from torch.cuda import current_stream
@cute.kernel
def hello_kernel():
cute.printf("Hello")
@cute.jit
def hello_launch(stream: CUstream):
hello_kernel().launch(grid=[1, 1, 1], block=[1, 1, 1], stream=stream)
s = current_stream()
compiled = cute.compile(hello_launch, CUstream(s.cuda_stream))
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
compiled(CUstream(current_stream().cuda_stream))
g.replay()
torch.cuda.synchronize()
9) Patterns and checklist
- Compile before CUDA graph capture; pass
CUstream
explicitly when needed - Use
cute.printf
for dynamic values; Pythonprint
is static - Prefer explicit numeric types (e.g.,
cutlass.Int32
) in JIT code; convert with.to(...)
- Vectorize with
zipped_divide
+ slice(None, (mi, ni))
+.load()
/.store()
- Use
make_layout_tv
to decouple thread/value tiling; compose with block-local tensors - Size grids/blocks with
cute.size(layout_or_tensor, mode=[...])
- Dynamic layouts may not coalesce; beware differences vs static
Minimal checklist
- Initialize CUDA context when needed
- Define kernels with
@cute.kernel
; launch from@cute.jit
functions - Convert external arrays/tensors via
from_dlpack
- Inspect/print tensors and layouts (
cute.print_tensor
,cute.printf
) - Build layouts/tilers (
make_layout
,make_layout_tv
), and use divide/composition/product as required - Use
cute.compile
to AOT-compile and reuse - Benchmark with
torch.cuda.Event
when measuring kernels or graphs
10) Key APIs quick reference
@cute.kernel
,@cute.jit
,cute.compile
,kernel().launch(...)
cute.printf
,cute.print_tensor
,cutlass.Constexpr
- Numeric types and conversion:
cutlass.Int32/Float32/...
,.to(...)
- Tensor and layout construction:
cute.make_layout
,cute.make_tensor
,from_dlpack
,cute.make_identity_tensor
- Vectorization and tiling:
cute.zipped_divide
,cute.size
,cute.make_layout_tv
,cute.composition
- Register ops:
TensorSSA
,.load()
,.store(...)
,.reduce(...)
,cute.ReductionOp
- CUDA Graphs interop:
cuda.bindings.driver.CUstream
,torch.cuda.current_stream
,torch.cuda.CUDAGraph
Glossary
- Layout: Mapping from coordinates to offsets (shape + stride), supports static/dynamic params
- Engine: Pointer-like object supporting offset/deref; composes with layout to form a tensor
- TensorSSA: Register-level tensor value enabling vectorized elementwise ops and reductions
- TV layout: Thread/value layout mapping
(tid, vid)
into logical tensor tile coordinates - Tiler: Layout/tuple guiding by-mode composition/division/product
- Constexpr: Compile-time constant distinct from runtime numeric types
- DLPack interop: Zero-copy exchange with frameworks (
from_dlpack
)