CUDA: backwards pass for misc. ops, add tests (#11257)

* CUDA: backwards pass for misc. ops, add tests

* remove restrict from pointers
This commit is contained in:
Johannes Gäßler 2025-01-16 16:43:38 +01:00 committed by GitHub
parent 681149ced2
commit 9c8dcefe17
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 930 additions and 332 deletions

View file

@ -6691,20 +6691,20 @@ static void ggml_compute_forward_silu_back_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * grad = dst->src[1];
const struct ggml_tensor * grad = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
assert(ggml_is_contiguous_1(grad));
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(src1));
assert(ggml_is_contiguous_1(dst));
assert(ggml_are_same_shape(src0, dst));
assert(ggml_are_same_shape(src0, grad));
assert(ggml_are_same_shape(src1, dst));
assert(ggml_are_same_shape(src1, grad));
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);
const int nc = src1->ne[0];
const int nr = ggml_nrows(src1);
// rows per thread
const int dr = (nr + nth - 1)/nth;
@ -6716,7 +6716,7 @@ static void ggml_compute_forward_silu_back_f32(
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_silu_backward_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])),
(float *) ((char *) src1->data + i1*(src1->nb[1])),
(float *) ((char *) grad->data + i1*(grad->nb[1])));
#ifndef NDEBUG
@ -6895,7 +6895,7 @@ static void ggml_compute_forward_norm_f32(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps > 0.0f);
GGML_ASSERT(eps >= 0.0f);
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
@ -6966,7 +6966,7 @@ static void ggml_compute_forward_rms_norm_f32(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps > 0.0f);
GGML_ASSERT(eps >= 0.0f);
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
@ -7018,12 +7018,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
const struct ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
const struct ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
@ -7042,8 +7043,8 @@ static void ggml_compute_forward_rms_norm_back_f32(
const int64_t i12 = i02;
const int64_t i13 = i03;
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
const float * x = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
ggml_float sum_xx = 0.0;
ggml_float sum_xdz = 0.0;
@ -7066,9 +7067,9 @@ static void ggml_compute_forward_rms_norm_back_f32(
{
// z = rms_norm(x)
//
// rms_norm(src0) =
// rms_norm(src1) =
// scale(
// src0,
// src1,
// div(
// 1,
// sqrt(
@ -7076,13 +7077,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
// scale(
// sum(
// sqr(
// src0)),
// src1)),
// (1.0/N)),
// eps))));
// postorder:
// ## op args grad
// 00 param src0 grad[#00]
// 00 param src1 grad[#00]
// 01 const 1
// 02 sqr (#00) grad[#02]
// 03 sum (#02) grad[#03]
@ -7159,6 +7160,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
// dx := scale(dx, rrms)
float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
// dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
ggml_vec_cpy_f32 (ne00, dx, x);
// ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
@ -7750,12 +7752,13 @@ static void ggml_compute_forward_out_prod_f32(
const int ith = params->ith;
const int nth = params->nth;
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne10);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne02 == ne12);
GGML_ASSERT(ne3 == ne13);
GGML_ASSERT(ne03 == ne13);
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne10);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);
GGML_ASSERT(ne2 % ne02 == 0);
GGML_ASSERT(ne3 % ne03 == 0);
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == sizeof(float));
@ -7797,6 +7800,10 @@ static void ggml_compute_forward_out_prod_f32(
const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
const int64_t blck_1 = 16;
// dps == dst per src0, used for group query attention
const int64_t dps2 = ne2 / ne02;
const int64_t dps3 = ne3 / ne03;
for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
const int64_t bir1 = MIN(bir + blck_1, ir1);
for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
@ -7807,8 +7814,8 @@ static void ggml_compute_forward_out_prod_f32(
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
const int64_t i02 = i2;
const int64_t i03 = i3;
const int64_t i02 = i2 / dps2;
const int64_t i03 = i3 / dps3;
//const int64_t i10 = i1;
const int64_t i12 = i2;
@ -8906,9 +8913,9 @@ static void ggml_compute_forward_soft_max(
}
// ggml_compute_forward_soft_max_back
// ggml_compute_forward_soft_max_ext_back
static void ggml_compute_forward_soft_max_back_f32(
static void ggml_compute_forward_soft_max_ext_back_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
@ -8921,6 +8928,14 @@ static void ggml_compute_forward_soft_max_back_f32(
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_are_same_shape(src1, dst));
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
GGML_ASSERT(max_bias == 0.0f);
// TODO: handle transposed/permuted matrices
const int ith = params->ith;
@ -8969,10 +8984,11 @@ static void ggml_compute_forward_soft_max_back_f32(
// linear runtime, no additional memory
float dot_y_dy = 0;
ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
ggml_vec_cpy_f32 (nc, dx, dy);
ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
ggml_vec_mul_f32 (nc, dx, dx, y);
ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
ggml_vec_cpy_f32 (nc, dx, dy);
ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
ggml_vec_mul_f32 (nc, dx, dx, y);
ggml_vec_scale_f32(nc, dx, scale);
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
@ -8983,7 +8999,7 @@ static void ggml_compute_forward_soft_max_back_f32(
}
}
static void ggml_compute_forward_soft_max_back(
static void ggml_compute_forward_soft_max_ext_back(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
@ -8992,7 +9008,7 @@ static void ggml_compute_forward_soft_max_back(
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_soft_max_back_f32(params, dst);
ggml_compute_forward_soft_max_ext_back_f32(params, dst);
} break;
default:
{
@ -9985,9 +10001,10 @@ static void ggml_compute_forward_im2col_back_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
const struct ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
const struct ggml_tensor * src1 = dst->src[1]; // convolution kernel
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -10009,11 +10026,11 @@ static void ggml_compute_forward_im2col_back_f32(
const int64_t IH = is_2D ? ne1 : 1;
const int64_t IW = ne0;
const int64_t KH = is_2D ? ne01 : 1;
const int64_t KW = ne00;
const int64_t KH = is_2D ? ne11 : 1;
const int64_t KW = ne10;
const int64_t OH = is_2D ? ne12 : 1;
const int64_t OW = ne11;
const int64_t OH = is_2D ? ne02 : 1;
const int64_t OW = ne01;
int ofs0 = is_2D ? nb3 : nb2;
int ofs1 = is_2D ? nb2 : nb1;
@ -10059,9 +10076,9 @@ static void ggml_compute_forward_im2col_back_f32(
continue;
}
const float * const src_data = (const float *) src1->data
const float * const grad_in = (const float *) src0->data
+ (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
grad += src_data[iic*(KH*KW) + ikh*KW + ikw];
grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
}
}
float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
@ -12484,22 +12501,22 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
const struct ggml_tensor * opt0 = dst->src[2];
const struct ggml_tensor * grad = dst->src[0]; // gradient of forward pass output
const struct ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
const struct ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(opt0));
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_is_contiguous(src0f));
GGML_ASSERT(ggml_is_contiguous(src1f));
GGML_ASSERT(ggml_is_contiguous(grad));
GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
const int64_t ith = params->ith;
const int64_t nth = params->nth;
// TODO: handle transposed/permuted matrices
const int64_t nc = src0->ne[0];
const int64_t nr = ggml_nrows(src0);
const int64_t nc = src0f->ne[0];
const int64_t nr = ggml_nrows(src0f);
// rows per thread
const int64_t dr = (nr + nth - 1)/nth;
@ -12508,12 +12525,12 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
const int64_t ir0 = dr*ith;
const int64_t ir1 = MIN(ir0 + dr, nr);
const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
for (int64_t i1 = ir0; i1 < ir1; i1++) {
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
#ifndef NDEBUG
for (int64_t i = 0; i < nc; ++i) {
@ -12526,11 +12543,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
// soft_max
float max = -INFINITY;
ggml_vec_max_f32(nc, &max, s0);
ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
assert(sum > 0.0);
ggml_vec_scale_f32(nc, ds0, 1.0/sum);
// grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
// grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
ggml_vec_sub_f32(nc, ds0, ds0, s1);
ggml_vec_scale_f32(nc, ds0, d_by_nr);
@ -12827,7 +12844,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_SOFT_MAX_BACK:
{
ggml_compute_forward_soft_max_back(params, tensor);
ggml_compute_forward_soft_max_ext_back(params, tensor);
} break;
case GGML_OP_ROPE:
{