cleanup
ggml-ci
This commit is contained in:
parent
d68c935c8d
commit
997a9b5bd2
1 changed files with 10 additions and 15 deletions
25
ggml.c
25
ggml.c
|
@ -11018,11 +11018,6 @@ static void ggml_compute_forward_mul_mat_id(
|
|||
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
||||
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
||||
|
||||
//GGML_ASSERT(ne0 == ne01);
|
||||
//GGML_ASSERT(ne1 == ne11);
|
||||
//GGML_ASSERT(ne2 == ne12);
|
||||
//GGML_ASSERT(ne3 == ne13);
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == ggml_type_size(type));
|
||||
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
||||
|
@ -11041,8 +11036,13 @@ static void ggml_compute_forward_mul_mat_id(
|
|||
(char *) params->wdata :
|
||||
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
|
||||
|
||||
struct mmid_row_mapping {
|
||||
int32_t i1;
|
||||
int32_t i2;
|
||||
};
|
||||
|
||||
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
|
||||
int64_t * matrix_rows = matrix_row_counts + n_as; // [n_as][ne11]
|
||||
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
|
||||
|
||||
if (params->type == GGML_TASK_TYPE_INIT) {
|
||||
if (ith != 0) {
|
||||
|
@ -11069,9 +11069,6 @@ static void ggml_compute_forward_mul_mat_id(
|
|||
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
|
||||
|
||||
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
|
||||
#define MAKE_I64(lo, hi) (((int64_t)(lo)) | (((int64_t)(hi)) << 32))
|
||||
#define LO_I64(i64) ((int32_t)(i64))
|
||||
#define HI_I64(i64) ((int32_t)((i64) >> 32))
|
||||
|
||||
// group rows by src0 matrix
|
||||
for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
|
||||
|
@ -11080,7 +11077,7 @@ static void ggml_compute_forward_mul_mat_id(
|
|||
|
||||
assert(i02 >= 0 && i02 < n_as);
|
||||
|
||||
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = MAKE_I64(iid1, id);
|
||||
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
|
||||
matrix_row_counts[i02] += 1;
|
||||
}
|
||||
}
|
||||
|
@ -11143,10 +11140,11 @@ static void ggml_compute_forward_mul_mat_id(
|
|||
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
|
||||
const int64_t _i12 = ir1; // logical row index for this expert
|
||||
|
||||
const int id = HI_I64(MMID_MATRIX_ROW(cur_a, _i12)); // selected expert index
|
||||
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
|
||||
const int id = row_mapping.i1; // selected expert index
|
||||
|
||||
const int64_t i11 = id % ne11;
|
||||
const int64_t i12 = LO_I64(MMID_MATRIX_ROW(cur_a, _i12)); // row index in src1
|
||||
const int64_t i12 = row_mapping.i2; // row index in src1
|
||||
|
||||
const int64_t i1 = id; // selected expert index
|
||||
const int64_t i2 = i12; // row
|
||||
|
@ -11177,9 +11175,6 @@ static void ggml_compute_forward_mul_mat_id(
|
|||
}
|
||||
|
||||
#undef MMID_MATRIX_ROW
|
||||
#undef MAKE_I64
|
||||
#undef LO_I64
|
||||
#undef HI_I64
|
||||
}
|
||||
|
||||
// ggml_compute_forward_out_prod
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue