diff --git a/ggml.c b/ggml.c index 4fe4d748b..bef4cac8f 100644 --- a/ggml.c +++ b/ggml.c @@ -3966,6 +3966,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "SQR", "SQRT", "SUM", + "SUM_ROWS", "MEAN", "REPEAT", "ABS", @@ -4008,7 +4009,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "MAP_BINARY", }; -static_assert(GGML_OP_COUNT == 46, "GGML_OP_COUNT != 46"); +static_assert(GGML_OP_COUNT == 47, "GGML_OP_COUNT != 47"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -4023,6 +4024,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "x^2", "√x", "Σx", + "Σx_k", "Σx/n", "repeat(x)", "abs(x)", @@ -4065,7 +4067,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "f(x,y)", }; -static_assert(GGML_OP_COUNT == 46, "GGML_OP_COUNT != 46"); +static_assert(GGML_OP_COUNT == 47, "GGML_OP_COUNT != 47"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); @@ -5322,6 +5324,33 @@ struct ggml_tensor * ggml_sum( return result; } + +// ggml_sum_rows + +struct ggml_tensor * ggml_sum_rows( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + int64_t ne[4] = {1,1,1,1}; + for (int i=1; in_dims; ++i) { + ne[i] = a->ne[i]; + } + + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, a->n_dims, ne); + + result->op = GGML_OP_SUM_ROWS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + // ggml_mean struct ggml_tensor * ggml_mean( @@ -8502,6 +8531,73 @@ static void ggml_compute_forward_sum( } } +// ggml_compute_forward_sum_rows + +static void ggml_compute_forward_sum_rows_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(dst->nb[0] == sizeof(float)); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + GGML_ASSERT(ne0 == 1); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + + for (int64_t i3 = 0; i3 < ne03; i3++) { + for (int64_t i2 = 0; i2 < ne02; i2++) { + for (int64_t i1 = 0; i1 < ne01; i1++) { + float* src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); + float* dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); + float row_sum = 0; + ggml_vec_sum_f32(ne00, &row_sum, src_row); + dst_row[0] = row_sum; + } + } + } +} + +static void ggml_compute_forward_sum_rows( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sum_rows_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_mean static void ggml_compute_forward_mean_f32( @@ -12681,6 +12777,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_sum(params, tensor->src0, tensor); } break; + case GGML_OP_SUM_ROWS: + { + ggml_compute_forward_sum_rows(params, tensor->src0, tensor); + } break; case GGML_OP_MEAN: { ggml_compute_forward_mean(params, tensor->src0, tensor); @@ -12980,6 +13080,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor inplace); } } break; + case GGML_OP_SUM_ROWS: + { + GGML_ASSERT(false); // TODO: implement + } break; case GGML_OP_MEAN: { GGML_ASSERT(false); // TODO: implement @@ -13758,6 +13862,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_REPEAT: case GGML_OP_ABS: diff --git a/ggml.h b/ggml.h index 15a9f3faf..884d26b23 100644 --- a/ggml.h +++ b/ggml.h @@ -535,11 +535,15 @@ extern "C" { struct ggml_tensor * a); // return scalar - // TODO: compute sum along rows GGML_API struct ggml_tensor * ggml_sum( struct ggml_context * ctx, struct ggml_tensor * a); + // sums along rows, with input shape [a,b,c,d] return shape [1,b,c,d] + GGML_API struct ggml_tensor * ggml_sum_rows( + struct ggml_context * ctx, + struct ggml_tensor * a); + // mean along rows GGML_API struct ggml_tensor * ggml_mean( struct ggml_context * ctx,