Arm AArch64: minor code refactoring
This commit is contained in:
parent
110d143ece
commit
4ff0b223c3
1 changed files with 32 additions and 2 deletions
|
@ -12438,6 +12438,8 @@ static void ggml_compute_forward_mul_mat_id(
|
||||||
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
|
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
|
||||||
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
||||||
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
||||||
|
int64_t const matmul_num_cols = type_traits[type].ncols;
|
||||||
|
ggml_gemv_t const gemv = type_traits[type].gemv;
|
||||||
|
|
||||||
// we don't support permuted src0 or src1
|
// we don't support permuted src0 or src1
|
||||||
GGML_ASSERT(nb00 == ggml_type_size(type));
|
GGML_ASSERT(nb00 == ggml_type_size(type));
|
||||||
|
@ -12523,6 +12525,34 @@ static void ggml_compute_forward_mul_mat_id(
|
||||||
const int64_t nr0 = ne01; // src0 rows
|
const int64_t nr0 = ne01; // src0 rows
|
||||||
const int64_t nr1 = cne1; // src1 rows
|
const int64_t nr1 = cne1; // src1 rows
|
||||||
|
|
||||||
|
if (((ggml_n_dims(src0) - 1) == 2) && gemv) {
|
||||||
|
int64_t src0_cur_start = (ith * ne01) / nth;
|
||||||
|
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
|
||||||
|
src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start;
|
||||||
|
src0_cur_end = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end;
|
||||||
|
if (src0_cur_start >= src0_cur_end) return;
|
||||||
|
|
||||||
|
for (int ir1 = 0; ir1 < nr1; ir1++) {
|
||||||
|
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
|
||||||
|
const int id = row_mapping.i1; // selected expert index
|
||||||
|
|
||||||
|
const int64_t i11 = id % ne11;
|
||||||
|
const int64_t i12 = row_mapping.i2; // row index in src1
|
||||||
|
|
||||||
|
const int64_t i1 = id; // selected expert index
|
||||||
|
const int64_t i2 = i12; // row
|
||||||
|
|
||||||
|
const char * src1_col = (const char *) wdata +
|
||||||
|
(src1_cont || src1->type != vec_dot_type
|
||||||
|
? (i11 + i12 * ne11) * row_size
|
||||||
|
: (i11 * nb11 + i12 * nb12));
|
||||||
|
|
||||||
|
gemv(ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
|
||||||
|
(const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start);
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// distribute the thread work across the inner or outer loop based on which one is larger
|
// distribute the thread work across the inner or outer loop based on which one is larger
|
||||||
|
|
||||||
const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
|
const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue