diff --git a/ggml/src/ggml-cuda/ssm_conv.cu b/ggml/src/ggml-cuda/ssm_conv.cu index b6c62893d..fcaddf3a8 100644 --- a/ggml/src/ggml-cuda/ssm_conv.cu +++ b/ggml/src/ggml-cuda/ssm_conv.cu @@ -2,13 +2,12 @@ template static __global__ void ssm_conv_f32( - const float * src0, const float * src1, const float * src2, - const int src0_nb1, const int src0_nb2, - const int src1_nb0, const int src1_nb1, const int src1_nb2, - const int src2_nb1, + const float * src0, const float * src1, + const int src0_nb0, const int src0_nb1, const int src0_nb2, + const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, - const int nc, const int nr, const int n_t, const int n_s) { + const int nc, const int ncs, const int nr, const int n_t, const int n_s) { // const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; @@ -24,118 +23,80 @@ static __global__ void ssm_conv_f32( const int ir1 = min(ir0 + dr, nr); const int ir = ir1 - ir0; - // TODO: maybe require src0 to have d_conv columns instead of (d_conv - 1)? - // This would avoid having to copy into an intermediate buffer, but the state would be bigger. - -// float * s = (float *) params->wdata + (nc*dr + CACHE_LINE_SIZE_F32) * ith; - extern __shared__ float wdata_f32[]; // work buffer for all threads - float * s = (float *) wdata_f32 + nc*dr*ith; - for (int i3 = 0; i3 < n_s; ++i3) { - float * s0 = (float *) ((char *) src0 + ir0*src0_nb1 + i3*src0_nb2); // {d_conv, d_inner, n_s} - - // copy the state into working memory - // can't use memcpy because (d_conv) != (d_conv - 1) - for (int i1 = 0; i1 < ir; ++i1) { - for (int i0 = 0; i0 < nc - 1; ++i0) { - s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)]; - } - } - for (int i2 = 0; i2 < n_t; ++i2) { - float * x = (float *) ((char *) dst + ir0* dst_nb0 + i2* dst_nb1 + i3* dst_nb2); // {d_inner, n_t, n_s} - float * x0 = (float *) ((char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s} - float * c = (float *) ((char *) src2 + ir0*src2_nb1); // {d_conv, d_inner} - - // shift state left - //memmove(s, s + 1, (nc*ir - 1) * sizeof(float)); - for (int i4 = 0; i4 < nc*ir - 1; ++i4) { - s[i4] = s[i4+1]; - } + // {d_conv - 1 + n_t, d_inner, n_seqs} + // sliding window + const float * s = (const float *) ((const char *) src0 + ir0*src0_nb1 + i2*src0_nb0 + i3*src0_nb2); // {d_conv, d_inner, n_s} + const float * c = (const float *) ((const char *) src1 + ir0*src1_nb1); // {d_conv, d_inner} + float * x = (float *) ((char *) dst + ir0*dst_nb0 + i2*dst_nb1 + i3*dst_nb2); // {d_inner, n_t, n_s} + // TODO: transpose the output for smaller strides for big batches? // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // insert x on the last column - s[(nc - 1) + i1*nc] = x0[i1]; - } - - // it seems a little faster when this is separate from the state shift for (int i1 = 0; i1 < ir; ++i1) { // rowwise dot product // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision float sumf = 0.0f; + + // d_conv for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - sumf += s[i] * c[i]; + sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; } x[i1] = sumf; } } - - // copy the state out of it - for (int i1 = 0; i1 < ir; ++i1) { - for (int i0 = 0; i0 < nc - 1; ++i0) { - s0[i0 + i1*(nc - 1)] = s[1 + i0 + i1*nc]; - } - } } } static void ssm_conv_f32_cuda( - const float * src0, const float * src1, const float * src2, - const int src0_nb1, const int src0_nb2, - const int src1_nb0, const int src1_nb1, const int src1_nb2, - const int src2_nb1, + const float * src0, const float * src1, + const int src0_nb0, const int src0_nb1, const int src0_nb2, + const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, - const int nc, const int nr, const int n_t, const int n_s, + const int nc, const int ncs, const int nr, const int n_t, const int n_s, cudaStream_t stream) { const dim3 block_dims(WARP_SIZE, 1, 1); const int nblocks = 1; // TODO - const int shmem_size = nc * (nr + WARP_SIZE - 1) * sizeof(float); // TODO - ssm_conv_f32<<>>( - src0, src1, src2, - src0_nb1, src0_nb2, - src1_nb0, src1_nb1, src1_nb2, - src2_nb1, + ssm_conv_f32<<>>( + src0, src1, + src0_nb0, src0_nb1, src0_nb2, + src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, - nc, nr, n_t, n_s); + nc, ncs, nr, n_t, n_s); } void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; // conv_state - const struct ggml_tensor * src1 = dst->src[1]; // x - const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight + const struct ggml_tensor * src0 = dst->src[0]; // conv_x + const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight - const int nc = src2->ne[0]; // d_conv + const int nc = src1->ne[0]; // d_conv + const int ncs = src0->ne[0]; // d_conv - 1 + n_t const int nr = src0->ne[1]; // d_inner - const int n_t = src1->ne[1]; // tokens per sequence - const int n_s = src0->ne[2]; // number of sequences in the batch + const int n_t = dst->ne[1]; // tokens per sequence + const int n_s = dst->ne[2]; // number of sequences in the batch - GGML_ASSERT(ggml_are_same_shape(src1, dst)); + GGML_ASSERT( dst->ne[0] == nr); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); - GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); const float * src0_d = (const float *)src0->data; const float * src1_d = (const float *)src1->data; - const float * src2_d = (const float *)src2->data; float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - ssm_conv_f32_cuda(src0_d, src1_d, src2_d, - src0->nb[1], src0->nb[2], - src1->nb[0], src1->nb[1], src1->nb[2], - src2->nb[1], + ssm_conv_f32_cuda(src0_d, src1_d, + src0->nb[0], src0->nb[1], src0->nb[2], + src1->nb[1], dst_d, dst->nb[0], dst->nb[1], dst->nb[2], - nc, nr, n_t, n_s, + nc, ncs, nr, n_t, n_s, stream); } diff --git a/ggml/src/ggml-cuda/ssm_scan.cu b/ggml/src/ggml-cuda/ssm_scan.cu index f19088fdd..4cc32b776 100644 --- a/ggml/src/ggml-cuda/ssm_scan.cu +++ b/ggml/src/ggml-cuda/ssm_scan.cu @@ -5,13 +5,12 @@ static __global__ void ssm_scan_f32( const float * src0, const float * src1, const float * src2, const float * src3, const float * src4, const float * src5, const int src0_nb1, const int src0_nb2, - const int src1_nb0, const int src1_nb1, const int src1_nb2, + const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, float * dst, - const int dst_nb0, const int dst_nb1, const int dst_nb2, const int nc, const int nr, const int n_t, const int n_s) { // const int row = blockIdx.x*blockDim.y + threadIdx.y; @@ -30,13 +29,17 @@ static __global__ void ssm_scan_f32( for (int i3 = 0; i3 < n_s; ++i3) { for (int i2 = 0; i2 < n_t; ++i2) { - float * y = (float *) ((char *) dst + ir0* dst_nb0 + i2* dst_nb1 + i3* dst_nb2); // {d_inner, n_t, n_s} - float * s = (float *) ((char *) src0 + ir0*src0_nb1 + i3*src0_nb2); // {d_state, d_inner, n_s} - float * x = (float *) ((char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s} - float * dt = (float *) ((char *) src2 + ir0*src2_nb0 + i2*src2_nb1 + i3*src2_nb2); // {d_inner, n_t, n_s} - float * A = (float *) ((char *) src3 + ir0*src3_nb1); // {d_state, d_inner} - float * B = (float *) ((char *) src4 + i2*src4_nb1 + i3*src4_nb2); // {d_state, n_t, n_s} - float * C = (float *) ((char *) src5 + i2*src5_nb1 + i3*src5_nb2); // {d_state, n_t, n_s} + const float * s0 = (const float *) ((const char *) src0 + ir0*src0_nb1 + i3*src0_nb2); // {d_state, d_inner, n_s} + const float * x = (const float *) ((const char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s} + const float * dt = (const float *) ((const char *) src2 + ir0*src2_nb0 + i2*src2_nb1 + i3*src2_nb2); // {d_inner, n_t, n_s} + const float * A = (const float *) ((const char *) src3 + ir0*src3_nb1); // {d_state, d_inner} + const float * B = (const float *) ((const char *) src4 + i2*src4_nb1 + i3*src4_nb2); // {d_state, n_t, n_s} + const float * C = (const float *) ((const char *) src5 + i2*src5_nb1 + i3*src5_nb2); // {d_state, n_t, n_s} + float * y = (float *) ((char *) dst + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s} + float * s = (float *) ((char *) dst + ir0*src0_nb1 + i3*src0_nb2 + src1_nb3); // {d_state, d_inner, n_s} + + // use the output as the source for the next token-wise iterations + if (i2 > 0) { s0 = s; } // d_inner for (int i1 = 0; i1 < ir; ++i1) { @@ -48,7 +51,7 @@ static __global__ void ssm_scan_f32( for (int i0 = 0; i0 < nc; ++i0) { int i = i0 + i1*nc; // state = prev_state * dA + dB * x - float state = (s[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); // y = rowwise_dotprod(state, C) sumf += state * C[i0]; s[i] = state; @@ -63,13 +66,12 @@ static void ssm_scan_f32_cuda( const float * src0, const float * src1, const float * src2, const float * src3, const float * src4, const float * src5, const int src0_nb1, const int src0_nb2, - const int src1_nb0, const int src1_nb1, const int src1_nb2, + const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, float * dst, - const int dst_nb0, const int dst_nb1, const int dst_nb2, const int nc, const int nr, const int n_t, const int n_s, cudaStream_t stream) { @@ -80,13 +82,12 @@ static void ssm_scan_f32_cuda( src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, - src1_nb0, src1_nb1, src1_nb2, + src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0, src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, - dst_nb0, dst_nb1, dst_nb2, nc, nr, n_t, n_s); } @@ -103,7 +104,7 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t n_t = src1->ne[1]; // number of tokens per sequence const int64_t n_s = src0->ne[2]; // number of sequences in the batch - GGML_ASSERT(ggml_nelements(src1) == ggml_nelements(dst)); + GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); @@ -112,6 +113,10 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(src5->nb[0] == sizeof(float)); // required for the dot product between s and C GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); + // required for per-sequence offsets for states + GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); + // required to get correct offset for state destination (i.e. src1->nb[3]) + GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float)); const float * src0_d = (const float *)src0->data; const float * src1_d = (const float *)src1->data; @@ -129,13 +134,12 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], - src1->nb[0], src1->nb[1], src1->nb[2], + src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1], src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, - dst->nb[0], dst->nb[1], dst->nb[2], nc, nr, n_t, n_s, stream); } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 592656048..2b8a99d20 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1662,10 +1662,9 @@ struct test_ssm_conv : public test_case { : type(type), d_conv(d_conv), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * s = ggml_new_tensor_3d(ctx, type, d_conv - 1, d_inner, n_seqs); - ggml_tensor * x = ggml_new_tensor_3d(ctx, type, d_inner, n_seq_tokens, n_seqs); + ggml_tensor * sx = ggml_new_tensor_3d(ctx, type, d_conv - 1 + n_seq_tokens, d_inner, n_seqs); ggml_tensor * c = ggml_new_tensor_2d(ctx, type, d_conv, d_inner); - ggml_tensor * out = ggml_ssm_conv(ctx, s, x, c); + ggml_tensor * out = ggml_ssm_conv(ctx, sx, c); return out; } };