ggml-ci
This commit is contained in:
slaren 2024-04-17 19:12:34 +02:00
parent d68c935c8d
commit 997a9b5bd2

25
ggml.c
View file

@ -11018,11 +11018,6 @@ static void ggml_compute_forward_mul_mat_id(
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
//GGML_ASSERT(ne0 == ne01);
//GGML_ASSERT(ne1 == ne11);
//GGML_ASSERT(ne2 == ne12);
//GGML_ASSERT(ne3 == ne13);
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type));
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
@ -11041,8 +11036,13 @@ static void ggml_compute_forward_mul_mat_id(
(char *) params->wdata :
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
struct mmid_row_mapping {
int32_t i1;
int32_t i2;
};
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
int64_t * matrix_rows = matrix_row_counts + n_as; // [n_as][ne11]
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
if (params->type == GGML_TASK_TYPE_INIT) {
if (ith != 0) {
@ -11069,9 +11069,6 @@ static void ggml_compute_forward_mul_mat_id(
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
#define MAKE_I64(lo, hi) (((int64_t)(lo)) | (((int64_t)(hi)) << 32))
#define LO_I64(i64) ((int32_t)(i64))
#define HI_I64(i64) ((int32_t)((i64) >> 32))
// group rows by src0 matrix
for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
@ -11080,7 +11077,7 @@ static void ggml_compute_forward_mul_mat_id(
assert(i02 >= 0 && i02 < n_as);
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = MAKE_I64(iid1, id);
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
matrix_row_counts[i02] += 1;
}
}
@ -11143,10 +11140,11 @@ static void ggml_compute_forward_mul_mat_id(
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
const int64_t _i12 = ir1; // logical row index for this expert
const int id = HI_I64(MMID_MATRIX_ROW(cur_a, _i12)); // selected expert index
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
const int id = row_mapping.i1; // selected expert index
const int64_t i11 = id % ne11;
const int64_t i12 = LO_I64(MMID_MATRIX_ROW(cur_a, _i12)); // row index in src1
const int64_t i12 = row_mapping.i2; // row index in src1
const int64_t i1 = id; // selected expert index
const int64_t i2 = i12; // row
@ -11177,9 +11175,6 @@ static void ggml_compute_forward_mul_mat_id(
}
#undef MMID_MATRIX_ROW
#undef MAKE_I64
#undef LO_I64
#undef HI_I64
}
// ggml_compute_forward_out_prod