Fix ggml_vec_mad_q4_1 too
This commit is contained in:
parent
a2e9d4951b
commit
edbd4a0534
1 changed files with 7 additions and 6 deletions
13
ggml.c
13
ggml.c
|
@ -1925,16 +1925,17 @@ inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * res
|
|||
assert(n % QK == 0);
|
||||
|
||||
const int nb = n / QK;
|
||||
const size_t bs = 2*sizeof(float) + QK/2;
|
||||
|
||||
const float * restrict pm = (const float *) (x);
|
||||
const float * restrict pd = (const float *) (pm + nb);
|
||||
const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
|
||||
const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
|
||||
const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
|
||||
const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float m = pm[i];
|
||||
const float d = pd[i];
|
||||
const float d = *(const float *) (pd + i*bs);
|
||||
const float m = *(const float *) (pm + i*bs);
|
||||
|
||||
const uint8_t * restrict pp = pb + i*QK/2;
|
||||
const uint8_t * restrict pp = pb + i*bs;
|
||||
|
||||
for (int l = 0; l < QK; l += 2) {
|
||||
const uint8_t vi = pp[l/2];
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue