implement ggml_soft_max_back for more performant backward pass of soft_max

avoids creating big intermediate matrices of size n_embd x n_embd for llama layers and n_vocab x n_vocab for cross entropy loss
This commit is contained in:
xaedes 2023-05-14 17:16:26 +02:00
parent f89c278d83
commit ec1aea09ec
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 164 additions and 44 deletions

196
ggml.c
View file

@ -3325,6 +3325,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"DIAG_MASK_INF",
"DIAG_MASK_ZERO",
"SOFT_MAX",
"SOFT_MAX_BACK",
"ROPE",
"ROPE_BACK",
"ALIBI",
@ -3338,7 +3339,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"MAP_BINARY",
};
static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50");
static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -3385,6 +3386,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"diag_mask_inf(x)",
"diag_mask_zero(x)",
"soft_max(x)",
"soft_max_back(x)",
"rope(x)",
"rope_back(x)",
"alibi(x)",
@ -3398,7 +3400,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"f(x,y)",
};
static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50");
static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@ -5927,6 +5929,44 @@ struct ggml_tensor * ggml_soft_max_inplace(
return ggml_soft_max_impl(ctx, a, true);
}
// ggml_soft_max_back
struct ggml_tensor * ggml_soft_max_back_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
bool inplace) {
bool is_node = false;
if (a->grad || b->grad) {
is_node = true; // TODO : implement backward pass
}
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
result->op = GGML_OP_SOFT_MAX_BACK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = b;
return result;
}
struct ggml_tensor * ggml_soft_max_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_soft_max_back_impl(ctx, a, b, false);
}
struct ggml_tensor * ggml_soft_max_back_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_soft_max_back_impl(ctx, a, b, true);
}
// ggml_rope
struct ggml_tensor * ggml_rope_impl(
@ -10482,6 +10522,103 @@ static void ggml_compute_forward_soft_max(
}
}
// ggml_compute_forward_soft_max_back
static void ggml_compute_forward_soft_max_back_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_are_same_shape(src1, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
// TODO: handle transposed/permuted matrices
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
//printf("p[%d] = %f\n", i, p[i]);
assert(!isnan(s0[i]));
assert(!isnan(s1[i]));
}
#endif
// Jii = yi - yi*yi
// Jij = -yi*yj
// J = diag(y)-y.T*y
// dx = J * dy
// dxk = sum_i(Jki * dyi)
// quadratic runtime, linear memory
for (int k = 0; k < nc; k++) {
ggml_float sum = 0.0;
for (int i = 0; i < k; i++) {
float Jki = -y[k]*y[i];
sum += (ggml_float) Jki * dy[i];
}
float Jkk = y[k] - y[k]*y[k];
sum += (ggml_float) Jkk * dy[k];
for (int i = k+1; i < nc; i++) {
float Jki = -y[k]*y[i];
sum += (ggml_float) Jki * dy[i];
}
dx[k] = (float) sum;
}
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
assert(!isnan(dx[i]));
assert(!isinf(dx[i]));
}
#endif
}
}
static void ggml_compute_forward_soft_max_back(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst);
} break;
default:
{
GGML_ASSERT(false);
} break;
}
}
// ggml_compute_forward_alibi
static void ggml_compute_forward_alibi_f32(
@ -12529,6 +12666,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_soft_max(params, tensor->src0, tensor);
} break;
case GGML_OP_SOFT_MAX_BACK:
{
ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
} break;
case GGML_OP_ROPE:
{
ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
@ -13146,50 +13287,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
// necessary for llama
if (src0->grad) {
// y = softmax(x)
//
// Jii = yi - yi*yi
// Jij = -yi*yj
// J = diag(y)-y.*y
// dx = J * dy
// dxk = sum(Jkj * dyk)
int64_t ne2[4] = {
tensor->ne[0],
1,
tensor->ne[1]*tensor->ne[2],
tensor->ne[3]
};
struct ggml_tensor * tensor2 = ggml_cont(ctx,
ggml_reshape_4d(ctx,
ggml_cont(ctx, tensor),
ne2[0], ne2[1], ne2[2], ne2[3]));
struct ggml_tensor * grad2 = ggml_cont(ctx,
ggml_reshape_4d(ctx,
ggml_cont(ctx, tensor->grad),
ne2[0], ne2[1], ne2[2], ne2[3]));
struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3]
tensor2, // [ne0,1,ne1*ne2,ne3]
1, 0, 2, 3));
src0->grad =
ggml_add_impl(ctx,
src0->grad, // [ne0,ne1,ne2,ne3]
ggml_reshape(ctx, // [ne0,ne1,ne2,ne3]
ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3]
ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3]
ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3]
tensor2), // [ne0,1,ne1*ne2,ne3]
ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
tensor2_t, // [1,ne0,ne1*ne2,ne3]
tensor2_t)), // [1,ne0,ne1*ne2,ne3]
grad2), // [ne0,1,ne1*ne2,ne3]
src0->grad),
inplace);
ggml_add_impl(ctx, src0->grad,
ggml_soft_max_back(ctx, tensor->grad, tensor),
inplace);
}
} break;
case GGML_OP_SOFT_MAX_BACK:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_ROPE:
{
@ -13718,6 +13825,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} break;
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
{

12
ggml.h
View file

@ -307,6 +307,7 @@ extern "C" {
GGML_OP_DIAG_MASK_INF,
GGML_OP_DIAG_MASK_ZERO,
GGML_OP_SOFT_MAX,
GGML_OP_SOFT_MAX_BACK,
GGML_OP_ROPE,
GGML_OP_ROPE_BACK,
GGML_OP_ALIBI,
@ -860,6 +861,17 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_soft_max_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
// in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
// rotary position embedding
// if mode & 1 == 1, skip n_past elements
// if mode & 2 == 1, GPT-NeoX style