diff --git a/ggml.c b/ggml.c index 9b5d0302b..919f84fa9 100644 --- a/ggml.c +++ b/ggml.c @@ -14700,15 +14700,16 @@ static void ggml_compute_forward_flash_attn_back( static void ggml_compute_forward_ssm_conv_f32( const struct ggml_compute_params * params, - const struct ggml_tensor * src0, // conv_state - const struct ggml_tensor * src1, // x - const struct ggml_tensor * src2, // conv1d.weight - const struct ggml_tensor * src3, // state_seq struct ggml_tensor * dst) { if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } + const struct ggml_tensor * src0 = dst->src[0]; // conv_state + const struct ggml_tensor * src1 = dst->src[1]; // x + const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight + const struct ggml_tensor * src3 = dst->src[3]; // state_seq + const int ith = params->ith; const int nth = params->nth; @@ -14808,15 +14809,11 @@ static void ggml_compute_forward_ssm_conv_f32( static void ggml_compute_forward_ssm_conv( const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - const struct ggml_tensor * src2, - const struct ggml_tensor * src3, struct ggml_tensor * dst) { - switch (src0->type) { + switch (dst->src[0]->type) { case GGML_TYPE_F32: { - ggml_compute_forward_ssm_conv_f32(params, src0, src1, src2, src3, dst); + ggml_compute_forward_ssm_conv_f32(params, dst); } break; default: { @@ -14829,18 +14826,19 @@ static void ggml_compute_forward_ssm_conv( static void ggml_compute_forward_ssm_scan_f32( const struct ggml_compute_params * params, - const struct ggml_tensor * src0, // s - const struct ggml_tensor * src1, // x - const struct ggml_tensor * src2, // dt - const struct ggml_tensor * src3, // A - const struct ggml_tensor * src4, // B - const struct ggml_tensor * src5, // C - const struct ggml_tensor * src6, // sq struct ggml_tensor * dst) { if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } + const struct ggml_tensor * src0 = dst->src[0]; // s + const struct ggml_tensor * src1 = dst->src[1]; // x + const struct ggml_tensor * src2 = dst->src[2]; // dt + const struct ggml_tensor * src3 = dst->src[3]; // A + const struct ggml_tensor * src4 = dst->src[4]; // B + const struct ggml_tensor * src5 = dst->src[5]; // C + const struct ggml_tensor * src6 = dst->src[6]; // sq + const int ith = params->ith; const int nth = params->nth; @@ -14936,18 +14934,11 @@ static void ggml_compute_forward_ssm_scan_f32( static void ggml_compute_forward_ssm_scan( const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - const struct ggml_tensor * src2, - const struct ggml_tensor * src3, - const struct ggml_tensor * src4, - const struct ggml_tensor * src5, - const struct ggml_tensor * src6, struct ggml_tensor * dst) { - switch (src0->type) { + switch (dst->src[0]->type) { case GGML_TYPE_F32: { - ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, src5, src6, dst); + ggml_compute_forward_ssm_scan_f32(params, dst); } break; default: { @@ -16009,11 +16000,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_SSM_CONV: { - ggml_compute_forward_ssm_conv(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + ggml_compute_forward_ssm_conv(params, tensor); } break; case GGML_OP_SSM_SCAN: { - ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor->src[5], tensor->src[6], tensor); + ggml_compute_forward_ssm_scan(params, tensor); } break; case GGML_OP_WIN_PART: {