Reorg the code.
This commit is contained in:
parent
daa87b1813
commit
700c782dc1
1 changed files with 25 additions and 10 deletions
35
ggml.c
35
ggml.c
|
@ -12056,8 +12056,24 @@ UseGgmlGemm1:;
|
|||
UseGgmlGemm2:;
|
||||
#endif
|
||||
|
||||
const int64_t nr0 = ne01; // src0 rows
|
||||
const int64_t nr1 = ne1*ne12*ne13; // src1 rows
|
||||
#ifdef GGML_PERF
|
||||
int chunks_executed = 0;
|
||||
UNUSED(chunks_executed);
|
||||
#endif
|
||||
|
||||
//This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
|
||||
const int64_t nr0 = ne0;
|
||||
|
||||
//This is the size of the rest of the dimensions of the result
|
||||
const int64_t nr1 = ne1 * ne2 * ne3;
|
||||
|
||||
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
|
||||
int64_t num_rows_per_vec_dot = vec_dot_num_rows;
|
||||
// TODO: currently the mmla kernels support only even numbered rows/cols.
|
||||
// this check can be removed once they are extended to support odd numbered rows/cols too
|
||||
if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
|
||||
num_rows_per_vec_dot = 1;
|
||||
}
|
||||
|
||||
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
|
||||
|
||||
|
@ -12084,15 +12100,14 @@ UseGgmlGemm2:;
|
|||
return;
|
||||
}
|
||||
|
||||
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
|
||||
int64_t num_rows_per_vec_dot = vec_dot_num_rows;
|
||||
// TODO: currently the mmla kernels support only even numbered rows/cols.
|
||||
// this check can be removed once they are extended to support odd numbered rows/cols too
|
||||
if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
|
||||
num_rows_per_vec_dot = 1;
|
||||
}
|
||||
|
||||
ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
|
||||
|
||||
#ifdef GGML_PERF
|
||||
// These numbers are useful when trying to measure how well the threading scheduling works.
|
||||
//int64_t workSize = (ne01 * ne11 * ne12 * ne13 * ne00) / nchunk0 / nchunk1;
|
||||
//float time = (ggml_perf_time_us() - t0);
|
||||
//printf("MUL_MAT = %f ms, [%d, %d, %d, %d] x [%d, %d, %d, %d] = %I64u, %f ops/usec in %d chunks.\n", time / 1000.0, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, workSize, (float)workSize/time, chunks_executed);
|
||||
#endif
|
||||
}
|
||||
|
||||
// ggml_compute_forward_mul_mat_id
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue