跳转至

Automatic differentiation

这一部分主要阐述前向与反向梯度传播、自定义类型上的自动微分实现、自定义导数的定义以及梯度如何传递到缓冲区。

考虑一个简单的函数

float square(float x, float y)
{
    return x * x + y * y;
}
void main()
{
    float x = 3.0f;
    float y = 4.0f;
    float result = square(x, y);
    printf("Result: %f\n", result);
}

对该函数加上 [Differentiable] 属性标记:Slang 就能通过 fwd_diffbwd_diff 内置函数生成并调用其导数版本。


前向模式用于计算函数输出对某个输入变量的导数。

  • 使用 diffPair(primal, derivative) 包装输入。
  • 调用 fwd_diff(func) 来执行前向传播。
  • 结果也是一个 diffPair,其中 .d 是导数值。
void main() {
    // x=3, ∂x/∂θ = 1;y=4, ∂y/∂γ = 0
    let x = diffPair(3.0, 1.0);
    let y = diffPair(4.0, 0.0);

    let result = fwd_diff(square)(x, y);
    printf("dResult: %f\n", result.d); // 输出 6.0 (∂square/∂x = 2*3)
}

反向模式用于从损失函数回传梯度到所有可训练参数(最常用在训练中)。

  • 输入变量需声明为 var(因为梯度会写回)。
  • 调用 bwd_diff(func)(..., dL/dOutput),其中最后一个参数是上游梯度。
  • 梯度会自动写回到输入变量的 .d 字段。
void main() {
    var x = diffPair(3.0, 0.0); // 初始梯度为0
    var y = diffPair(4.0, 0.0);

    let dL_dSquare = 1.0f; // 假设损失对 square 的导数为1
    bwd_diff(square)(x, y, dL_dSquare);

    printf("dL/dx: %f, dL/dy: %f\n", x.d, y.d); // 输出 6.0, 8.0
}

默认情况下,结构体(struct不是可微的。要使其支持 AutoDiff,需实现 IDifferentiable 接口。

只需继承 IDifferentiable,编译器会自动生成所需方法:

struct Point : IDifferentiable {
    float x;
    float y;
}

[Differentiable]
float square(Point p) {
    return p.x * p.x + p.y * p.y;
}

half, float, double

vector and matrix of floating point types

T[] (if T is differentiable) 这些类型都是可微的


有时自动微分效率低或不适用(如包含不可导操作),可手动提供导数实现。

前向自定义导数:

[ForwardDerivativeOf(square)]
DifferentialPair<float> squareFwd(
    DifferentialPair<float> x,
    DifferentialPair<float> y
) {
    return diffPair(
        square(x.p, y.p),
        2.0f * (x.p * x.d + y.p * y.d)
    );
}

反向自定义导数:

[BackwardDerivativeOf(square)]
void squareBwd(
    inout DifferentialPair<float> x,
    inout DifferentialPair<float> y,
    float dOut  // 上游梯度 dL/d(square)
) {
    x.d += 2.0f * x.p * dOut;
    y.d += 2.0f * y.p * dOut;
}

当函数从缓冲区(如 RWStructuredBuffer)读取数据时,需手动处理梯度写回。

  • 将缓冲区访问封装为函数。
  • 为该函数提供 [BackwardDerivativeOf(...)] 实现。
  • 在反向函数中将梯度累加到梯度缓冲区(通常用原子操作)。
RWStructuredBuffer<float> paramBuffer;
RWStructuredBuffer<Atomic<float>> gradBuffer;

float getParam(int idx) {
    return paramBuffer[idx];
}

[BackwardDerivativeOf(getParam)]
void getParamBwd(int idx, float dOut) {
    gradBuffer[idx] += dOut; // 原子累加梯度
}

[Differentiable]
float compute(int idx) {
    float w = getParam(idx);
    return w * w;
}

评论区

对你有帮助的话请给我个赞和 star => GitHub stars
欢迎跟我探讨!!!