add ggml_out_prod and use it for mul_mat backward pass for improved performance

performance stats report improvement from 37 seconds to 16 seconds runtime during my training tests
This commit is contained in:
xaedes 2023-05-15 14:17:42 +02:00
parent a703d7a85f
commit efa4bb78ea
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 246 additions and 31 deletions

264
ggml.c
View file

@ -3310,6 +3310,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"RMS_NORM_BACK", "RMS_NORM_BACK",
"MUL_MAT", "MUL_MAT",
"OUT_PROD",
"SCALE", "SCALE",
"SET", "SET",
@ -3339,7 +3340,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"MAP_BINARY", "MAP_BINARY",
}; };
static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51"); static_assert(GGML_OP_COUNT == 52, "GGML_OP_COUNT != 52");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none", "none",
@ -3370,6 +3371,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"rms_norm(x)", "rms_norm(x)",
"rms_norm_back(x)", "rms_norm_back(x)",
"X*Y",
"X*Y", "X*Y",
"x*v", "x*v",
@ -3400,7 +3402,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"f(x,y)", "f(x,y)",
}; };
static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51"); static_assert(GGML_OP_COUNT == 52, "GGML_OP_COUNT != 52");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@ -3566,6 +3568,15 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
(t0->ne[3] == t1->ne[3]); (t0->ne[3] == t1->ne[3]);
} }
static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return
(t0->ne[1] == t1->ne[1]) &&
(t0->ne[2] == t1->ne[2]) &&
(t0->ne[3] == t1->ne[3]);
}
bool ggml_is_quantized(enum ggml_type type) { bool ggml_is_quantized(enum ggml_type type) {
return GGML_IS_QUANTIZED[type]; return GGML_IS_QUANTIZED[type];
} }
@ -5156,6 +5167,32 @@ struct ggml_tensor * ggml_mul_mat(
return result; return result;
} }
// ggml_out_prod
struct ggml_tensor * ggml_out_prod(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
GGML_ASSERT(ggml_can_out_prod(a, b));
GGML_ASSERT(!ggml_is_transposed(a));
bool is_node = false;
if (a->grad || b->grad) {
is_node = true;
}
const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
result->op = GGML_OP_OUT_PROD;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = b;
return result;
}
// ggml_scale // ggml_scale
struct ggml_tensor * ggml_scale_impl( struct ggml_tensor * ggml_scale_impl(
@ -9802,6 +9839,178 @@ static void ggml_compute_forward_mul_mat(
} }
} }
// ggml_compute_forward_out_prod
static void ggml_compute_forward_out_prod_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
int64_t t0 = ggml_perf_time_us();
UNUSED(t0);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
const int64_t ne12 = src1->ne[2];
const int64_t ne13 = src1->ne[3];
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
const int64_t ne2 = dst->ne[2];
const int64_t ne3 = dst->ne[3];
const int nb00 = src0->nb[0];
const int nb01 = src0->nb[1];
const int nb02 = src0->nb[2];
const int nb03 = src0->nb[3];
const int nb10 = src1->nb[0];
const int nb11 = src1->nb[1];
const int nb12 = src1->nb[2];
const int nb13 = src1->nb[3];
const int nb0 = dst->nb[0];
const int nb1 = dst->nb[1];
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];
const int ith = params->ith;
const int nth = params->nth;
GGML_ASSERT(ne02 == ne12);
GGML_ASSERT(ne03 == ne13);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == sizeof(float));
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
// GGML_ASSERT(nb0 <= nb1);
// GGML_ASSERT(nb1 <= nb2);
// GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne10);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne3 == ne03);
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
// TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
// TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
if (params->type == GGML_TASK_INIT) {
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
return;
}
if (params->type == GGML_TASK_FINALIZE) {
return;
}
// parallelize by last two dimensions
// total parallel in src0
const int64_t np = ne02*ne03;
// per thread
const int64_t dp = (np + nth - 1)/nth;
// range for this thread
const int64_t ip0 = dp*ith;
const int64_t ip1 = MIN(ip0 + dp, np);
// dst[:,:,:,:] = 0
// for i2,i3:
// for i01:
// for i1:
// for i0:
// dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
for (int64_t ip = ip0; ip < ip1; ++ip) {
// src0 indices
const int64_t i3 = ip/ne02;
const int64_t i2 = ip - i3*ne02;
const int64_t i02 = i2;
const int64_t i03 = i3;
const int64_t i12 = i2;
const int64_t i13 = i3;
for (int64_t i01 = 0; i01 < ne01; ++i01) {
const int64_t i11 = i01;
for (int64_t i1 = 0; i1 < ne1; ++i1) {
const int64_t i10 = i1;
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
ggml_vec_mad_f32(ne0, d, s0, *s1);
// for (int64_t i0 = 0; i0 < ne0; ++i0) {
// d[i0] += s0[i0] * s1[i1];
// }
}
}
}
//int64_t t1 = ggml_perf_time_us();
//static int64_t acc = 0;
//acc += t1 - t0;
//if (t1 - t0 > 10) {
// printf("\n");
// printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
// printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
// printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
// printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
// printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
//}
}
static void ggml_compute_forward_out_prod(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
{
GGML_ASSERT(false); // todo
// ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_F16:
{
GGML_ASSERT(false); // todo
// ggml_compute_forward_out_prod_f16_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_F32:
{
ggml_compute_forward_out_prod_f32(params, src0, src1, dst);
} break;
default:
{
GGML_ASSERT(false);
} break;
}
}
// ggml_compute_forward_scale // ggml_compute_forward_scale
static void ggml_compute_forward_scale_f32( static void ggml_compute_forward_scale_f32(
@ -10380,7 +10589,7 @@ static void ggml_compute_forward_diag_mask_f32(
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return; return;
} }
// TODO: handle transposed/permuted matrices // TODO: handle transposed/permuted matrices
const int n = ggml_nrows(src0); const int n = ggml_nrows(src0);
@ -10541,7 +10750,7 @@ static void ggml_compute_forward_soft_max_back_f32(
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return; return;
} }
// TODO: handle transposed/permuted matrices // TODO: handle transposed/permuted matrices
const int ith = params->ith; const int ith = params->ith;
@ -10580,7 +10789,7 @@ static void ggml_compute_forward_soft_max_back_f32(
// dxk = -yk * dot(y, dy) + yk*dyk // dxk = -yk * dot(y, dy) + yk*dyk
// dxk = yk * (- dot(y, dy) + dyk) // dxk = yk * (- dot(y, dy) + dyk)
// dxk = yk * (dyk - dot(y, dy)) // dxk = yk * (dyk - dot(y, dy))
// //
// post-order: // post-order:
// dot_y_dy := dot(y, dy) // dot_y_dy := dot(y, dy)
// dx := dy // dx := dy
@ -12611,6 +12820,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{ {
ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor); ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
} break; } break;
case GGML_OP_OUT_PROD:
{
ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor);
} break;
case GGML_OP_SCALE: case GGML_OP_SCALE:
{ {
ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor); ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
@ -13041,45 +13254,37 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// necessary for llama // necessary for llama
if (src0->grad) { if (src0->grad) {
// TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
src0->grad = src0->grad =
ggml_add_impl(ctx, ggml_add_impl(ctx,
src0->grad, src0->grad,
// ds0 = dt.dot(s1.T) ggml_out_prod(ctx, // [n,m]
// ggml_out_prod(ctx, // [n,m] src1, // [n,p]
// src1, // [n,p] tensor->grad), // [m,p]
// tensor->grad), // [m,p]
// for now just using A*B==(B.T*A.T).T
ggml_mul_mat(ctx, // [n,m]
ggml_cont(ctx, // [p,n]
ggml_transpose(ctx, // [p,n]
src1)), // [n,p]
ggml_cont(ctx, // [p,m]
ggml_transpose(ctx, // [p,m]
tensor->grad))), // [m,p]
inplace); inplace);
} }
if (src1->grad) { if (src1->grad) {
src1->grad = src1->grad =
ggml_add_impl(ctx, ggml_add_impl(ctx,
src1->grad, src1->grad,
// ds1 = s0.T.dot(dt): // ggml_mul_mat(ctx, // [n,p]
ggml_mul_mat(ctx, // [n,p] // ggml_cont(ctx, // [m,n]
ggml_cont(ctx, // [m,n] // ggml_transpose(ctx, src0)), // [m,n]
ggml_transpose(ctx, src0)), // [m,n] // tensor->grad), // [m,p]
tensor->grad), // [m,p]
// // when src0 is bigger than tensor->grad (this is the case in llama), // // when src0 is bigger than tensor->grad (this is mostly the case in llama),
// // avoid transpose of src0, rather transpose smaller tensor->grad // // avoid transpose of src0, rather transpose smaller tensor->grad
// // and then use ggml_out_prod // // and then use ggml_out_prod
// ggml_out_prod(ctx, // [n,p] ggml_out_prod(ctx, // [n,p]
// src0, // [n,m] src0, // [n,m]
// ggml_cont(ctx, // [p,m] ggml_transpose(ctx, // [p,m]
// ggml_transpose(ctx, // [p,m] tensor->grad)), // [m,p]
// tensor->grad)), // [m,p]
inplace); inplace);
} }
} break; } break;
case GGML_OP_OUT_PROD:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_SCALE: case GGML_OP_SCALE:
{ {
// necessary for llama // necessary for llama
@ -13757,6 +13962,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
node->n_tasks = n_threads; node->n_tasks = n_threads;
} break; } break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
case GGML_OP_OUT_PROD:
{ {
node->n_tasks = n_threads; node->n_tasks = n_threads;

13
ggml.h
View file

@ -292,6 +292,7 @@ extern "C" {
GGML_OP_RMS_NORM_BACK, GGML_OP_RMS_NORM_BACK,
GGML_OP_MUL_MAT, GGML_OP_MUL_MAT,
GGML_OP_OUT_PROD,
GGML_OP_SCALE, GGML_OP_SCALE,
GGML_OP_SET, GGML_OP_SET,
@ -643,14 +644,22 @@ extern "C" {
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b); struct ggml_tensor * b);
// A: m rows, n columns // A: n columns, m rows
// B: p rows, n columns (i.e. we transpose it internally) // B: n columns, p rows (i.e. we transpose it internally)
// result is m columns, p rows // result is m columns, p rows
GGML_API struct ggml_tensor * ggml_mul_mat( GGML_API struct ggml_tensor * ggml_mul_mat(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b); struct ggml_tensor * b);
// A: m columns, n rows,
// B: p columns, n rows,
// result is m columns, p rows
GGML_API struct ggml_tensor * ggml_out_prod(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
// //
// operations on tensors without backpropagation // operations on tensors without backpropagation
// //