implement ggml_soft_max_back for more performant backward pass of soft_max
avoids creating big intermediate matrices of size n_embd x n_embd for llama layers and n_vocab x n_vocab for cross entropy loss
This commit is contained in:
parent
f89c278d83
commit
ec1aea09ec
2 changed files with 164 additions and 44 deletions
196
ggml.c
196
ggml.c
|
@ -3325,6 +3325,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
|||
"DIAG_MASK_INF",
|
||||
"DIAG_MASK_ZERO",
|
||||
"SOFT_MAX",
|
||||
"SOFT_MAX_BACK",
|
||||
"ROPE",
|
||||
"ROPE_BACK",
|
||||
"ALIBI",
|
||||
|
@ -3338,7 +3339,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
|||
"MAP_BINARY",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50");
|
||||
static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
|
@ -3385,6 +3386,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"diag_mask_inf(x)",
|
||||
"diag_mask_zero(x)",
|
||||
"soft_max(x)",
|
||||
"soft_max_back(x)",
|
||||
"rope(x)",
|
||||
"rope_back(x)",
|
||||
"alibi(x)",
|
||||
|
@ -3398,7 +3400,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"f(x,y)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50");
|
||||
static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
|
||||
|
||||
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
||||
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
||||
|
@ -5927,6 +5929,44 @@ struct ggml_tensor * ggml_soft_max_inplace(
|
|||
return ggml_soft_max_impl(ctx, a, true);
|
||||
}
|
||||
|
||||
|
||||
// ggml_soft_max_back
|
||||
|
||||
struct ggml_tensor * ggml_soft_max_back_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
bool inplace) {
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad || b->grad) {
|
||||
is_node = true; // TODO : implement backward pass
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
result->op = GGML_OP_SOFT_MAX_BACK;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src0 = a;
|
||||
result->src1 = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_soft_max_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b) {
|
||||
return ggml_soft_max_back_impl(ctx, a, b, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_soft_max_back_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b) {
|
||||
return ggml_soft_max_back_impl(ctx, a, b, true);
|
||||
}
|
||||
|
||||
// ggml_rope
|
||||
|
||||
struct ggml_tensor * ggml_rope_impl(
|
||||
|
@ -10482,6 +10522,103 @@ static void ggml_compute_forward_soft_max(
|
|||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_soft_max_back
|
||||
|
||||
static void ggml_compute_forward_soft_max_back_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
GGML_ASSERT(ggml_are_same_shape(src1, dst));
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: handle transposed/permuted matrices
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nc = src0->ne[0];
|
||||
const int nr = ggml_nrows(src0);
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
|
||||
float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
|
||||
float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int i = 0; i < nc; ++i) {
|
||||
//printf("p[%d] = %f\n", i, p[i]);
|
||||
assert(!isnan(s0[i]));
|
||||
assert(!isnan(s1[i]));
|
||||
}
|
||||
#endif
|
||||
// Jii = yi - yi*yi
|
||||
// Jij = -yi*yj
|
||||
// J = diag(y)-y.T*y
|
||||
// dx = J * dy
|
||||
// dxk = sum_i(Jki * dyi)
|
||||
|
||||
// quadratic runtime, linear memory
|
||||
for (int k = 0; k < nc; k++) {
|
||||
|
||||
ggml_float sum = 0.0;
|
||||
|
||||
for (int i = 0; i < k; i++) {
|
||||
float Jki = -y[k]*y[i];
|
||||
sum += (ggml_float) Jki * dy[i];
|
||||
}
|
||||
|
||||
float Jkk = y[k] - y[k]*y[k];
|
||||
sum += (ggml_float) Jkk * dy[k];
|
||||
|
||||
for (int i = k+1; i < nc; i++) {
|
||||
float Jki = -y[k]*y[i];
|
||||
sum += (ggml_float) Jki * dy[i];
|
||||
}
|
||||
|
||||
dx[k] = (float) sum;
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int i = 0; i < nc; ++i) {
|
||||
assert(!isnan(dx[i]));
|
||||
assert(!isinf(dx[i]));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_soft_max_back(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_alibi
|
||||
|
||||
static void ggml_compute_forward_alibi_f32(
|
||||
|
@ -12529,6 +12666,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
{
|
||||
ggml_compute_forward_soft_max(params, tensor->src0, tensor);
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
{
|
||||
ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
|
||||
} break;
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
|
||||
|
@ -13146,50 +13287,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
{
|
||||
// necessary for llama
|
||||
if (src0->grad) {
|
||||
// y = softmax(x)
|
||||
//
|
||||
// Jii = yi - yi*yi
|
||||
// Jij = -yi*yj
|
||||
// J = diag(y)-y.*y
|
||||
// dx = J * dy
|
||||
// dxk = sum(Jkj * dyk)
|
||||
|
||||
int64_t ne2[4] = {
|
||||
tensor->ne[0],
|
||||
1,
|
||||
tensor->ne[1]*tensor->ne[2],
|
||||
tensor->ne[3]
|
||||
};
|
||||
struct ggml_tensor * tensor2 = ggml_cont(ctx,
|
||||
ggml_reshape_4d(ctx,
|
||||
ggml_cont(ctx, tensor),
|
||||
ne2[0], ne2[1], ne2[2], ne2[3]));
|
||||
|
||||
struct ggml_tensor * grad2 = ggml_cont(ctx,
|
||||
ggml_reshape_4d(ctx,
|
||||
ggml_cont(ctx, tensor->grad),
|
||||
ne2[0], ne2[1], ne2[2], ne2[3]));
|
||||
|
||||
struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
|
||||
ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3]
|
||||
tensor2, // [ne0,1,ne1*ne2,ne3]
|
||||
1, 0, 2, 3));
|
||||
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx,
|
||||
src0->grad, // [ne0,ne1,ne2,ne3]
|
||||
ggml_reshape(ctx, // [ne0,ne1,ne2,ne3]
|
||||
ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3]
|
||||
ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
||||
ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
||||
tensor2), // [ne0,1,ne1*ne2,ne3]
|
||||
ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
||||
tensor2_t, // [1,ne0,ne1*ne2,ne3]
|
||||
tensor2_t)), // [1,ne0,ne1*ne2,ne3]
|
||||
grad2), // [ne0,1,ne1*ne2,ne3]
|
||||
src0->grad),
|
||||
inplace);
|
||||
ggml_add_impl(ctx, src0->grad,
|
||||
ggml_soft_max_back(ctx, tensor->grad, tensor),
|
||||
inplace);
|
||||
}
|
||||
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
|
@ -13718,6 +13825,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||
} break;
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK:
|
||||
{
|
||||
|
|
12
ggml.h
12
ggml.h
|
@ -307,6 +307,7 @@ extern "C" {
|
|||
GGML_OP_DIAG_MASK_INF,
|
||||
GGML_OP_DIAG_MASK_ZERO,
|
||||
GGML_OP_SOFT_MAX,
|
||||
GGML_OP_SOFT_MAX_BACK,
|
||||
GGML_OP_ROPE,
|
||||
GGML_OP_ROPE_BACK,
|
||||
GGML_OP_ALIBI,
|
||||
|
@ -860,6 +861,17 @@ extern "C" {
|
|||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_soft_max_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
// in-place, returns view(a)
|
||||
GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
// rotary position embedding
|
||||
// if mode & 1 == 1, skip n_past elements
|
||||
// if mode & 2 == 1, GPT-NeoX style
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue