From 1a80e9a0faf95102a49337021e6b1640a2291378 Mon Sep 17 00:00:00 2001 From: xaedes Date: Thu, 27 Apr 2023 00:13:43 +0200 Subject: [PATCH] correctly implement softmax backward pass using new operation ggml_diag ggml_diag constructs diagonal matrices with entries. ggml_diag(shape[a,1,c,d]) -> shape[a,a,c,d] --- ggml.c | 146 +++++++++++++++++++++++++++++++++++++++++++++++++++------ ggml.h | 5 ++ 2 files changed, 137 insertions(+), 14 deletions(-) diff --git a/ggml.c b/ggml.c index 9894e13c5..5e0725931 100644 --- a/ggml.c +++ b/ggml.c @@ -3991,6 +3991,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "TRANSPOSE", "GET_ROWS", "GET_ROWS_BACK", + "DIAG", "DIAG_MASK_INF", "DIAG_MASK_ZERO", "SOFT_MAX", @@ -4007,7 +4008,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "MAP_BINARY", }; -static_assert(GGML_OP_COUNT == 45, "GGML_OP_COUNT != 45"); +static_assert(GGML_OP_COUNT == 46, "GGML_OP_COUNT != 46"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -4047,6 +4048,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "transpose(x)", "get_rows(x)", "get_rows_back(x)", + "diag(x)", "diag_mask_inf(x)", "diag_mask_zero(x)", "soft_max(x)", @@ -4063,7 +4065,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "f(x,y)", }; -static_assert(GGML_OP_COUNT == 45, "GGML_OP_COUNT != 45"); +static_assert(GGML_OP_COUNT == 46, "GGML_OP_COUNT != 46"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); @@ -6175,6 +6177,30 @@ struct ggml_tensor * ggml_get_rows_back( return result; } +// ggml_diag + +struct ggml_tensor * ggml_diag( + struct ggml_context * ctx, + struct ggml_tensor * a) { + GGML_ASSERT(a->ne[1] == 1); + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, MAX(a->n_dims, 2), ne); + + result->op = GGML_OP_DIAG; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + + // ggml_diag_mask_inf struct ggml_tensor * ggml_diag_mask_inf_impl( @@ -10269,6 +10295,79 @@ static void ggml_compute_forward_get_rows_back( //} } +// ggml_compute_forward_diag + +static void ggml_compute_forward_diag_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // TODO: handle transposed/permuted matrices + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; + assert(ne00 == ne0); + assert(ne00 == ne1); + assert(ne01 == 1); + assert(ne02 == ne2); + assert(ne03 == ne3); + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + assert(nb00 == sizeof(float)); + assert(nb0 == sizeof(float)); + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + for (int i1 = 0; i1 < ne1; i1++) { + float * d = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02); + for (int i0 = 0; i0 < i1; i0++) { + d[i0] = 0; + } + d[i1] = s[i1]; + for (int i0 = i1+1; i0 < ne0; i0++) { + d[i0] = 0; + } + } + } + } +} + +static void ggml_compute_forward_diag( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_diag_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_diag_mask_inf static void ggml_compute_forward_diag_mask_f32( @@ -10392,7 +10491,7 @@ static void ggml_compute_forward_soft_max_f32( if (sp[i] == -INFINITY) { dp[i] = 0.0f; } else { - //const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max); + // const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max); ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max); memcpy(&scvt, &s, sizeof(scvt)); const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); @@ -12443,6 +12542,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_get_rows_back(params, tensor->src0, tensor->src1, tensor); } break; + case GGML_OP_DIAG: + { + ggml_compute_forward_diag(params, tensor->src0, tensor); + } break; case GGML_OP_DIAG_MASK_INF: { ggml_compute_forward_diag_mask_inf(params, tensor->src0, tensor->src1, tensor); @@ -12906,6 +13009,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // noop } } break; + case GGML_OP_DIAG: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_DIAG_MASK_INF: { // necessary for llama @@ -12943,20 +13050,30 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // necessary for llama if (src0->grad) { // y = softmax(x) - // dx = dy * y - sum(dy * y) * y - // dx = y * (dy - sum(dy * y)) + // + // Jii = yi - yi*yi + // Jij = -yi*yj + // J = diag(y)-y.*y + // dx = J * dy + // dxk = sum(Jkj * dyk) + + struct ggml_tensor * tensor_t = ggml_cont(ctx, + ggml_permute(ctx, + ggml_reshape(ctx, + tensor, + ggml_new_tensor(ctx, + tensor->type, + 4, tensor->ne)), + 1, 0, 2, 3)); + src0->grad = ggml_add_impl(ctx, src0->grad, - ggml_mul(ctx, - tensor, - ggml_add1(ctx, - tensor->grad, - ggml_neg(ctx, - ggml_sum(ctx, - ggml_mul(ctx, - tensor->grad, - tensor))))), + ggml_mul_mat(ctx, + ggml_sub(ctx, + ggml_diag(ctx, tensor), + ggml_mul_mat(ctx, tensor_t, tensor_t)), + tensor->grad), inplace); } } break; @@ -13480,6 +13597,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) case GGML_OP_TRANSPOSE: case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS_BACK: + case GGML_OP_DIAG: case GGML_OP_DIAG_MASK_INF: { node->n_tasks = 1; diff --git a/ggml.h b/ggml.h index 1677ea533..e93c6bfac 100644 --- a/ggml.h +++ b/ggml.h @@ -285,6 +285,7 @@ extern "C" { GGML_OP_TRANSPOSE, GGML_OP_GET_ROWS, GGML_OP_GET_ROWS_BACK, + GGML_OP_DIAG, GGML_OP_DIAG_MASK_INF, GGML_OP_DIAG_MASK_ZERO, GGML_OP_SOFT_MAX, @@ -700,6 +701,10 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_diag( + struct ggml_context * ctx, + struct ggml_tensor * a); + // set elements above the diagonal to -INF GGML_API struct ggml_tensor * ggml_diag_mask_inf( struct ggml_context * ctx,