diff --git a/ggml-aarch64.c b/ggml-aarch64.c index f5b6ec896..d7f7f5ed5 100644 --- a/ggml-aarch64.c +++ b/ggml-aarch64.c @@ -5,9 +5,6 @@ #include "ggml-quants.h" #include "ggml-impl.h" -#define GGML_COMMON_IMPL_C -#include "ggml-common.h" - #include #include #include @@ -304,7 +301,8 @@ static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict ds else if (nrows_interleaved == 4) { out_ptr = (block_q4_0x4 *) dst; } - block_q4_0 dst_tmp[nrows_interleaved]; + assert(nrows_interleaved <= 8); + block_q4_0 dst_tmp[8]; for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) { diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 42dd224e6..1e3677537 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2414,10 +2414,10 @@ extern "C" { const void * GGML_RESTRICT y, size_t by, int nrc); typedef void (*ggml_from_float_to_mat_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr, int64_t k, int64_t bx); - typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, int nr, int nc); - typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, int nr, int nc); + typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, + const void * GGML_RESTRICT y, int nr, int nc); + typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, + const void * GGML_RESTRICT y, int nr, int nc); typedef struct { const char * type_name; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index bb515ee05..725e3fc7a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -12383,13 +12383,15 @@ UseGgmlGemm2:; if (src0_start >= src0_end) return; // If there are more than three rows in src1, use gemm; otherwise, use gemv. - if (gemm && (ne11 > 3)) + if (gemm && (ne11 > 3)) { gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01, (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); - for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) + } + for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) { gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01, (const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1, src0_end - src0_start); + } return; }