From 997a9b5bd2f2c9c16067cea3a901be96c1203b7d Mon Sep 17 00:00:00 2001 From: slaren Date: Wed, 17 Apr 2024 19:12:34 +0200 Subject: [PATCH] cleanup ggml-ci --- ggml.c | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/ggml.c b/ggml.c index 821ce25ed..a3b312e4a 100644 --- a/ggml.c +++ b/ggml.c @@ -11018,11 +11018,6 @@ static void ggml_compute_forward_mul_mat_id( enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; - //GGML_ASSERT(ne0 == ne01); - //GGML_ASSERT(ne1 == ne11); - //GGML_ASSERT(ne2 == ne12); - //GGML_ASSERT(ne3 == ne13); - // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == ggml_type_size(type)); GGML_ASSERT(nb10 == ggml_type_size(src1->type)); @@ -11041,8 +11036,13 @@ static void ggml_compute_forward_mul_mat_id( (char *) params->wdata : (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t)); + struct mmid_row_mapping { + int32_t i1; + int32_t i2; + }; + int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] - int64_t * matrix_rows = matrix_row_counts + n_as; // [n_as][ne11] + struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11] if (params->type == GGML_TASK_TYPE_INIT) { if (ith != 0) { @@ -11069,9 +11069,6 @@ static void ggml_compute_forward_mul_mat_id( memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)] -#define MAKE_I64(lo, hi) (((int64_t)(lo)) | (((int64_t)(hi)) << 32)) -#define LO_I64(i64) ((int32_t)(i64)) -#define HI_I64(i64) ((int32_t)((i64) >> 32)) // group rows by src0 matrix for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { @@ -11080,7 +11077,7 @@ static void ggml_compute_forward_mul_mat_id( assert(i02 >= 0 && i02 < n_as); - MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = MAKE_I64(iid1, id); + MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1}; matrix_row_counts[i02] += 1; } } @@ -11143,10 +11140,11 @@ static void ggml_compute_forward_mul_mat_id( for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { const int64_t _i12 = ir1; // logical row index for this expert - const int id = HI_I64(MMID_MATRIX_ROW(cur_a, _i12)); // selected expert index + struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12); + const int id = row_mapping.i1; // selected expert index const int64_t i11 = id % ne11; - const int64_t i12 = LO_I64(MMID_MATRIX_ROW(cur_a, _i12)); // row index in src1 + const int64_t i12 = row_mapping.i2; // row index in src1 const int64_t i1 = id; // selected expert index const int64_t i2 = i12; // row @@ -11177,9 +11175,6 @@ static void ggml_compute_forward_mul_mat_id( } #undef MMID_MATRIX_ROW -#undef MAKE_I64 -#undef LO_I64 -#undef HI_I64 } // ggml_compute_forward_out_prod