correctly implement softmax backward pass using new operation ggml_diag

ggml_diag constructs diagonal matrices with entries.
ggml_diag(shape[a,1,c,d]) -> shape[a,a,c,d]
This commit is contained in:
xaedes 2023-04-27 00:13:43 +02:00
parent 54ab300cc4
commit 1a80e9a0fa
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 137 additions and 14 deletions

146
ggml.c
View file

@ -3991,6 +3991,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"TRANSPOSE",
"GET_ROWS",
"GET_ROWS_BACK",
"DIAG",
"DIAG_MASK_INF",
"DIAG_MASK_ZERO",
"SOFT_MAX",
@ -4007,7 +4008,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"MAP_BINARY",
};
static_assert(GGML_OP_COUNT == 45, "GGML_OP_COUNT != 45");
static_assert(GGML_OP_COUNT == 46, "GGML_OP_COUNT != 46");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -4047,6 +4048,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"transpose(x)",
"get_rows(x)",
"get_rows_back(x)",
"diag(x)",
"diag_mask_inf(x)",
"diag_mask_zero(x)",
"soft_max(x)",
@ -4063,7 +4065,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"f(x,y)",
};
static_assert(GGML_OP_COUNT == 45, "GGML_OP_COUNT != 45");
static_assert(GGML_OP_COUNT == 46, "GGML_OP_COUNT != 46");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@ -6175,6 +6177,30 @@ struct ggml_tensor * ggml_get_rows_back(
return result;
}
// ggml_diag
struct ggml_tensor * ggml_diag(
struct ggml_context * ctx,
struct ggml_tensor * a) {
GGML_ASSERT(a->ne[1] == 1);
bool is_node = false;
if (a->grad) {
is_node = true;
}
const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, MAX(a->n_dims, 2), ne);
result->op = GGML_OP_DIAG;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = NULL;
return result;
}
// ggml_diag_mask_inf
struct ggml_tensor * ggml_diag_mask_inf_impl(
@ -10269,6 +10295,79 @@ static void ggml_compute_forward_get_rows_back(
//}
}
// ggml_compute_forward_diag
static void ggml_compute_forward_diag_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
assert(params->ith == 0);
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
// TODO: handle transposed/permuted matrices
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];
const int ne0 = dst->ne[0];
const int ne1 = dst->ne[1];
const int ne2 = dst->ne[2];
const int ne3 = dst->ne[3];
assert(ne00 == ne0);
assert(ne00 == ne1);
assert(ne01 == 1);
assert(ne02 == ne2);
assert(ne03 == ne3);
const int nb00 = src0->nb[0];
const int nb01 = src0->nb[1];
const int nb02 = src0->nb[2];
const int nb03 = src0->nb[3];
const int nb0 = dst->nb[0];
const int nb1 = dst->nb[1];
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];
assert(nb00 == sizeof(float));
assert(nb0 == sizeof(float));
for (int i3 = 0; i3 < ne3; i3++) {
for (int i2 = 0; i2 < ne2; i2++) {
for (int i1 = 0; i1 < ne1; i1++) {
float * d = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02);
for (int i0 = 0; i0 < i1; i0++) {
d[i0] = 0;
}
d[i1] = s[i1];
for (int i0 = i1+1; i0 < ne0; i0++) {
d[i0] = 0;
}
}
}
}
}
static void ggml_compute_forward_diag(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_diag_f32(params, src0, dst);
} break;
default:
{
GGML_ASSERT(false);
} break;
}
}
// ggml_compute_forward_diag_mask_inf
static void ggml_compute_forward_diag_mask_f32(
@ -10392,7 +10491,7 @@ static void ggml_compute_forward_soft_max_f32(
if (sp[i] == -INFINITY) {
dp[i] = 0.0f;
} else {
//const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max);
// const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max);
ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max);
memcpy(&scvt, &s, sizeof(scvt));
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
@ -12443,6 +12542,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_get_rows_back(params, tensor->src0, tensor->src1, tensor);
} break;
case GGML_OP_DIAG:
{
ggml_compute_forward_diag(params, tensor->src0, tensor);
} break;
case GGML_OP_DIAG_MASK_INF:
{
ggml_compute_forward_diag_mask_inf(params, tensor->src0, tensor->src1, tensor);
@ -12906,6 +13009,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// noop
}
} break;
case GGML_OP_DIAG:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_DIAG_MASK_INF:
{
// necessary for llama
@ -12943,20 +13050,30 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// necessary for llama
if (src0->grad) {
// y = softmax(x)
// dx = dy * y - sum(dy * y) * y
// dx = y * (dy - sum(dy * y))
//
// Jii = yi - yi*yi
// Jij = -yi*yj
// J = diag(y)-y.*y
// dx = J * dy
// dxk = sum(Jkj * dyk)
struct ggml_tensor * tensor_t = ggml_cont(ctx,
ggml_permute(ctx,
ggml_reshape(ctx,
tensor,
ggml_new_tensor(ctx,
tensor->type,
4, tensor->ne)),
1, 0, 2, 3));
src0->grad =
ggml_add_impl(ctx,
src0->grad,
ggml_mul(ctx,
tensor,
ggml_add1(ctx,
tensor->grad,
ggml_neg(ctx,
ggml_sum(ctx,
ggml_mul(ctx,
tensor->grad,
tensor))))),
ggml_mul_mat(ctx,
ggml_sub(ctx,
ggml_diag(ctx, tensor),
ggml_mul_mat(ctx, tensor_t, tensor_t)),
tensor->grad),
inplace);
}
} break;
@ -13480,6 +13597,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
case GGML_OP_TRANSPOSE:
case GGML_OP_GET_ROWS:
case GGML_OP_GET_ROWS_BACK:
case GGML_OP_DIAG:
case GGML_OP_DIAG_MASK_INF:
{
node->n_tasks = 1;

5
ggml.h
View file

@ -285,6 +285,7 @@ extern "C" {
GGML_OP_TRANSPOSE,
GGML_OP_GET_ROWS,
GGML_OP_GET_ROWS_BACK,
GGML_OP_DIAG,
GGML_OP_DIAG_MASK_INF,
GGML_OP_DIAG_MASK_ZERO,
GGML_OP_SOFT_MAX,
@ -700,6 +701,10 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_diag(
struct ggml_context * ctx,
struct ggml_tensor * a);
// set elements above the diagonal to -INF
GGML_API struct ggml_tensor * ggml_diag_mask_inf(
struct ggml_context * ctx,