CUDA: backwards pass for misc. ops, add tests (#11257)
* CUDA: backwards pass for misc. ops, add tests * remove restrict from pointers
This commit is contained in:
parent
681149ced2
commit
9c8dcefe17
18 changed files with 930 additions and 332 deletions
|
@ -3450,12 +3450,14 @@ struct ggml_tensor * ggml_soft_max_ext(
|
|||
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
|
||||
}
|
||||
|
||||
// ggml_soft_max_back
|
||||
// ggml_soft_max_ext_back
|
||||
|
||||
static struct ggml_tensor * ggml_soft_max_back_impl(
|
||||
static struct ggml_tensor * ggml_soft_max_ext_back_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
float scale,
|
||||
float max_bias,
|
||||
bool inplace) {
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
|
@ -3463,21 +3465,28 @@ static struct ggml_tensor * ggml_soft_max_back_impl(
|
|||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
|
||||
memcpy((float *) result->op_params + 0, &scale, sizeof(float));
|
||||
memcpy((float *) result->op_params + 1, &max_bias, sizeof(float));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_soft_max_back(
|
||||
struct ggml_tensor * ggml_soft_max_ext_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b) {
|
||||
return ggml_soft_max_back_impl(ctx, a, b, false);
|
||||
struct ggml_tensor * b,
|
||||
float scale,
|
||||
float max_bias) {
|
||||
return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_soft_max_back_inplace(
|
||||
struct ggml_tensor * ggml_soft_max_ext_back_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b) {
|
||||
return ggml_soft_max_back_impl(ctx, a, b, true);
|
||||
struct ggml_tensor * b,
|
||||
float scale,
|
||||
float max_bias) {
|
||||
return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true);
|
||||
}
|
||||
|
||||
// ggml_rope
|
||||
|
@ -5076,10 +5085,10 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
|
|||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c) {
|
||||
GGML_ASSERT(ggml_are_same_shape(a, b));
|
||||
GGML_ASSERT(ggml_is_scalar(c));
|
||||
GGML_ASSERT(ggml_is_scalar(a));
|
||||
GGML_ASSERT(ggml_are_same_shape(b, c));
|
||||
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, b);
|
||||
|
||||
result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
|
||||
result->src[0] = a;
|
||||
|
@ -5258,7 +5267,7 @@ static void ggml_sub_or_set(
|
|||
}
|
||||
|
||||
static void ggml_compute_backward(
|
||||
struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) {
|
||||
struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, const bool * grads_needed) {
|
||||
struct ggml_tensor * tensor = cgraph->nodes[i];
|
||||
struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor);
|
||||
|
||||
|
@ -5402,7 +5411,7 @@ static void ggml_compute_backward(
|
|||
if (src0_needs_grads) {
|
||||
float eps;
|
||||
memcpy(&eps, tensor->op_params, sizeof(float));
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, src0, grad, eps));
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, grad, src0, eps));
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT: {
|
||||
|
@ -5585,7 +5594,13 @@ static void ggml_compute_backward(
|
|||
} break;
|
||||
case GGML_OP_SOFT_MAX: {
|
||||
if (src0_needs_grads) {
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_back(ctx, grad, tensor));
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
memcpy(&scale, (const float *) tensor->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (const float *) tensor->op_params + 1, sizeof(float));
|
||||
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_ext_back(ctx, grad, tensor, scale, max_bias));
|
||||
}
|
||||
GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented");
|
||||
} break;
|
||||
|
@ -5626,7 +5641,7 @@ static void ggml_compute_backward(
|
|||
const int32_t d1 = ggml_get_op_params_i32(tensor, 5);
|
||||
const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
|
||||
|
||||
ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
|
||||
ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, grad, src0, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_POOL_2D: {
|
||||
|
@ -5669,7 +5684,7 @@ static void ggml_compute_backward(
|
|||
} break;
|
||||
case GGML_UNARY_OP_SILU: {
|
||||
if (src0_needs_grads) {
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, src0, grad));
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0));
|
||||
}
|
||||
} break;
|
||||
case GGML_UNARY_OP_EXP: {
|
||||
|
@ -5686,7 +5701,7 @@ static void ggml_compute_backward(
|
|||
} break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS: {
|
||||
if (src0_needs_grads) {
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, src0, src1, grad));
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, grad, src0, src1));
|
||||
}
|
||||
GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
|
||||
} break;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue