diff --git a/ggml.c b/ggml.c index 3e17dfeb6..2cc51fcc0 100644 --- a/ggml.c +++ b/ggml.c @@ -3325,6 +3325,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "DIAG_MASK_INF", "DIAG_MASK_ZERO", "SOFT_MAX", + "SOFT_MAX_BACK", "ROPE", "ROPE_BACK", "ALIBI", @@ -3338,7 +3339,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "MAP_BINARY", }; -static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50"); +static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3385,6 +3386,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "diag_mask_inf(x)", "diag_mask_zero(x)", "soft_max(x)", + "soft_max_back(x)", "rope(x)", "rope_back(x)", "alibi(x)", @@ -3398,7 +3400,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "f(x,y)", }; -static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50"); +static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); @@ -5927,6 +5929,44 @@ struct ggml_tensor * ggml_soft_max_inplace( return ggml_soft_max_impl(ctx, a, true); } + +// ggml_soft_max_back + +struct ggml_tensor * ggml_soft_max_back_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; // TODO : implement backward pass + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SOFT_MAX_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct ggml_tensor * ggml_soft_max_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_soft_max_back_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_soft_max_back_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_soft_max_back_impl(ctx, a, b, true); +} + // ggml_rope struct ggml_tensor * ggml_rope_impl( @@ -10482,6 +10522,103 @@ static void ggml_compute_forward_soft_max( } } +// ggml_compute_forward_soft_max_back + +static void ggml_compute_forward_soft_max_back_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_are_same_shape(src1, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // TODO: handle transposed/permuted matrices + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float *dy = (float *)((char *) src0->data + i1*src0->nb[1]); + float *y = (float *)((char *) src1->data + i1*src1->nb[1]); + float *dx = (float *)((char *) dst->data + i1*dst->nb[1]); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(s0[i])); + assert(!isnan(s1[i])); + } +#endif + // Jii = yi - yi*yi + // Jij = -yi*yj + // J = diag(y)-y.T*y + // dx = J * dy + // dxk = sum_i(Jki * dyi) + + // quadratic runtime, linear memory + for (int k = 0; k < nc; k++) { + + ggml_float sum = 0.0; + + for (int i = 0; i < k; i++) { + float Jki = -y[k]*y[i]; + sum += (ggml_float) Jki * dy[i]; + } + + float Jkk = y[k] - y[k]*y[k]; + sum += (ggml_float) Jkk * dy[k]; + + for (int i = k+1; i < nc; i++) { + float Jki = -y[k]*y[i]; + sum += (ggml_float) Jki * dy[i]; + } + + dx[k] = (float) sum; + } + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(dx[i])); + assert(!isinf(dx[i])); + } +#endif + } +} + +static void ggml_compute_forward_soft_max_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_alibi static void ggml_compute_forward_alibi_f32( @@ -12529,6 +12666,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_soft_max(params, tensor->src0, tensor); } break; + case GGML_OP_SOFT_MAX_BACK: + { + ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor); + } break; case GGML_OP_ROPE: { ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor); @@ -13146,50 +13287,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // necessary for llama if (src0->grad) { - // y = softmax(x) - // - // Jii = yi - yi*yi - // Jij = -yi*yj - // J = diag(y)-y.*y - // dx = J * dy - // dxk = sum(Jkj * dyk) - - int64_t ne2[4] = { - tensor->ne[0], - 1, - tensor->ne[1]*tensor->ne[2], - tensor->ne[3] - }; - struct ggml_tensor * tensor2 = ggml_cont(ctx, - ggml_reshape_4d(ctx, - ggml_cont(ctx, tensor), - ne2[0], ne2[1], ne2[2], ne2[3])); - - struct ggml_tensor * grad2 = ggml_cont(ctx, - ggml_reshape_4d(ctx, - ggml_cont(ctx, tensor->grad), - ne2[0], ne2[1], ne2[2], ne2[3])); - - struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3] - ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3] - tensor2, // [ne0,1,ne1*ne2,ne3] - 1, 0, 2, 3)); - src0->grad = - ggml_add_impl(ctx, - src0->grad, // [ne0,ne1,ne2,ne3] - ggml_reshape(ctx, // [ne0,ne1,ne2,ne3] - ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3] - ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3] - ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3] - tensor2), // [ne0,1,ne1*ne2,ne3] - ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3] - tensor2_t, // [1,ne0,ne1*ne2,ne3] - tensor2_t)), // [1,ne0,ne1*ne2,ne3] - grad2), // [ne0,1,ne1*ne2,ne3] - src0->grad), - inplace); + ggml_add_impl(ctx, src0->grad, + ggml_soft_max_back(ctx, tensor->grad, tensor), + inplace); } + + } break; + case GGML_OP_SOFT_MAX_BACK: + { + GGML_ASSERT(false); // TODO: not implemented } break; case GGML_OP_ROPE: { @@ -13718,6 +13825,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } break; case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ROPE: case GGML_OP_ROPE_BACK: { diff --git a/ggml.h b/ggml.h index 967ef72d0..0a0989516 100644 --- a/ggml.h +++ b/ggml.h @@ -307,6 +307,7 @@ extern "C" { GGML_OP_DIAG_MASK_INF, GGML_OP_DIAG_MASK_ZERO, GGML_OP_SOFT_MAX, + GGML_OP_SOFT_MAX_BACK, GGML_OP_ROPE, GGML_OP_ROPE_BACK, GGML_OP_ALIBI, @@ -860,6 +861,17 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_soft_max_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_soft_max_back_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + // rotary position embedding // if mode & 1 == 1, skip n_past elements // if mode & 2 == 1, GPT-NeoX style