diff --git a/ggml.c b/ggml.c index 177d8b3c3..b132bec68 100644 --- a/ggml.c +++ b/ggml.c @@ -14807,13 +14807,17 @@ static void ggml_compute_forward_ssm_scan_f32( const int nth = params->nth; const int64_t nc = src1->ne[0]; - const int64_t n_b = src1->ne[2]; // number of batches + const int64_t n_t = src1->ne[2]; // number of tokens in the batch const int64_t nr0 = ggml_nrows(src0); - GGML_ASSERT(nc*n_b*nr0 == ggml_nelements(src1)); + GGML_ASSERT(nc*n_t*nr0 == ggml_nelements(src1)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); + // allow merging multiple rows in the same vec operation + GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); + GGML_ASSERT(src1->nb[1] == src1->ne[0]*sizeof(float)); + GGML_ASSERT(src2->nb[1] == src2->ne[0]*sizeof(float)); // rows per thread const int dr = (nr0 + nth - 1)/nth; @@ -14821,27 +14825,26 @@ static void ggml_compute_forward_ssm_scan_f32( // row range for this thread const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr0); + const int ir = ir1 - ir0; // first batch - for (int i1 = ir0; i1 < ir1; i1++) { - float * dest = (float *) ((char *) dst->data + i1*( dst->nb[1])); - float * s = (float *) ((char *) src0->data + i1*(src0->nb[1])); - float * dA = (float *) ((char *) src1->data + i1*(src1->nb[1])); - float * dB_x = (float *) ((char *) src2->data + i1*(src2->nb[1])); - ggml_vec_mul_f32(nc, dest, s, dA); - ggml_vec_add_f32(nc, dest, dest, dB_x); + { + float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1])); + float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); + float * dA = (float *) ((char *) src1->data + ir0*(src1->nb[1])); + float * dB_x = (float *) ((char *) src2->data + ir0*(src2->nb[1])); + ggml_vec_mul_f32(nc*ir, dest, s, dA); + ggml_vec_add_f32(nc*ir, dest, dest, dB_x); } - // rest of batches, state comes from dest - for (int i2 = 1; i2 < n_b; i2++) { - for (int i1 = ir0; i1 < ir1; i1++) { - float * dest = (float *) ((char *) dst->data + i1*( dst->nb[1]) + i2 *( dst->nb[2])); - float * s = (float *) ((char *) dst->data + i1*( dst->nb[1]) + (i2-1)*( dst->nb[2])); - float * dA = (float *) ((char *) src1->data + i1*(src1->nb[1]) + i2 *(src1->nb[2])); - float * dB_x = (float *) ((char *) src2->data + i1*(src2->nb[1]) + i2 *(src2->nb[2])); - ggml_vec_mul_f32(nc, dest, s, dA); - ggml_vec_add_f32(nc, dest, dest, dB_x); - } + // compute state for rest of tokens, previous state comes from dest + for (int i2 = 1; i2 < n_t; i2++) { + float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); + float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); + float * dA = (float *) ((char *) src1->data + ir0*(src1->nb[1]) + i2 *(src1->nb[2])); + float * dB_x = (float *) ((char *) src2->data + ir0*(src2->nb[1]) + i2 *(src2->nb[2])); + ggml_vec_mul_f32(nc*ir, dest, s, dA); + ggml_vec_add_f32(nc*ir, dest, dest, dB_x); } }