Automatic differentiation¶
这一部分主要阐述前向与反向梯度传播、自定义类型上的自动微分实现、自定义导数的定义以及梯度如何传递到缓冲区。
考虑一个简单的函数
float square(float x, float y)
{
return x * x + y * y;
}
void main()
{
float x = 3.0f;
float y = 4.0f;
float result = square(x, y);
printf("Result: %f\n", result);
}
对该函数加上 [Differentiable] 属性标记:Slang 就能通过 fwd_diff 和 bwd_diff 内置函数生成并调用其导数版本。
前向模式用于计算函数输出对某个输入变量的导数。
- 使用
diffPair(primal, derivative)包装输入。 - 调用
fwd_diff(func)来执行前向传播。 - 结果也是一个
diffPair,其中.d是导数值。
void main() {
// x=3, ∂x/∂θ = 1;y=4, ∂y/∂γ = 0
let x = diffPair(3.0, 1.0);
let y = diffPair(4.0, 0.0);
let result = fwd_diff(square)(x, y);
printf("dResult: %f\n", result.d); // 输出 6.0 (∂square/∂x = 2*3)
}
反向模式用于从损失函数回传梯度到所有可训练参数(最常用在训练中)。
- 输入变量需声明为
var(因为梯度会写回)。 - 调用
bwd_diff(func)(..., dL/dOutput),其中最后一个参数是上游梯度。 - 梯度会自动写回到输入变量的
.d字段。
void main() {
var x = diffPair(3.0, 0.0); // 初始梯度为0
var y = diffPair(4.0, 0.0);
let dL_dSquare = 1.0f; // 假设损失对 square 的导数为1
bwd_diff(square)(x, y, dL_dSquare);
printf("dL/dx: %f, dL/dy: %f\n", x.d, y.d); // 输出 6.0, 8.0
}
默认情况下,结构体(struct)不是可微的。要使其支持 AutoDiff,需实现 IDifferentiable 接口。
只需继承 IDifferentiable,编译器会自动生成所需方法:
struct Point : IDifferentiable {
float x;
float y;
}
[Differentiable]
float square(Point p) {
return p.x * p.x + p.y * p.y;
}
half, float, double
vector and matrix of floating point types
T[] (if T is differentiable) 这些类型都是可微的
有时自动微分效率低或不适用(如包含不可导操作),可手动提供导数实现。
前向自定义导数:
[ForwardDerivativeOf(square)]
DifferentialPair<float> squareFwd(
DifferentialPair<float> x,
DifferentialPair<float> y
) {
return diffPair(
square(x.p, y.p),
2.0f * (x.p * x.d + y.p * y.d)
);
}
反向自定义导数:
[BackwardDerivativeOf(square)]
void squareBwd(
inout DifferentialPair<float> x,
inout DifferentialPair<float> y,
float dOut // 上游梯度 dL/d(square)
) {
x.d += 2.0f * x.p * dOut;
y.d += 2.0f * y.p * dOut;
}
当函数从缓冲区(如 RWStructuredBuffer)读取数据时,需手动处理梯度写回。
- 将缓冲区访问封装为函数。
- 为该函数提供
[BackwardDerivativeOf(...)]实现。 - 在反向函数中将梯度累加到梯度缓冲区(通常用原子操作)。
RWStructuredBuffer<float> paramBuffer;
RWStructuredBuffer<Atomic<float>> gradBuffer;
float getParam(int idx) {
return paramBuffer[idx];
}
[BackwardDerivativeOf(getParam)]
void getParamBwd(int idx, float dOut) {
gradBuffer[idx] += dOut; // 原子累加梯度
}
[Differentiable]
float compute(int idx) {
float w = getParam(idx);
return w * w;
}