ggml : test dot product q4_0 x f32
This commit is contained in:
parent
42747220b4
commit
72cd433066
1 changed files with 44 additions and 15 deletions
59
ggml.c
59
ggml.c
|
@ -2754,6 +2754,33 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
||||||
*s = sumf;
|
*s = sumf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vec_dot_q4_0_f32(const int n, float * restrict s, const void * restrict vx, const float * restrict y) {
|
||||||
|
const static float kValues[16] = {-8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f};
|
||||||
|
const uint32_t kMask1 = 0x0f0f0f0f;
|
||||||
|
uint32_t u1, u2;
|
||||||
|
const uint8_t * q1 = (const uint8_t*)&u1;
|
||||||
|
const uint8_t * q2 = (const uint8_t*)&u2;
|
||||||
|
const block_q4_0 * restrict x = vx;
|
||||||
|
double sum = 0;
|
||||||
|
for (int i=0; i<n; i += QK4_0) {
|
||||||
|
float d = x->d;
|
||||||
|
const uint32_t * u = (const uint32_t *)(x->qs);
|
||||||
|
float s = 0;
|
||||||
|
for (int k=0; k<4; ++k) {
|
||||||
|
u1 = (u[k] ) & kMask1;
|
||||||
|
u2 = (u[k] >> 4) & kMask1;
|
||||||
|
s += y[0]*kValues[q1[0]] + y[1]*kValues[q2[0]] +
|
||||||
|
y[2]*kValues[q1[1]] + y[3]*kValues[q2[1]] +
|
||||||
|
y[4]*kValues[q1[2]] + y[5]*kValues[q2[2]] +
|
||||||
|
y[6]*kValues[q1[3]] + y[7]*kValues[q2[3]];
|
||||||
|
y += 8;
|
||||||
|
}
|
||||||
|
sum += s*d;
|
||||||
|
++x;
|
||||||
|
}
|
||||||
|
*s = sum;
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||||
const int nb = n / QK8_0;
|
const int nb = n / QK8_0;
|
||||||
|
|
||||||
|
@ -7659,17 +7686,17 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (params->type == GGML_TASK_INIT) {
|
if (params->type == GGML_TASK_INIT) {
|
||||||
char * wdata = params->wdata;
|
//char * wdata = params->wdata;
|
||||||
const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
|
//const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
|
||||||
|
|
||||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
//for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
// for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
// for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||||
quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
|
// quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
|
||||||
wdata += row_size;
|
// wdata += row_size;
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
//}
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -7690,8 +7717,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
||||||
const int ir0 = dr*ith;
|
const int ir0 = dr*ith;
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
void * wdata = params->wdata;
|
//void * wdata = params->wdata;
|
||||||
const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
|
//const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
|
||||||
|
|
||||||
for (int ir = ir0; ir < ir1; ++ir) {
|
for (int ir = ir0; ir < ir1; ++ir) {
|
||||||
// src0 indices
|
// src0 indices
|
||||||
|
@ -7706,15 +7733,17 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
||||||
const int i2 = i02;
|
const int i2 = i02;
|
||||||
const int i3 = i03;
|
const int i3 = i03;
|
||||||
|
|
||||||
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
|
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
|
||||||
char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));
|
float * src1_col = (float *)((char *) src1->data + ( 0 + i12*nb12 + i13*nb13));
|
||||||
|
//char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));
|
||||||
|
|
||||||
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
|
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
|
||||||
|
|
||||||
assert(ne00 % 32 == 0);
|
assert(ne00 % 32 == 0);
|
||||||
|
|
||||||
for (int64_t ic = 0; ic < ne11; ++ic) {
|
for (int64_t ic = 0; ic < ne11; ++ic) {
|
||||||
vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
|
//vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
|
||||||
|
ggml_vec_dot_q4_0_f32(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue