diff --git a/ggml.c b/ggml.c index 9a0a07aa5..77b654809 100644 --- a/ggml.c +++ b/ggml.c @@ -3310,6 +3310,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "RMS_NORM_BACK", "MUL_MAT", + "OUT_PROD", "SCALE", "SET", @@ -3339,7 +3340,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "MAP_BINARY", }; -static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51"); +static_assert(GGML_OP_COUNT == 52, "GGML_OP_COUNT != 52"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3370,6 +3371,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rms_norm(x)", "rms_norm_back(x)", + "X*Y", "X*Y", "x*v", @@ -3400,7 +3402,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "f(x,y)", }; -static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51"); +static_assert(GGML_OP_COUNT == 52, "GGML_OP_COUNT != 52"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); @@ -3566,6 +3568,15 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct (t0->ne[3] == t1->ne[3]); } +static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + (t0->ne[1] == t1->ne[1]) && + (t0->ne[2] == t1->ne[2]) && + (t0->ne[3] == t1->ne[3]); +} + bool ggml_is_quantized(enum ggml_type type) { return GGML_IS_QUANTIZED[type]; } @@ -5156,6 +5167,32 @@ struct ggml_tensor * ggml_mul_mat( return result; } +// ggml_out_prod + +struct ggml_tensor * ggml_out_prod( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_can_out_prod(a, b)); + GGML_ASSERT(!ggml_is_transposed(a)); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne); + + result->op = GGML_OP_OUT_PROD; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + // ggml_scale struct ggml_tensor * ggml_scale_impl( @@ -9802,6 +9839,178 @@ static void ggml_compute_forward_mul_mat( } } +// ggml_compute_forward_out_prod + + +static void ggml_compute_forward_out_prod_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == sizeof(float)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + // GGML_ASSERT(nb0 <= nb1); + // GGML_ASSERT(nb1 <= nb2); + // GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod + // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) + + if (params->type == GGML_TASK_INIT) { + ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by last two dimensions + + // total parallel in src0 + const int64_t np = ne02*ne03; + + // per thread + const int64_t dp = (np + nth - 1)/nth; + + // range for this thread + const int64_t ip0 = dp*ith; + const int64_t ip1 = MIN(ip0 + dp, np); + + // dst[:,:,:,:] = 0 + // for i2,i3: + // for i01: + // for i1: + // for i0: + // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] + + for (int64_t ip = ip0; ip < ip1; ++ip) { + // src0 indices + const int64_t i3 = ip/ne02; + const int64_t i2 = ip - i3*ne02; + + const int64_t i02 = i2; + const int64_t i03 = i3; + + const int64_t i12 = i2; + const int64_t i13 = i3; + + for (int64_t i01 = 0; i01 < ne01; ++i01) { + const int64_t i11 = i01; + + for (int64_t i1 = 0; i1 < ne1; ++i1) { + const int64_t i10 = i1; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + ggml_vec_mad_f32(ne0, d, s0, *s1); + // for (int64_t i0 = 0; i0 < ne0; ++i0) { + // d[i0] += s0[i0] * s1[i1]; + // } + } + } + } + + //int64_t t1 = ggml_perf_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); + + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void ggml_compute_forward_out_prod( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + { + GGML_ASSERT(false); // todo + // ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(false); // todo + // ggml_compute_forward_out_prod_f16_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_out_prod_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_scale static void ggml_compute_forward_scale_f32( @@ -10380,7 +10589,7 @@ static void ggml_compute_forward_diag_mask_f32( if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } - + // TODO: handle transposed/permuted matrices const int n = ggml_nrows(src0); @@ -10541,7 +10750,7 @@ static void ggml_compute_forward_soft_max_back_f32( if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } - + // TODO: handle transposed/permuted matrices const int ith = params->ith; @@ -10580,7 +10789,7 @@ static void ggml_compute_forward_soft_max_back_f32( // dxk = -yk * dot(y, dy) + yk*dyk // dxk = yk * (- dot(y, dy) + dyk) // dxk = yk * (dyk - dot(y, dy)) - // + // // post-order: // dot_y_dy := dot(y, dy) // dx := dy @@ -12611,6 +12820,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor); } break; + case GGML_OP_OUT_PROD: + { + ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor); + } break; case GGML_OP_SCALE: { ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor); @@ -13041,45 +13254,37 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // necessary for llama if (src0->grad) { - // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad); src0->grad = ggml_add_impl(ctx, src0->grad, - // ds0 = dt.dot(s1.T) - // ggml_out_prod(ctx, // [n,m] - // src1, // [n,p] - // tensor->grad), // [m,p] - // for now just using A*B==(B.T*A.T).T - ggml_mul_mat(ctx, // [n,m] - ggml_cont(ctx, // [p,n] - ggml_transpose(ctx, // [p,n] - src1)), // [n,p] - ggml_cont(ctx, // [p,m] - ggml_transpose(ctx, // [p,m] - tensor->grad))), // [m,p] + ggml_out_prod(ctx, // [n,m] + src1, // [n,p] + tensor->grad), // [m,p] inplace); } if (src1->grad) { src1->grad = ggml_add_impl(ctx, src1->grad, - // ds1 = s0.T.dot(dt): - ggml_mul_mat(ctx, // [n,p] - ggml_cont(ctx, // [m,n] - ggml_transpose(ctx, src0)), // [m,n] - tensor->grad), // [m,p] + // ggml_mul_mat(ctx, // [n,p] + // ggml_cont(ctx, // [m,n] + // ggml_transpose(ctx, src0)), // [m,n] + // tensor->grad), // [m,p] - // // when src0 is bigger than tensor->grad (this is the case in llama), + // // when src0 is bigger than tensor->grad (this is mostly the case in llama), // // avoid transpose of src0, rather transpose smaller tensor->grad // // and then use ggml_out_prod - // ggml_out_prod(ctx, // [n,p] - // src0, // [n,m] - // ggml_cont(ctx, // [p,m] - // ggml_transpose(ctx, // [p,m] - // tensor->grad)), // [m,p] + ggml_out_prod(ctx, // [n,p] + src0, // [n,m] + ggml_transpose(ctx, // [p,m] + tensor->grad)), // [m,p] inplace); } } break; + case GGML_OP_OUT_PROD: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_SCALE: { // necessary for llama @@ -13757,6 +13962,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) node->n_tasks = n_threads; } break; case GGML_OP_MUL_MAT: + case GGML_OP_OUT_PROD: { node->n_tasks = n_threads; diff --git a/ggml.h b/ggml.h index 0a0989516..aa75fd726 100644 --- a/ggml.h +++ b/ggml.h @@ -292,6 +292,7 @@ extern "C" { GGML_OP_RMS_NORM_BACK, GGML_OP_MUL_MAT, + GGML_OP_OUT_PROD, GGML_OP_SCALE, GGML_OP_SET, @@ -643,14 +644,22 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); - // A: m rows, n columns - // B: p rows, n columns (i.e. we transpose it internally) + // A: n columns, m rows + // B: n columns, p rows (i.e. we transpose it internally) // result is m columns, p rows GGML_API struct ggml_tensor * ggml_mul_mat( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b); + // A: m columns, n rows, + // B: p columns, n rows, + // result is m columns, p rows + GGML_API struct ggml_tensor * ggml_out_prod( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + // // operations on tensors without backpropagation //