ggml : test dot product q4_0 x f32

This commit is contained in:
Georgi Gerganov 2023-04-18 19:20:37 +03:00
parent 42747220b4
commit 72cd433066
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

59
ggml.c
View file

@ -2754,6 +2754,33 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
*s = sumf;
}
static void ggml_vec_dot_q4_0_f32(const int n, float * restrict s, const void * restrict vx, const float * restrict y) {
const static float kValues[16] = {-8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f};
const uint32_t kMask1 = 0x0f0f0f0f;
uint32_t u1, u2;
const uint8_t * q1 = (const uint8_t*)&u1;
const uint8_t * q2 = (const uint8_t*)&u2;
const block_q4_0 * restrict x = vx;
double sum = 0;
for (int i=0; i<n; i += QK4_0) {
float d = x->d;
const uint32_t * u = (const uint32_t *)(x->qs);
float s = 0;
for (int k=0; k<4; ++k) {
u1 = (u[k] ) & kMask1;
u2 = (u[k] >> 4) & kMask1;
s += y[0]*kValues[q1[0]] + y[1]*kValues[q2[0]] +
y[2]*kValues[q1[1]] + y[3]*kValues[q2[1]] +
y[4]*kValues[q1[2]] + y[5]*kValues[q2[2]] +
y[6]*kValues[q1[3]] + y[7]*kValues[q2[3]];
y += 8;
}
sum += s*d;
++x;
}
*s = sum;
}
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int nb = n / QK8_0;
@ -7659,17 +7686,17 @@ static void ggml_compute_forward_mul_mat_q_f32(
#endif
if (params->type == GGML_TASK_INIT) {
char * wdata = params->wdata;
const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
//char * wdata = params->wdata;
//const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
wdata += row_size;
}
}
}
//for (int64_t i13 = 0; i13 < ne13; ++i13) {
// for (int64_t i12 = 0; i12 < ne12; ++i12) {
// for (int64_t i11 = 0; i11 < ne11; ++i11) {
// quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
// wdata += row_size;
// }
// }
//}
return;
}
@ -7690,8 +7717,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
void * wdata = params->wdata;
const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
//void * wdata = params->wdata;
//const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
for (int ir = ir0; ir < ir1; ++ir) {
// src0 indices
@ -7706,15 +7733,17 @@ static void ggml_compute_forward_mul_mat_q_f32(
const int i2 = i02;
const int i3 = i03;
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
float * src1_col = (float *)((char *) src1->data + ( 0 + i12*nb12 + i13*nb13));
//char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
assert(ne00 % 32 == 0);
for (int64_t ic = 0; ic < ne11; ++ic) {
vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
//vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
ggml_vec_dot_q4_0_f32(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
}
}