先回顾Naive 矩阵乘法长什么样假设我们要算 C A × B矩阵都是 N×N。C 的每个元素C[row][col]是 A 的第row行和 B 的第col列做点积C[row][col] A[row][0]*B[0][col] A[row][1]*B[1][col] ... A[row][N-1]*B[N-1][col] └─── 沿 K 维度求和 ──────────────────────────────────┘Naive 版本每个线程算 C 的一个元素从 Global Memory 读 A 的一整行和 B 的一整列。问题是 A 和 B 的数据被大量重复读取——计算C[0][0]读了 A 的第 0 行计算C[0][1]又要再读一遍 A 的第 0 行。Tiled 的核心思想分块加载到 Shared Memory把大矩阵切成小块tile每次只处理一小块这样数据可以被 Block 内的线程复用。画个图最直观。假设 N8TILE_SIZE4矩阵 A (8×8) 矩阵 B (8×8) 矩阵 C (8×8) ┌────┬────┐ ┌────┬────┐ ┌────┬────┐ │A00 │A01 │ │B00 │B01 │ │C00 │C01 │ │4×4 │4×4 │ │4×4 │4×4 │ │4×4 │4×4 │ ├────┼────┤ ├────┼────┤ ├────┼────┤ │A10 │A11 │ │B10 │B11 │ │C10 │C11 │ │4×4 │4×4 │ │4×4 │4×4 │ │4×4 │4×4 │ └────┴────┘ └────┴────┘ └────┴────┘要算C00左上角的 4×4 子矩阵需要C00 A00 × B00 A01 × B10 ───────── ───────── 第 1 个 tile 第 2 个 tile关键洞察C00 这块 4×4 16 个元素每个元素都需要读 A 的第 0-3 行和 B 的第 0-3 列。如果不用 Shared Memory16 个线程各读各的A 的同一行被读 4 次。用了 Shared Memory整个 Block 协作把 A00 这块 4×4 一次性搬到 Shared Memory然后 16 个线程都从 Shared Memory 读——Global Memory 访问量减少到 1/4。逐行解读代码__global__ void matMulShared(float *A, float *B, float *C, int N) { // ① 声明 Shared Memory每个 Block 分配两块 TILE_SIZE × TILE_SIZE 的共享空间 __shared__ float sA[TILE_SIZE][TILE_SIZE]; __shared__ float sB[TILE_SIZE][TILE_SIZE];__shared__关键字告诉编译器这块内存在 SM 的 Shared Memory 上Block 内所有线程共享。// ② 计算当前线程负责 C 的哪个元素 int row threadIdx.y blockIdx.y * TILE_SIZE; int col threadIdx.x blockIdx.x * TILE_SIZE;假设 TILE_SIZE4blockIdx (1, 2)threadIdx (3, 1)row 1 2 × 4 9 → 第 9 行col 3 1 × 4 7 → 第 7 列这个线程负责计算 C[9][7]这里threadIdx.x对应列、threadIdx.y对应行——正好满足上一课讲的合并访问原则threadIdx.x相邻 → col 相邻 → 内存连续。float sum 0.0f; // ③ 沿 K 维度公共维度分块遍历 for (int t 0; t N / TILE_SIZE; t) {t是 tile 的编号。N8, TILE_SIZE4 时t 从 0 到 1循环两次——对应前面图里的A00×B00和A01×B10。// ④ 协作加载每个线程从 Global Memory 读一个元素存入 Shared Memory sA[threadIdx.y][threadIdx.x] A[row * N (t * TILE_SIZE threadIdx.x)]; sB[threadIdx.y][threadIdx.x] B[(t * TILE_SIZE threadIdx.y) * N col];这是最关键的一步。Block 内有 TILE_SIZE × TILE_SIZE 个线程比如 4×4 16 个每个线程负责搬一个元素线程 (ty0, tx0) 搬 A[row][t*4 0] → sA[0][0] 线程 (ty0, tx1) 搬 A[row][t*4 1] → sA[0][1] ... 线程 (ty3, tx3) 搬 A[row][t*4 3] → sA[3][3] 16 个线程各搬 1 个 → 一次协作就把 4×4 的 tile 从 Global Memory 搬到了 Shared Memory注意 sA 和 sB 的索引方式sA[ty][tx]ty 是行tx 是列。每个线程搬 A 中当前行、第 t 块的一个元素sB[ty][tx]ty 是行tx 是列。每个线程搬 B 中第 t 块、当前列的一个元素__syncthreads(); // ⑤ 屏障同步等所有线程都搬完了再继续为什么必须有这个如果没有可能线程 0 已经搬完了开始计算但线程 15 还没搬完——线程 0 就会读到 sA 中的垃圾数据。__syncthreads()确保 Block 内所有线程都走到这里之后才一起继续。// ⑥ 在 Shared Memory 内做计算 for (int k 0; k TILE_SIZE; k) { sum sA[threadIdx.y][k] * sB[k][threadIdx.x]; }现在数据在 Shared Memory 里了做点积。对于计算C[row][col]的线程sum sA[ty][0] * sB[0][tx] ← 从 Shared Memory 读~20 cycles sum sA[ty][1] * sB[1][tx] sum sA[ty][2] * sB[2][tx] sum sA[ty][3] * sB[3][tx]对比 naive 版直接从 Global Memory 读~300 cycles/次每次访问快了15 倍以上。__syncthreads(); // ⑦ 确保所有线程算完了再加载下一个 tile覆盖 sA/sB } // ⑧ 写回结果 C[row * N col] sum; }第二个__syncthreads()也不能省下一轮循环会用新的数据覆盖 sA 和 sB如果还有线程没算完就被覆盖了结果就错了。一图总结整个流程K 维度分块遍历 ┌──────────┐ │ t 0 │ └────┬─────┘ ▼ ┌────────────────────────────────────┐ │ 16 个线程协作加载 A_tile, B_tile │ ← Global → Shared │ __syncthreads() │ │ 每个线程在 Shared Memory 内做部分求和 │ ← 快速读取 │ __syncthreads() │ └────────────────┬───────────────────┘ ▼ ┌──────────┐ │ t 1 │ ← 加载下一块 tile覆盖 sA/sB └────┬─────┘ ▼ ... 重复 ... ▼ 所有 tile 算完 ▼ 写 sum 到 Global Memory核心收益Naive 版本中 A 的每一行被读 N 次C 的每一列都要用tiled 版本中只从 Global Memory 读 N/TILE_SIZE 次每次搬一个 tileBlock 内复用 TILE_SIZE 次。当 TILE_SIZE32 时Global Memory 带宽需求降低到1/32。