add operation ggml_sum_rows
ggml_sum_rows(shape[a,b,c,d]) -> shape[1,b,c,d]
This commit is contained in:
parent
2277053839
commit
c4539ede53
2 changed files with 112 additions and 3 deletions
109
ggml.c
109
ggml.c
|
@ -3966,6 +3966,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
|||
"SQR",
|
||||
"SQRT",
|
||||
"SUM",
|
||||
"SUM_ROWS",
|
||||
"MEAN",
|
||||
"REPEAT",
|
||||
"ABS",
|
||||
|
@ -4008,7 +4009,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
|||
"MAP_BINARY",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 46, "GGML_OP_COUNT != 46");
|
||||
static_assert(GGML_OP_COUNT == 47, "GGML_OP_COUNT != 47");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
|
@ -4023,6 +4024,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"x^2",
|
||||
"√x",
|
||||
"Σx",
|
||||
"Σx_k",
|
||||
"Σx/n",
|
||||
"repeat(x)",
|
||||
"abs(x)",
|
||||
|
@ -4065,7 +4067,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"f(x,y)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 46, "GGML_OP_COUNT != 46");
|
||||
static_assert(GGML_OP_COUNT == 47, "GGML_OP_COUNT != 47");
|
||||
|
||||
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
||||
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
||||
|
@ -5322,6 +5324,33 @@ struct ggml_tensor * ggml_sum(
|
|||
return result;
|
||||
}
|
||||
|
||||
|
||||
// ggml_sum_rows
|
||||
|
||||
struct ggml_tensor * ggml_sum_rows(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
int64_t ne[4] = {1,1,1,1};
|
||||
for (int i=1; i<a->n_dims; ++i) {
|
||||
ne[i] = a->ne[i];
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, a->n_dims, ne);
|
||||
|
||||
result->op = GGML_OP_SUM_ROWS;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src0 = a;
|
||||
result->src1 = NULL;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_mean
|
||||
|
||||
struct ggml_tensor * ggml_mean(
|
||||
|
@ -8502,6 +8531,73 @@ static void ggml_compute_forward_sum(
|
|||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_sum_rows
|
||||
|
||||
static void ggml_compute_forward_sum_rows_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(params->ith == 0);
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(dst->nb[0] == sizeof(float));
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
const int64_t ne3 = dst->ne[3];
|
||||
|
||||
GGML_ASSERT(ne0 == 1);
|
||||
GGML_ASSERT(ne1 == ne01);
|
||||
GGML_ASSERT(ne2 == ne02);
|
||||
GGML_ASSERT(ne3 == ne03);
|
||||
|
||||
const size_t nb01 = src0->nb[1];
|
||||
const size_t nb02 = src0->nb[2];
|
||||
const size_t nb03 = src0->nb[3];
|
||||
|
||||
const size_t nb1 = dst->nb[1];
|
||||
const size_t nb2 = dst->nb[2];
|
||||
const size_t nb3 = dst->nb[3];
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne03; i3++) {
|
||||
for (int64_t i2 = 0; i2 < ne02; i2++) {
|
||||
for (int64_t i1 = 0; i1 < ne01; i1++) {
|
||||
float* src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
|
||||
float* dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
float row_sum = 0;
|
||||
ggml_vec_sum_f32(ne00, &row_sum, src_row);
|
||||
dst_row[0] = row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_sum_rows(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_sum_rows_f32(params, src0, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_mean
|
||||
|
||||
static void ggml_compute_forward_mean_f32(
|
||||
|
@ -12681,6 +12777,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
{
|
||||
ggml_compute_forward_sum(params, tensor->src0, tensor);
|
||||
} break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
{
|
||||
ggml_compute_forward_sum_rows(params, tensor->src0, tensor);
|
||||
} break;
|
||||
case GGML_OP_MEAN:
|
||||
{
|
||||
ggml_compute_forward_mean(params, tensor->src0, tensor);
|
||||
|
@ -12980,6 +13080,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
inplace);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: implement
|
||||
} break;
|
||||
case GGML_OP_MEAN:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: implement
|
||||
|
@ -13758,6 +13862,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_ABS:
|
||||
|
|
6
ggml.h
6
ggml.h
|
@ -535,11 +535,15 @@ extern "C" {
|
|||
struct ggml_tensor * a);
|
||||
|
||||
// return scalar
|
||||
// TODO: compute sum along rows
|
||||
GGML_API struct ggml_tensor * ggml_sum(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// sums along rows, with input shape [a,b,c,d] return shape [1,b,c,d]
|
||||
GGML_API struct ggml_tensor * ggml_sum_rows(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// mean along rows
|
||||
GGML_API struct ggml_tensor * ggml_mean(
|
||||
struct ggml_context * ctx,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue