ggml : implicitly pass src tensors through dst for Mamba-related ops

This commit is contained in:
Francis Couture-Harpin 2024-03-04 10:10:50 -05:00
parent eefb794bd7
commit 2a99d1b243

47
ggml.c
View file

@ -14700,15 +14700,16 @@ static void ggml_compute_forward_flash_attn_back(
static void ggml_compute_forward_ssm_conv_f32( static void ggml_compute_forward_ssm_conv_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, // conv_state
const struct ggml_tensor * src1, // x
const struct ggml_tensor * src2, // conv1d.weight
const struct ggml_tensor * src3, // state_seq
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
return; return;
} }
const struct ggml_tensor * src0 = dst->src[0]; // conv_state
const struct ggml_tensor * src1 = dst->src[1]; // x
const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
const struct ggml_tensor * src3 = dst->src[3]; // state_seq
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
@ -14808,15 +14809,11 @@ static void ggml_compute_forward_ssm_conv_f32(
static void ggml_compute_forward_ssm_conv( static void ggml_compute_forward_ssm_conv(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
const struct ggml_tensor * src2,
const struct ggml_tensor * src3,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
switch (src0->type) { switch (dst->src[0]->type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
{ {
ggml_compute_forward_ssm_conv_f32(params, src0, src1, src2, src3, dst); ggml_compute_forward_ssm_conv_f32(params, dst);
} break; } break;
default: default:
{ {
@ -14829,18 +14826,19 @@ static void ggml_compute_forward_ssm_conv(
static void ggml_compute_forward_ssm_scan_f32( static void ggml_compute_forward_ssm_scan_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, // s
const struct ggml_tensor * src1, // x
const struct ggml_tensor * src2, // dt
const struct ggml_tensor * src3, // A
const struct ggml_tensor * src4, // B
const struct ggml_tensor * src5, // C
const struct ggml_tensor * src6, // sq
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
return; return;
} }
const struct ggml_tensor * src0 = dst->src[0]; // s
const struct ggml_tensor * src1 = dst->src[1]; // x
const struct ggml_tensor * src2 = dst->src[2]; // dt
const struct ggml_tensor * src3 = dst->src[3]; // A
const struct ggml_tensor * src4 = dst->src[4]; // B
const struct ggml_tensor * src5 = dst->src[5]; // C
const struct ggml_tensor * src6 = dst->src[6]; // sq
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
@ -14936,18 +14934,11 @@ static void ggml_compute_forward_ssm_scan_f32(
static void ggml_compute_forward_ssm_scan( static void ggml_compute_forward_ssm_scan(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
const struct ggml_tensor * src2,
const struct ggml_tensor * src3,
const struct ggml_tensor * src4,
const struct ggml_tensor * src5,
const struct ggml_tensor * src6,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
switch (src0->type) { switch (dst->src[0]->type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
{ {
ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, src5, src6, dst); ggml_compute_forward_ssm_scan_f32(params, dst);
} break; } break;
default: default:
{ {
@ -16009,11 +16000,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break; } break;
case GGML_OP_SSM_CONV: case GGML_OP_SSM_CONV:
{ {
ggml_compute_forward_ssm_conv(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); ggml_compute_forward_ssm_conv(params, tensor);
} break; } break;
case GGML_OP_SSM_SCAN: case GGML_OP_SSM_SCAN:
{ {
ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor->src[5], tensor->src[6], tensor); ggml_compute_forward_ssm_scan(params, tensor);
} break; } break;
case GGML_OP_WIN_PART: case GGML_OP_WIN_PART:
{ {