ggml : implicitly pass src tensors through dst for Mamba-related ops

This commit is contained in:
Francis Couture-Harpin 2024-03-04 10:10:50 -05:00
parent eefb794bd7
commit 2a99d1b243

47
ggml.c
View file

@ -14700,15 +14700,16 @@ static void ggml_compute_forward_flash_attn_back(
static void ggml_compute_forward_ssm_conv_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0, // conv_state
const struct ggml_tensor * src1, // x
const struct ggml_tensor * src2, // conv1d.weight
const struct ggml_tensor * src3, // state_seq
struct ggml_tensor * dst) {
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
return;
}
const struct ggml_tensor * src0 = dst->src[0]; // conv_state
const struct ggml_tensor * src1 = dst->src[1]; // x
const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
const struct ggml_tensor * src3 = dst->src[3]; // state_seq
const int ith = params->ith;
const int nth = params->nth;
@ -14808,15 +14809,11 @@ static void ggml_compute_forward_ssm_conv_f32(
static void ggml_compute_forward_ssm_conv(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
const struct ggml_tensor * src2,
const struct ggml_tensor * src3,
struct ggml_tensor * dst) {
switch (src0->type) {
switch (dst->src[0]->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_ssm_conv_f32(params, src0, src1, src2, src3, dst);
ggml_compute_forward_ssm_conv_f32(params, dst);
} break;
default:
{
@ -14829,18 +14826,19 @@ static void ggml_compute_forward_ssm_conv(
static void ggml_compute_forward_ssm_scan_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0, // s
const struct ggml_tensor * src1, // x
const struct ggml_tensor * src2, // dt
const struct ggml_tensor * src3, // A
const struct ggml_tensor * src4, // B
const struct ggml_tensor * src5, // C
const struct ggml_tensor * src6, // sq
struct ggml_tensor * dst) {
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
return;
}
const struct ggml_tensor * src0 = dst->src[0]; // s
const struct ggml_tensor * src1 = dst->src[1]; // x
const struct ggml_tensor * src2 = dst->src[2]; // dt
const struct ggml_tensor * src3 = dst->src[3]; // A
const struct ggml_tensor * src4 = dst->src[4]; // B
const struct ggml_tensor * src5 = dst->src[5]; // C
const struct ggml_tensor * src6 = dst->src[6]; // sq
const int ith = params->ith;
const int nth = params->nth;
@ -14936,18 +14934,11 @@ static void ggml_compute_forward_ssm_scan_f32(
static void ggml_compute_forward_ssm_scan(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
const struct ggml_tensor * src2,
const struct ggml_tensor * src3,
const struct ggml_tensor * src4,
const struct ggml_tensor * src5,
const struct ggml_tensor * src6,
struct ggml_tensor * dst) {
switch (src0->type) {
switch (dst->src[0]->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, src5, src6, dst);
ggml_compute_forward_ssm_scan_f32(params, dst);
} break;
default:
{
@ -16009,11 +16000,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_SSM_CONV:
{
ggml_compute_forward_ssm_conv(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
ggml_compute_forward_ssm_conv(params, tensor);
} break;
case GGML_OP_SSM_SCAN:
{
ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor->src[5], tensor->src[6], tensor);
ggml_compute_forward_ssm_scan(params, tensor);
} break;
case GGML_OP_WIN_PART:
{