ggml : in ggml_ssm_scan, merge multiple rows in the same vec operation

This will help with performance on CPU if ggml_vec_mul_f32
and ggml_vec_add_f32 are ever optimized with SIMD.
This commit is contained in:
Francis Couture-Harpin 2024-02-01 21:16:40 -05:00
parent ffc116f5ec
commit 78a853b788

41
ggml.c
View file

@ -14807,13 +14807,17 @@ static void ggml_compute_forward_ssm_scan_f32(
const int nth = params->nth; const int nth = params->nth;
const int64_t nc = src1->ne[0]; const int64_t nc = src1->ne[0];
const int64_t n_b = src1->ne[2]; // number of batches const int64_t n_t = src1->ne[2]; // number of tokens in the batch
const int64_t nr0 = ggml_nrows(src0); const int64_t nr0 = ggml_nrows(src0);
GGML_ASSERT(nc*n_b*nr0 == ggml_nelements(src1)); GGML_ASSERT(nc*n_t*nr0 == ggml_nelements(src1));
GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float));
// allow merging multiple rows in the same vec operation
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
GGML_ASSERT(src1->nb[1] == src1->ne[0]*sizeof(float));
GGML_ASSERT(src2->nb[1] == src2->ne[0]*sizeof(float));
// rows per thread // rows per thread
const int dr = (nr0 + nth - 1)/nth; const int dr = (nr0 + nth - 1)/nth;
@ -14821,27 +14825,26 @@ static void ggml_compute_forward_ssm_scan_f32(
// row range for this thread // row range for this thread
const int ir0 = dr*ith; const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr0); const int ir1 = MIN(ir0 + dr, nr0);
const int ir = ir1 - ir0;
// first batch // first batch
for (int i1 = ir0; i1 < ir1; i1++) { {
float * dest = (float *) ((char *) dst->data + i1*( dst->nb[1])); float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]));
float * s = (float *) ((char *) src0->data + i1*(src0->nb[1])); float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1]));
float * dA = (float *) ((char *) src1->data + i1*(src1->nb[1])); float * dA = (float *) ((char *) src1->data + ir0*(src1->nb[1]));
float * dB_x = (float *) ((char *) src2->data + i1*(src2->nb[1])); float * dB_x = (float *) ((char *) src2->data + ir0*(src2->nb[1]));
ggml_vec_mul_f32(nc, dest, s, dA); ggml_vec_mul_f32(nc*ir, dest, s, dA);
ggml_vec_add_f32(nc, dest, dest, dB_x); ggml_vec_add_f32(nc*ir, dest, dest, dB_x);
} }
// rest of batches, state comes from dest // compute state for rest of tokens, previous state comes from dest
for (int i2 = 1; i2 < n_b; i2++) { for (int i2 = 1; i2 < n_t; i2++) {
for (int i1 = ir0; i1 < ir1; i1++) { float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2]));
float * dest = (float *) ((char *) dst->data + i1*( dst->nb[1]) + i2 *( dst->nb[2])); float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2]));
float * s = (float *) ((char *) dst->data + i1*( dst->nb[1]) + (i2-1)*( dst->nb[2])); float * dA = (float *) ((char *) src1->data + ir0*(src1->nb[1]) + i2 *(src1->nb[2]));
float * dA = (float *) ((char *) src1->data + i1*(src1->nb[1]) + i2 *(src1->nb[2])); float * dB_x = (float *) ((char *) src2->data + ir0*(src2->nb[1]) + i2 *(src2->nb[2]));
float * dB_x = (float *) ((char *) src2->data + i1*(src2->nb[1]) + i2 *(src2->nb[2])); ggml_vec_mul_f32(nc*ir, dest, s, dA);
ggml_vec_mul_f32(nc, dest, s, dA); ggml_vec_add_f32(nc*ir, dest, dest, dB_x);
ggml_vec_add_f32(nc, dest, dest, dB_x);
}
} }
} }