ggml : implicitly pass src tensors through dst for Mamba-related ops
This commit is contained in:
parent
eefb794bd7
commit
2a99d1b243
1 changed files with 19 additions and 28 deletions
47
ggml.c
47
ggml.c
|
@ -14700,15 +14700,16 @@ static void ggml_compute_forward_flash_attn_back(
|
||||||
|
|
||||||
static void ggml_compute_forward_ssm_conv_f32(
|
static void ggml_compute_forward_ssm_conv_f32(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0, // conv_state
|
|
||||||
const struct ggml_tensor * src1, // x
|
|
||||||
const struct ggml_tensor * src2, // conv1d.weight
|
|
||||||
const struct ggml_tensor * src3, // state_seq
|
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const struct ggml_tensor * src0 = dst->src[0]; // conv_state
|
||||||
|
const struct ggml_tensor * src1 = dst->src[1]; // x
|
||||||
|
const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
|
||||||
|
const struct ggml_tensor * src3 = dst->src[3]; // state_seq
|
||||||
|
|
||||||
const int ith = params->ith;
|
const int ith = params->ith;
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
@ -14808,15 +14809,11 @@ static void ggml_compute_forward_ssm_conv_f32(
|
||||||
|
|
||||||
static void ggml_compute_forward_ssm_conv(
|
static void ggml_compute_forward_ssm_conv(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0,
|
|
||||||
const struct ggml_tensor * src1,
|
|
||||||
const struct ggml_tensor * src2,
|
|
||||||
const struct ggml_tensor * src3,
|
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
switch (src0->type) {
|
switch (dst->src[0]->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_ssm_conv_f32(params, src0, src1, src2, src3, dst);
|
ggml_compute_forward_ssm_conv_f32(params, dst);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
|
@ -14829,18 +14826,19 @@ static void ggml_compute_forward_ssm_conv(
|
||||||
|
|
||||||
static void ggml_compute_forward_ssm_scan_f32(
|
static void ggml_compute_forward_ssm_scan_f32(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0, // s
|
|
||||||
const struct ggml_tensor * src1, // x
|
|
||||||
const struct ggml_tensor * src2, // dt
|
|
||||||
const struct ggml_tensor * src3, // A
|
|
||||||
const struct ggml_tensor * src4, // B
|
|
||||||
const struct ggml_tensor * src5, // C
|
|
||||||
const struct ggml_tensor * src6, // sq
|
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const struct ggml_tensor * src0 = dst->src[0]; // s
|
||||||
|
const struct ggml_tensor * src1 = dst->src[1]; // x
|
||||||
|
const struct ggml_tensor * src2 = dst->src[2]; // dt
|
||||||
|
const struct ggml_tensor * src3 = dst->src[3]; // A
|
||||||
|
const struct ggml_tensor * src4 = dst->src[4]; // B
|
||||||
|
const struct ggml_tensor * src5 = dst->src[5]; // C
|
||||||
|
const struct ggml_tensor * src6 = dst->src[6]; // sq
|
||||||
|
|
||||||
const int ith = params->ith;
|
const int ith = params->ith;
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
@ -14936,18 +14934,11 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
|
|
||||||
static void ggml_compute_forward_ssm_scan(
|
static void ggml_compute_forward_ssm_scan(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0,
|
|
||||||
const struct ggml_tensor * src1,
|
|
||||||
const struct ggml_tensor * src2,
|
|
||||||
const struct ggml_tensor * src3,
|
|
||||||
const struct ggml_tensor * src4,
|
|
||||||
const struct ggml_tensor * src5,
|
|
||||||
const struct ggml_tensor * src6,
|
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
switch (src0->type) {
|
switch (dst->src[0]->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, src5, src6, dst);
|
ggml_compute_forward_ssm_scan_f32(params, dst);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
|
@ -16009,11 +16000,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SSM_CONV:
|
case GGML_OP_SSM_CONV:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_ssm_conv(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
ggml_compute_forward_ssm_conv(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SSM_SCAN:
|
case GGML_OP_SSM_SCAN:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor->src[5], tensor->src[6], tensor);
|
ggml_compute_forward_ssm_scan(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_WIN_PART:
|
case GGML_OP_WIN_PART:
|
||||||
{
|
{
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue