ggml/ex: calculate accuracy in graph, adapt MNIST (ggml/980)

This commit is contained in:
Johannes Gäßler 2024-10-03 17:29:59 +02:00 committed by Georgi Gerganov
parent eee39bdc96
commit fabdc3bda3
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
11 changed files with 389 additions and 8 deletions

View file

@ -2994,6 +2994,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"SUM_ROWS",
"MEAN",
"ARGMAX",
"COUNT_EQUAL",
"REPEAT",
"REPEAT_BACK",
"CONCAT",
@ -3067,7 +3068,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"OPT_STEP_ADAMW",
};
static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -3088,6 +3089,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"Σx_k",
"Σx/n",
"argmax(x)",
"count_equal(x)",
"repeat(x)",
"repeat_back(x)",
"concat(x, y)",
@ -3161,7 +3163,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"adamw(x)",
};
static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -5222,6 +5224,23 @@ struct ggml_tensor * ggml_argmax(
return result;
}
// ggml_count_equal
struct ggml_tensor * ggml_count_equal(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
GGML_ASSERT(ggml_are_same_shape(a, b));
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, 1);
result->op = GGML_OP_COUNT_EQUAL;
result->src[0] = a;
result->src[1] = b;
return result;
}
// ggml_repeat
struct ggml_tensor * ggml_repeat(
@ -10809,6 +10828,86 @@ static void ggml_compute_forward_argmax(
}
}
// ggml_compute_forward_count_equal
static void ggml_compute_forward_count_equal_i32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS;
GGML_ASSERT(src0->type == GGML_TYPE_I32);
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_are_same_shape(src0, src1));
GGML_ASSERT(ggml_is_scalar(dst));
GGML_ASSERT(dst->type == GGML_TYPE_I64);
const int64_t nr = ggml_nrows(src0);
const int ith = params->ith;
const int nth = params->nth;
int64_t * sums = (int64_t *) params->wdata;
int64_t sum_thread = 0;
// rows per thread
const int64_t dr = (nr + nth - 1)/nth;
// row range for this thread
const int64_t ir0 = dr*ith;
const int64_t ir1 = MIN(ir0 + dr, nr);
for (int64_t ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir / (ne02*ne01);
const int64_t i02 = (ir - i03*ne03) / ne01;
const int64_t i01 = ir - i03*ne03 - i02*ne02;
const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
for (int64_t i00 = 0; i00 < ne00; ++i00) {
const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
sum_thread += val0 == val1;
}
}
if (ith != 0) {
sums[ith] = sum_thread;
}
ggml_barrier(params->threadpool);
if (ith != 0) {
return;
}
for (int ith_other = 1; ith_other < nth; ++ith_other) {
sum_thread += sums[ith_other];
}
*((int64_t *) dst->data) = sum_thread;
}
static void ggml_compute_forward_count_equal(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_I32:
{
ggml_compute_forward_count_equal_i32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_repeat
static void ggml_compute_forward_repeat_f32(
@ -17187,6 +17286,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_argmax(params, tensor);
} break;
case GGML_OP_COUNT_EQUAL:
{
ggml_compute_forward_count_equal(params, tensor);
} break;
case GGML_OP_REPEAT:
{
ggml_compute_forward_repeat(params, tensor);
@ -17937,6 +18040,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
} break;
case GGML_OP_MEAN:
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
{
GGML_ABORT("fatal error"); // TODO: implement
}
@ -18710,6 +18814,10 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
for (int i = 0; i < gf->n_nodes; ++i) {
struct ggml_tensor * node = gf->nodes[i];
if (node->type == GGML_TYPE_I32) {
continue;
}
bool needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM;
bool ignore_src[GGML_MAX_SRC] = {false};
switch (node->op) {
@ -19113,6 +19221,13 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_ARGMAX:
{
n_tasks = 1;
} break;
case GGML_OP_COUNT_EQUAL:
{
n_tasks = n_threads;
} break;
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_LEAKY_RELU:
@ -19611,6 +19726,10 @@ struct ggml_cplan ggml_graph_plan(
cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
}
} break;
case GGML_OP_COUNT_EQUAL:
{
cur = ggml_type_size(node->type)*n_tasks;
} break;
case GGML_OP_MUL_MAT:
{
const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type;