Arm AArch64: minor code refactoring
This commit is contained in:
parent
4ff0b223c3
commit
42724b4d02
3 changed files with 10 additions and 10 deletions
|
@ -5,9 +5,6 @@
|
|||
#include "ggml-quants.h"
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#define GGML_COMMON_IMPL_C
|
||||
#include "ggml-common.h"
|
||||
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
#include <assert.h>
|
||||
|
@ -304,7 +301,8 @@ static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict ds
|
|||
else if (nrows_interleaved == 4) {
|
||||
out_ptr = (block_q4_0x4 *) dst;
|
||||
}
|
||||
block_q4_0 dst_tmp[nrows_interleaved];
|
||||
assert(nrows_interleaved <= 8);
|
||||
block_q4_0 dst_tmp[8];
|
||||
|
||||
for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) {
|
||||
|
||||
|
|
|
@ -2414,10 +2414,10 @@ extern "C" {
|
|||
const void * GGML_RESTRICT y, size_t by, int nrc);
|
||||
typedef void (*ggml_from_float_to_mat_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr,
|
||||
int64_t k, int64_t bx);
|
||||
typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
|
||||
const void * GGML_RESTRICT y, int nr, int nc);
|
||||
typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
|
||||
const void * GGML_RESTRICT y, int nr, int nc);
|
||||
|
||||
typedef struct {
|
||||
const char * type_name;
|
||||
|
|
|
@ -12383,13 +12383,15 @@ UseGgmlGemm2:;
|
|||
if (src0_start >= src0_end) return;
|
||||
|
||||
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
||||
if (gemm && (ne11 > 3))
|
||||
if (gemm && (ne11 > 3)) {
|
||||
gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01,
|
||||
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
||||
for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++)
|
||||
}
|
||||
for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) {
|
||||
gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
||||
(const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
|
||||
src0_end - src0_start);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue