implement ggml_soft_max_back for more performant backward pass of soft_max
avoids creating big intermediate matrices of size n_embd x n_embd for llama layers and n_vocab x n_vocab for cross entropy loss
This commit is contained in:
parent
f89c278d83
commit
ec1aea09ec
2 changed files with 164 additions and 44 deletions
196
ggml.c
196
ggml.c
|
@ -3325,6 +3325,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
||||||
"DIAG_MASK_INF",
|
"DIAG_MASK_INF",
|
||||||
"DIAG_MASK_ZERO",
|
"DIAG_MASK_ZERO",
|
||||||
"SOFT_MAX",
|
"SOFT_MAX",
|
||||||
|
"SOFT_MAX_BACK",
|
||||||
"ROPE",
|
"ROPE",
|
||||||
"ROPE_BACK",
|
"ROPE_BACK",
|
||||||
"ALIBI",
|
"ALIBI",
|
||||||
|
@ -3338,7 +3339,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
||||||
"MAP_BINARY",
|
"MAP_BINARY",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50");
|
static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"none",
|
"none",
|
||||||
|
@ -3385,6 +3386,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"diag_mask_inf(x)",
|
"diag_mask_inf(x)",
|
||||||
"diag_mask_zero(x)",
|
"diag_mask_zero(x)",
|
||||||
"soft_max(x)",
|
"soft_max(x)",
|
||||||
|
"soft_max_back(x)",
|
||||||
"rope(x)",
|
"rope(x)",
|
||||||
"rope_back(x)",
|
"rope_back(x)",
|
||||||
"alibi(x)",
|
"alibi(x)",
|
||||||
|
@ -3398,7 +3400,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"f(x,y)",
|
"f(x,y)",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50");
|
static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
|
||||||
|
|
||||||
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
||||||
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
||||||
|
@ -5927,6 +5929,44 @@ struct ggml_tensor * ggml_soft_max_inplace(
|
||||||
return ggml_soft_max_impl(ctx, a, true);
|
return ggml_soft_max_impl(ctx, a, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ggml_soft_max_back
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_soft_max_back_impl(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
bool inplace) {
|
||||||
|
bool is_node = false;
|
||||||
|
|
||||||
|
if (a->grad || b->grad) {
|
||||||
|
is_node = true; // TODO : implement backward pass
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
|
|
||||||
|
result->op = GGML_OP_SOFT_MAX_BACK;
|
||||||
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
result->src0 = a;
|
||||||
|
result->src1 = b;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_soft_max_back(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b) {
|
||||||
|
return ggml_soft_max_back_impl(ctx, a, b, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_soft_max_back_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b) {
|
||||||
|
return ggml_soft_max_back_impl(ctx, a, b, true);
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_rope
|
// ggml_rope
|
||||||
|
|
||||||
struct ggml_tensor * ggml_rope_impl(
|
struct ggml_tensor * ggml_rope_impl(
|
||||||
|
@ -10482,6 +10522,103 @@ static void ggml_compute_forward_soft_max(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_compute_forward_soft_max_back
|
||||||
|
|
||||||
|
static void ggml_compute_forward_soft_max_back_f32(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
const struct ggml_tensor * src0,
|
||||||
|
const struct ggml_tensor * src1,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||||
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||||
|
GGML_ASSERT(ggml_are_same_shape(src1, dst));
|
||||||
|
|
||||||
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: handle transposed/permuted matrices
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int nc = src0->ne[0];
|
||||||
|
const int nr = ggml_nrows(src0);
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr*ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
|
float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
|
||||||
|
float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
|
||||||
|
float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
for (int i = 0; i < nc; ++i) {
|
||||||
|
//printf("p[%d] = %f\n", i, p[i]);
|
||||||
|
assert(!isnan(s0[i]));
|
||||||
|
assert(!isnan(s1[i]));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
// Jii = yi - yi*yi
|
||||||
|
// Jij = -yi*yj
|
||||||
|
// J = diag(y)-y.T*y
|
||||||
|
// dx = J * dy
|
||||||
|
// dxk = sum_i(Jki * dyi)
|
||||||
|
|
||||||
|
// quadratic runtime, linear memory
|
||||||
|
for (int k = 0; k < nc; k++) {
|
||||||
|
|
||||||
|
ggml_float sum = 0.0;
|
||||||
|
|
||||||
|
for (int i = 0; i < k; i++) {
|
||||||
|
float Jki = -y[k]*y[i];
|
||||||
|
sum += (ggml_float) Jki * dy[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
float Jkk = y[k] - y[k]*y[k];
|
||||||
|
sum += (ggml_float) Jkk * dy[k];
|
||||||
|
|
||||||
|
for (int i = k+1; i < nc; i++) {
|
||||||
|
float Jki = -y[k]*y[i];
|
||||||
|
sum += (ggml_float) Jki * dy[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
dx[k] = (float) sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
for (int i = 0; i < nc; ++i) {
|
||||||
|
assert(!isnan(dx[i]));
|
||||||
|
assert(!isinf(dx[i]));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_soft_max_back(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
const struct ggml_tensor * src0,
|
||||||
|
const struct ggml_tensor * src1,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
} break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_alibi
|
// ggml_compute_forward_alibi
|
||||||
|
|
||||||
static void ggml_compute_forward_alibi_f32(
|
static void ggml_compute_forward_alibi_f32(
|
||||||
|
@ -12529,6 +12666,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
{
|
{
|
||||||
ggml_compute_forward_soft_max(params, tensor->src0, tensor);
|
ggml_compute_forward_soft_max(params, tensor->src0, tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_SOFT_MAX_BACK:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
|
ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
|
||||||
|
@ -13146,50 +13287,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
{
|
{
|
||||||
// necessary for llama
|
// necessary for llama
|
||||||
if (src0->grad) {
|
if (src0->grad) {
|
||||||
// y = softmax(x)
|
|
||||||
//
|
|
||||||
// Jii = yi - yi*yi
|
|
||||||
// Jij = -yi*yj
|
|
||||||
// J = diag(y)-y.*y
|
|
||||||
// dx = J * dy
|
|
||||||
// dxk = sum(Jkj * dyk)
|
|
||||||
|
|
||||||
int64_t ne2[4] = {
|
|
||||||
tensor->ne[0],
|
|
||||||
1,
|
|
||||||
tensor->ne[1]*tensor->ne[2],
|
|
||||||
tensor->ne[3]
|
|
||||||
};
|
|
||||||
struct ggml_tensor * tensor2 = ggml_cont(ctx,
|
|
||||||
ggml_reshape_4d(ctx,
|
|
||||||
ggml_cont(ctx, tensor),
|
|
||||||
ne2[0], ne2[1], ne2[2], ne2[3]));
|
|
||||||
|
|
||||||
struct ggml_tensor * grad2 = ggml_cont(ctx,
|
|
||||||
ggml_reshape_4d(ctx,
|
|
||||||
ggml_cont(ctx, tensor->grad),
|
|
||||||
ne2[0], ne2[1], ne2[2], ne2[3]));
|
|
||||||
|
|
||||||
struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
|
|
||||||
ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3]
|
|
||||||
tensor2, // [ne0,1,ne1*ne2,ne3]
|
|
||||||
1, 0, 2, 3));
|
|
||||||
|
|
||||||
src0->grad =
|
src0->grad =
|
||||||
ggml_add_impl(ctx,
|
ggml_add_impl(ctx, src0->grad,
|
||||||
src0->grad, // [ne0,ne1,ne2,ne3]
|
ggml_soft_max_back(ctx, tensor->grad, tensor),
|
||||||
ggml_reshape(ctx, // [ne0,ne1,ne2,ne3]
|
inplace);
|
||||||
ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3]
|
|
||||||
ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
|
||||||
ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
|
||||||
tensor2), // [ne0,1,ne1*ne2,ne3]
|
|
||||||
ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
|
||||||
tensor2_t, // [1,ne0,ne1*ne2,ne3]
|
|
||||||
tensor2_t)), // [1,ne0,ne1*ne2,ne3]
|
|
||||||
grad2), // [ne0,1,ne1*ne2,ne3]
|
|
||||||
src0->grad),
|
|
||||||
inplace);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} break;
|
||||||
|
case GGML_OP_SOFT_MAX_BACK:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(false); // TODO: not implemented
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
{
|
{
|
||||||
|
@ -13718,6 +13825,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
|
case GGML_OP_SOFT_MAX_BACK:
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
case GGML_OP_ROPE_BACK:
|
case GGML_OP_ROPE_BACK:
|
||||||
{
|
{
|
||||||
|
|
12
ggml.h
12
ggml.h
|
@ -307,6 +307,7 @@ extern "C" {
|
||||||
GGML_OP_DIAG_MASK_INF,
|
GGML_OP_DIAG_MASK_INF,
|
||||||
GGML_OP_DIAG_MASK_ZERO,
|
GGML_OP_DIAG_MASK_ZERO,
|
||||||
GGML_OP_SOFT_MAX,
|
GGML_OP_SOFT_MAX,
|
||||||
|
GGML_OP_SOFT_MAX_BACK,
|
||||||
GGML_OP_ROPE,
|
GGML_OP_ROPE,
|
||||||
GGML_OP_ROPE_BACK,
|
GGML_OP_ROPE_BACK,
|
||||||
GGML_OP_ALIBI,
|
GGML_OP_ALIBI,
|
||||||
|
@ -860,6 +861,17 @@ extern "C" {
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_soft_max_back(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
// in-place, returns view(a)
|
||||||
|
GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
// rotary position embedding
|
// rotary position embedding
|
||||||
// if mode & 1 == 1, skip n_past elements
|
// if mode & 1 == 1, skip n_past elements
|
||||||
// if mode & 2 == 1, GPT-NeoX style
|
// if mode & 2 == 1, GPT-NeoX style
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue