ggml : in ggml_ssm_scan, merge multiple rows in the same vec operation
This will help with performance on CPU if ggml_vec_mul_f32 and ggml_vec_add_f32 are ever optimized with SIMD.
This commit is contained in:
parent
ffc116f5ec
commit
78a853b788
1 changed files with 22 additions and 19 deletions
41
ggml.c
41
ggml.c
|
@ -14807,13 +14807,17 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
|
|
||||||
const int64_t nc = src1->ne[0];
|
const int64_t nc = src1->ne[0];
|
||||||
const int64_t n_b = src1->ne[2]; // number of batches
|
const int64_t n_t = src1->ne[2]; // number of tokens in the batch
|
||||||
const int64_t nr0 = ggml_nrows(src0);
|
const int64_t nr0 = ggml_nrows(src0);
|
||||||
|
|
||||||
GGML_ASSERT(nc*n_b*nr0 == ggml_nelements(src1));
|
GGML_ASSERT(nc*n_t*nr0 == ggml_nelements(src1));
|
||||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||||
|
// allow merging multiple rows in the same vec operation
|
||||||
|
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
||||||
|
GGML_ASSERT(src1->nb[1] == src1->ne[0]*sizeof(float));
|
||||||
|
GGML_ASSERT(src2->nb[1] == src2->ne[0]*sizeof(float));
|
||||||
|
|
||||||
// rows per thread
|
// rows per thread
|
||||||
const int dr = (nr0 + nth - 1)/nth;
|
const int dr = (nr0 + nth - 1)/nth;
|
||||||
|
@ -14821,27 +14825,26 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
// row range for this thread
|
// row range for this thread
|
||||||
const int ir0 = dr*ith;
|
const int ir0 = dr*ith;
|
||||||
const int ir1 = MIN(ir0 + dr, nr0);
|
const int ir1 = MIN(ir0 + dr, nr0);
|
||||||
|
const int ir = ir1 - ir0;
|
||||||
|
|
||||||
// first batch
|
// first batch
|
||||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
{
|
||||||
float * dest = (float *) ((char *) dst->data + i1*( dst->nb[1]));
|
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]));
|
||||||
float * s = (float *) ((char *) src0->data + i1*(src0->nb[1]));
|
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1]));
|
||||||
float * dA = (float *) ((char *) src1->data + i1*(src1->nb[1]));
|
float * dA = (float *) ((char *) src1->data + ir0*(src1->nb[1]));
|
||||||
float * dB_x = (float *) ((char *) src2->data + i1*(src2->nb[1]));
|
float * dB_x = (float *) ((char *) src2->data + ir0*(src2->nb[1]));
|
||||||
ggml_vec_mul_f32(nc, dest, s, dA);
|
ggml_vec_mul_f32(nc*ir, dest, s, dA);
|
||||||
ggml_vec_add_f32(nc, dest, dest, dB_x);
|
ggml_vec_add_f32(nc*ir, dest, dest, dB_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
// rest of batches, state comes from dest
|
// compute state for rest of tokens, previous state comes from dest
|
||||||
for (int i2 = 1; i2 < n_b; i2++) {
|
for (int i2 = 1; i2 < n_t; i2++) {
|
||||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2]));
|
||||||
float * dest = (float *) ((char *) dst->data + i1*( dst->nb[1]) + i2 *( dst->nb[2]));
|
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2]));
|
||||||
float * s = (float *) ((char *) dst->data + i1*( dst->nb[1]) + (i2-1)*( dst->nb[2]));
|
float * dA = (float *) ((char *) src1->data + ir0*(src1->nb[1]) + i2 *(src1->nb[2]));
|
||||||
float * dA = (float *) ((char *) src1->data + i1*(src1->nb[1]) + i2 *(src1->nb[2]));
|
float * dB_x = (float *) ((char *) src2->data + ir0*(src2->nb[1]) + i2 *(src2->nb[2]));
|
||||||
float * dB_x = (float *) ((char *) src2->data + i1*(src2->nb[1]) + i2 *(src2->nb[2]));
|
ggml_vec_mul_f32(nc*ir, dest, s, dA);
|
||||||
ggml_vec_mul_f32(nc, dest, s, dA);
|
ggml_vec_add_f32(nc*ir, dest, dest, dB_x);
|
||||||
ggml_vec_add_f32(nc, dest, dest, dB_x);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue