Arm AArch64: minor code refactoring

This commit is contained in:
Dibakar Gope 2024-07-08 04:19:04 +00:00
parent 4ff0b223c3
commit 42724b4d02
3 changed files with 10 additions and 10 deletions

View file

@ -5,9 +5,6 @@
#include "ggml-quants.h" #include "ggml-quants.h"
#include "ggml-impl.h" #include "ggml-impl.h"
#define GGML_COMMON_IMPL_C
#include "ggml-common.h"
#include <math.h> #include <math.h>
#include <string.h> #include <string.h>
#include <assert.h> #include <assert.h>
@ -304,7 +301,8 @@ static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict ds
else if (nrows_interleaved == 4) { else if (nrows_interleaved == 4) {
out_ptr = (block_q4_0x4 *) dst; out_ptr = (block_q4_0x4 *) dst;
} }
block_q4_0 dst_tmp[nrows_interleaved]; assert(nrows_interleaved <= 8);
block_q4_0 dst_tmp[8];
for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) { for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) {

View file

@ -2414,10 +2414,10 @@ extern "C" {
const void * GGML_RESTRICT y, size_t by, int nrc); const void * GGML_RESTRICT y, size_t by, int nrc);
typedef void (*ggml_from_float_to_mat_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr, typedef void (*ggml_from_float_to_mat_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr,
int64_t k, int64_t bx); int64_t k, int64_t bx);
typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
const void * GGML_RESTRICT vy, int nr, int nc); const void * GGML_RESTRICT y, int nr, int nc);
typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
const void * GGML_RESTRICT vy, int nr, int nc); const void * GGML_RESTRICT y, int nr, int nc);
typedef struct { typedef struct {
const char * type_name; const char * type_name;

View file

@ -12383,13 +12383,15 @@ UseGgmlGemm2:;
if (src0_start >= src0_end) return; if (src0_start >= src0_end) return;
// If there are more than three rows in src1, use gemm; otherwise, use gemv. // If there are more than three rows in src1, use gemm; otherwise, use gemv.
if (gemm && (ne11 > 3)) if (gemm && (ne11 > 3)) {
gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01, gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01,
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) }
for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) {
gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01, gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1, (const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
src0_end - src0_start); src0_end - src0_start);
}
return; return;
} }