Arm AArch64: minor code refactoring
This commit is contained in:
parent
4ff0b223c3
commit
42724b4d02
3 changed files with 10 additions and 10 deletions
|
@ -5,9 +5,6 @@
|
||||||
#include "ggml-quants.h"
|
#include "ggml-quants.h"
|
||||||
#include "ggml-impl.h"
|
#include "ggml-impl.h"
|
||||||
|
|
||||||
#define GGML_COMMON_IMPL_C
|
|
||||||
#include "ggml-common.h"
|
|
||||||
|
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
@ -304,7 +301,8 @@ static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict ds
|
||||||
else if (nrows_interleaved == 4) {
|
else if (nrows_interleaved == 4) {
|
||||||
out_ptr = (block_q4_0x4 *) dst;
|
out_ptr = (block_q4_0x4 *) dst;
|
||||||
}
|
}
|
||||||
block_q4_0 dst_tmp[nrows_interleaved];
|
assert(nrows_interleaved <= 8);
|
||||||
|
block_q4_0 dst_tmp[8];
|
||||||
|
|
||||||
for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) {
|
for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) {
|
||||||
|
|
||||||
|
|
|
@ -2414,10 +2414,10 @@ extern "C" {
|
||||||
const void * GGML_RESTRICT y, size_t by, int nrc);
|
const void * GGML_RESTRICT y, size_t by, int nrc);
|
||||||
typedef void (*ggml_from_float_to_mat_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr,
|
typedef void (*ggml_from_float_to_mat_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr,
|
||||||
int64_t k, int64_t bx);
|
int64_t k, int64_t bx);
|
||||||
typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx,
|
typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
|
||||||
const void * GGML_RESTRICT vy, int nr, int nc);
|
const void * GGML_RESTRICT y, int nr, int nc);
|
||||||
typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx,
|
typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
|
||||||
const void * GGML_RESTRICT vy, int nr, int nc);
|
const void * GGML_RESTRICT y, int nr, int nc);
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
const char * type_name;
|
const char * type_name;
|
||||||
|
|
|
@ -12383,13 +12383,15 @@ UseGgmlGemm2:;
|
||||||
if (src0_start >= src0_end) return;
|
if (src0_start >= src0_end) return;
|
||||||
|
|
||||||
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
||||||
if (gemm && (ne11 > 3))
|
if (gemm && (ne11 > 3)) {
|
||||||
gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01,
|
gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01,
|
||||||
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
||||||
for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++)
|
}
|
||||||
|
for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) {
|
||||||
gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
||||||
(const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
|
(const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
|
||||||
src0_end - src0_start);
|
src0_end - src0_start);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue