Moving some variable definitions into the chunk function.
This commit is contained in:
parent
7b932e4908
commit
4f95478ea3
1 changed files with 11 additions and 8 deletions
19
ggml.c
19
ggml.c
|
@ -11779,10 +11779,7 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
||||||
const int64_t ir1_end,
|
const int64_t ir1_end,
|
||||||
|
|
||||||
const bool src1_cont,
|
const bool src1_cont,
|
||||||
const int64_t r2,
|
|
||||||
const int64_t r3,
|
|
||||||
enum ggml_type vec_dot_type,
|
enum ggml_type vec_dot_type,
|
||||||
const void* wdata,
|
|
||||||
const size_t row_size,
|
const size_t row_size,
|
||||||
const ggml_vec_dot_t vec_dot
|
const ggml_vec_dot_t vec_dot
|
||||||
) {
|
) {
|
||||||
|
@ -11794,6 +11791,14 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
||||||
|
|
||||||
const enum ggml_type type = src0->type;
|
const enum ggml_type type = src0->type;
|
||||||
|
|
||||||
|
const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||||
|
|
||||||
|
// broadcast factors
|
||||||
|
const int64_t r2 = ne12 / ne02;
|
||||||
|
const int64_t r3 = ne13 / ne03;
|
||||||
|
|
||||||
|
//printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end);
|
||||||
|
|
||||||
assert(ne12 % ne02 == 0);
|
assert(ne12 % ne02 == 0);
|
||||||
assert(ne13 % ne03 == 0);
|
assert(ne13 % ne03 == 0);
|
||||||
|
|
||||||
|
@ -11891,8 +11896,8 @@ static void ggml_compute_forward_mul_mat(
|
||||||
GGML_ASSERT(nb2 <= nb3);
|
GGML_ASSERT(nb2 <= nb3);
|
||||||
|
|
||||||
// broadcast factors
|
// broadcast factors
|
||||||
const int64_t r2 = ne12/ne02;
|
const int64_t r2 = ne12 / ne02;
|
||||||
const int64_t r3 = ne13/ne03;
|
const int64_t r3 = ne13 / ne03;
|
||||||
|
|
||||||
// nb01 >= nb00 - src0 is not transposed
|
// nb01 >= nb00 - src0 is not transposed
|
||||||
// compute by src0 rows
|
// compute by src0 rows
|
||||||
|
@ -12070,8 +12075,6 @@ UseGgmlGemm2:;
|
||||||
const int64_t ir1_start = dr1*ith1;
|
const int64_t ir1_start = dr1*ith1;
|
||||||
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
|
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
|
||||||
|
|
||||||
//printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end);
|
|
||||||
|
|
||||||
// threads with no work simply yield (not sure if it helps)
|
// threads with no work simply yield (not sure if it helps)
|
||||||
if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
|
if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
|
||||||
sched_yield();
|
sched_yield();
|
||||||
|
@ -12086,7 +12089,7 @@ UseGgmlGemm2:;
|
||||||
num_rows_per_vec_dot = 1;
|
num_rows_per_vec_dot = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end, src1_cont, r2, r3, vec_dot_type, wdata, row_size, vec_dot);
|
ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end, src1_cont, vec_dot_type, row_size, vec_dot);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_mul_mat_id
|
// ggml_compute_forward_mul_mat_id
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue