diff --git a/ggml.c b/ggml.c index 3b5e0719a..e0cca5854 100644 --- a/ggml.c +++ b/ggml.c @@ -11779,9 +11779,7 @@ static void ggml_compute_forward_mul_mat_one_chunk( const int64_t ir1_end, const bool src1_cont, - enum ggml_type vec_dot_type, - const size_t row_size, - const ggml_vec_dot_t vec_dot + const size_t row_size ) { const struct ggml_tensor* src0 = dst->src[0]; @@ -11791,7 +11789,8 @@ static void ggml_compute_forward_mul_mat_one_chunk( const enum ggml_type type = src0->type; - const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; + enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; // broadcast factors const int64_t r2 = ne12 / ne02; @@ -11799,6 +11798,8 @@ static void ggml_compute_forward_mul_mat_one_chunk( //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end); + const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + assert(ne12 % ne02 == 0); assert(ne13 % ne03 == 0); @@ -11875,7 +11876,6 @@ static void ggml_compute_forward_mul_mat( const bool src1_cont = ggml_is_contiguous(src1); - ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; int64_t const vec_dot_num_rows = type_traits[type].nrows; @@ -12089,7 +12089,7 @@ UseGgmlGemm2:; num_rows_per_vec_dot = 1; } - ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end, src1_cont, vec_dot_type, row_size, vec_dot); + ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end, src1_cont, row_size); } // ggml_compute_forward_mul_mat_id