Arm AArch64: minimize changes in ggml_compute_forward_mul_mat

This commit is contained in:
Dibakar Gope 2024-06-26 07:32:53 +00:00
parent ffbfabb517
commit cbbfd69f42

View file

@ -12276,23 +12276,19 @@ UseGgmlGemm1:;
} }
} }
} }
if (from_float_to_mat && gemm && (ne11 >= 4) && (ne12 == 1) && (ne13 == 1)) { for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i11 = 0; i11 < ne11 / 4; ++i11) { for (int64_t i12 = 0; i12 < ne12; ++i12) {
from_float_to_mat((float *)((char *) src1->data + i11 * 4 * nb11), (void *) wdata, ne10); int64_t i11_processed = 0;
wdata += row_size * 4; if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) {
} for (int64_t i11 = 0; i11 < ne11 - ne11 % 4; i11 += 4) {
for (int64_t i11 = (ne11 / 4) * 4; i11 < ne11; ++i11) { from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
from_float_to_vec_dot((float *)((char *) src1->data + i11 * nb11), (void *) wdata, ne10); wdata += row_size * 4;
wdata += row_size;
}
}
else {
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
wdata += row_size;
} }
i11_processed = ne11 - ne11 % 4;
}
for (int64_t i11 = i11_processed; i11 < ne11; ++i11) {
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
wdata += row_size;
} }
} }
} }
@ -12374,51 +12370,46 @@ UseGgmlGemm2:;
//if (ith == 0) //if (ith == 0)
// printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1); // printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);
const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; if ((ggml_n_dims(src0) == 2) && gemv) {
const size_t src1_col_stride = ggml_is_contiguous(src1) || src1->type != vec_dot_type ? ggml_row_size(vec_dot_type, ne10) : nb11; const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
int64_t src0_start = (ith * ne01) / nth; const size_t src1_col_stride = ggml_is_contiguous(src1) || src1->type != vec_dot_type ? ggml_row_size(vec_dot_type, ne10) : nb11;
int64_t src0_end = ((ith + 1) * ne01) / nth; int64_t src0_start = (ith * ne01) / nth;
src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start; int64_t src0_end = ((ith + 1) * ne01) / nth;
src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end; src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start;
src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end;
if ((ggml_n_dims(src0) == 2) && gemm && gemv) {
if (src0_start >= src0_end) return; if (src0_start >= src0_end) return;
// If there are more than three rows in src1, use gemm; otherwise, use gemv. // If there are more than three rows in src1, use gemm; otherwise, use gemv.
if (ne11 > 3) if (gemm && (ne11 > 3))
gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01, gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01,
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++)
gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01, gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1, (const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
src0_end - src0_start); src0_end - src0_start);
} else if ((ggml_n_dims(src0) == 2) && gemv) { return;
if (src0_start >= src0_end) return; }
for (int iter = 0; iter < ne11; iter++)
gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
src0_end - src0_start);
} else {
// The first chunk comes from our thread_id, the rest will get auto-assigned.
int current_chunk = ith;
while (current_chunk < nchunk0 * nchunk1) { // The first chunk comes from our thread_id, the rest will get auto-assigned.
const int64_t ith0 = current_chunk % nchunk0; int current_chunk = ith;
const int64_t ith1 = current_chunk / nchunk0;
const int64_t ir0_start = dr0 * ith0; while (current_chunk < nchunk0 * nchunk1) {
const int64_t ir0_end = MIN(ir0_start + dr0, nr0); const int64_t ith0 = current_chunk % nchunk0;
const int64_t ith1 = current_chunk / nchunk0;
const int64_t ir1_start = dr1 * ith1; const int64_t ir0_start = dr0 * ith0;
const int64_t ir1_end = MIN(ir1_start + dr1, nr1); const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); const int64_t ir1_start = dr1 * ith1;
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
if (nth >= nchunk0 * nchunk1) { ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
break;
}
current_chunk = atomic_fetch_add(&params->shared->current_chunk, 1); if (nth >= nchunk0 * nchunk1) {
break;
} }
current_chunk = atomic_fetch_add(&params->shared->current_chunk, 1);
} }
} }