Skip to content

第三十章 · 计算着色器 (Compute Shader)

30.1 计算着色器概述

计算着色器 (Compute Shader) = GPU 上的通用计算

GPU 计算 vs CPU 计算:
─────────────────────────────

GPU (Compute Shader):
┌──────────────────────┐
│ 1000+ 并行线程         │
│ 适合: 并行计算、大规模数据 │
│ 带宽: 300-1000 GB/s  │
│ 延迟: 高 (ms 级)       │
└───────────────────┘

CPU:
┌──────────────────────┐
│ 4-64 个核心            │
│ 适合: 复杂逻辑、顺序计算 │
│ 带宽: 50-100 GB/s    │
│ 延迟: 低 (ns 级)       │
└───────────────────┘

30.2 计算着色器基础

30.2.1 基本结构

glsl
#version 450

// 工作组大小(必须在编译时确定)
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;

// 输入数据
layout(set = 0, binding = 0) buffer InputBuffer {
    float data[];
} input;

// 输出数据
layout(set = 0, binding = 1) buffer OutputBuffer {
    float result[];
} output;

void main() {
    // 全局工作索引
    uint idx = gl_GlobalInvocationID.x;
    
    // 工作组索引
    uint gid = gl_GroupID.x;
    
    // 工作组内局部索引
    uint lid = gl_LocalInvocationID.x;
    
    // 工作组成员总数
    uint localSize = gl_WorkGroupSize.x;
    
    // 工作组总数
    uint numGroups = gl_NumWorkGroups.x;
    
    // 计算
    result[idx] = compute(data[idx]);
}

30.2.2 内置变量

变量类型说明
gl_GlobalInvocationIDuvec3全局线程 ID
gl_LocalInvocationIDuvec3工作组内线程 ID
gl_WorkGroupIDuvec3工作组 ID
gl_LocalInvocationIndexuint工作组内线性索引
gl_WorkGroupSizeuvec3工作组大小
gl_NumWorkGroupsuvec3工作组总数
gl_NumWorkGroupsEXTuvec3工作组总数(扩展)
gl_SampleIDuint采样 ID(MSAA)

30.3 常用 Compute Shader 模式

30.3.1 点乘

glsl
#version 450

layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;

layout(set = 0, binding = 0) readonly buffer A {
    float a[];
} buf_a;

layout(set = 0, binding = 1) readonly buffer B {
    float b[];
} buf_b;

layout(set = 0, binding = 2) writeonly buffer Result {
    float result[];
} buf_result;

void main() {
    uint idx = gl_GlobalInvocationID.x;
    result[idx] = a[idx] * b[idx];
}

30.3.2 矩阵乘法

glsl
#version 450

layout(local_size_x = 16, local_size_y = 16, local_size_z = 1) in;

layout(set = 0, binding = 0) readonly buffer MatA {
    mat4 matrix[];
} mat_a;

layout(set = 0, binding = 1) readonly buffer MatB {
    mat4 matrix[];
} mat_b;

layout(set = 0, binding = 2) writeonly buffer Result {
    mat4 matrix[];
} result;

void main() {
    uint x = gl_GlobalInvocationID.x;
    uint y = gl_GlobalInvocationID.y;
    
    // 计算矩阵乘法
    result.matrix[x + y * 4] = mat_a.matrix[x] * mat_b.matrix[y];
}

30.3.3 并行归约

glsl
#version 450

layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;

shared float shared_data[256];

layout(set = 0, binding = 0) readonly buffer Input {
    float data[];
} input;

layout(set = 0, binding = 1) writeonly buffer Output {
    float result[];
} output;

void main() {
    uint idx = gl_GlobalInvocationID.x;
    shared_data[lid] = input.data[idx];
    
    // 归约
    for (uint stride = 128; stride > 0; stride >>= 1) {
        barrier();
        if (lid < stride) {
            shared_data[lid] += shared_data[lid + stride];
        }
        barrier();
    }
    
    // 每个工作组写一个结果
    if (lid == 0) {
        output.result[gid] = shared_data[0];
    }
}

30.3.4 粒子系统

glsl
#version 450

layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;

struct Particle {
    vec3 position;
    vec3 velocity;
    float life;
};

layout(set = 0, binding = 0) buffer ParticlesIn {
    Particle particles[];
} particles_in;

layout(set = 0, binding = 1) buffer ParticlesOut {
    Particle particles[];
} particles_out;

uniform float dt;
uniform vec3 gravity;

void main() {
    uint idx = gl_GlobalInvocationID.x;
    
    Particle p = particles_in.particles[idx];
    
    // 更新物理
    p.velocity += gravity * dt;
    p.position += p.velocity * dt;
    p.life -= dt;
    
    particles_out.particles[idx] = p;
}

30.4 工作组大小优化

30.4.1 最佳工作组大小

python
# 查询设备支持的工作组大小
props = device.get_physical_device().get_properties()
max_work_group = props.maxComputeWorkGroupSize
max_work_group_invocations = props.maxComputeWorkGroupInvocations
max_work_group_size = props.maxComputeWorkGroupSize

print(f"Max work group: ({max_work_group})")
print(f"Max invocations: {max_work_group_invocations}")

# 推荐的工作组大小:
# local_size_x = 256/512 (向量/矩阵计算)
# local_size_x = 16, local_size_y = 16 (2D 计算)
# local_size_x = 64, local_size_y = 16 (纹理处理)

30.4.2 工作组大小选择

工作负载推荐工作组大小
向量计算local_size_x = 256
矩阵计算local_size_x = 16, local_size_y = 16
纹理处理local_size_x = 16, local_size_y = 16
粒子系统local_size_x = 256
图像滤波local_size_x = 16, local_size_y = 16
通用计算local_size_x = 64-256

30.5 内存屏障

30.5.1 工作组内屏障

glsl
// 工作组内屏障
barrier();  // 等待同组所有线程
memoryBarrierShared();  // 确保共享内存可见

30.5.2 全局内存屏障

glsl
// 确保全局内存写入完成
memoryBarrier();  // 所有共享内存
memoryBarrierBuffer();  // 确保 Buffer 写入完成
memoryBarrierImage();  // 确保 Image 写入完成

30.6 计算着色器 vs 图形着色器

特性ComputeGraphics
用途通用计算渲染
工作组可自定义固定
输入Buffer/ImageVertex/Fragment
输出Buffer/ImageFramebuffer
同步barrier/memoryBarrierPipelineBarrier
并行度1000+ 线程数千+ 线程
延迟
精度可高受显示限制

30.7 计算着色器速查

概念说明
local_size_x/y/z工作组大小
gl_GlobalInvocationID全局线程 ID
gl_LocalInvocationID工作组内线程 ID
gl_WorkGroupID工作组 ID
barrier()工作组内同步
memoryBarrier()全局内存屏障
shared工作组内共享内存
dispatch(x,y,z)执行计算着色器
groupMemoryBarrier()工作组间同步
sharedBarrier()工作组内屏障