ggml : add ggml_ssm_conv metal impl
This commit is contained in:
parent
7a3df798fc
commit
9928f4bde3
3 changed files with 112 additions and 0 deletions
|
@ -82,6 +82,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
||||||
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
||||||
GGML_METAL_KERNEL_TYPE_NORM,
|
GGML_METAL_KERNEL_TYPE_NORM,
|
||||||
|
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
||||||
|
@ -542,6 +543,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
||||||
|
@ -803,6 +805,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||||
|
case GGML_OP_SSM_CONV:
|
||||||
|
return true;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
return ctx->support_simdgroup_reduction &&
|
return ctx->support_simdgroup_reduction &&
|
||||||
|
@ -1538,6 +1542,39 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_SSM_CONV:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||||
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||||
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||||
|
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
||||||
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
||||||
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
||||||
|
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
||||||
|
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
||||||
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
|
||||||
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
|
||||||
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
||||||
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
||||||
|
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15];
|
||||||
|
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16];
|
||||||
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17];
|
||||||
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
} break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne00 == ne10);
|
GGML_ASSERT(ne00 == ne10);
|
||||||
|
|
|
@ -667,6 +667,54 @@ kernel void kernel_diag_mask_inf_8(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
|
||||||
|
// TODO: optimize
|
||||||
|
kernel void kernel_ssm_conv_f32(
|
||||||
|
device const void * src0,
|
||||||
|
device const void * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant int64_t & ne2,
|
||||||
|
constant uint64_t & nb0,
|
||||||
|
constant uint64_t & nb1,
|
||||||
|
constant uint64_t & nb2,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int64_t ir = tgpig.x;
|
||||||
|
const int64_t i2 = tgpig.y;
|
||||||
|
const int64_t i3 = tgpig.z;
|
||||||
|
|
||||||
|
const int64_t nc = ne10;
|
||||||
|
const int64_t ncs = ne00;
|
||||||
|
const int64_t nr = ne01;
|
||||||
|
const int64_t n_t = ne1;
|
||||||
|
const int64_t n_s = ne2;
|
||||||
|
|
||||||
|
device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
|
||||||
|
device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
|
||||||
|
device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2);
|
||||||
|
|
||||||
|
float sumf = 0.0f;
|
||||||
|
|
||||||
|
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
||||||
|
sumf += s[i0] * c[i0];
|
||||||
|
}
|
||||||
|
|
||||||
|
x[0] = sumf;
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_norm(
|
kernel void kernel_norm(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
|
|
@ -949,6 +949,29 @@ struct test_rms_norm : public test_case {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// GGML_OP_SSM_CONV
|
||||||
|
struct test_ssm_conv : public test_case {
|
||||||
|
const ggml_type type;
|
||||||
|
const std::array<int64_t, 4> ne_a;
|
||||||
|
const std::array<int64_t, 4> ne_b;
|
||||||
|
|
||||||
|
std::string vars() override {
|
||||||
|
return VARS_TO_STR3(type, ne_a, ne_b);
|
||||||
|
}
|
||||||
|
|
||||||
|
test_ssm_conv(ggml_type type = GGML_TYPE_F32,
|
||||||
|
std::array<int64_t, 4> ne_a = {10, 10, 10, 1},
|
||||||
|
std::array<int64_t, 4> ne_b = {3, 3, 1, 1})
|
||||||
|
: type(type), ne_a(ne_a), ne_b(ne_b) {}
|
||||||
|
|
||||||
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||||
|
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
|
||||||
|
ggml_tensor * out = ggml_ssm_conv(ctx, a, b);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// GGML_OP_MUL_MAT
|
// GGML_OP_MUL_MAT
|
||||||
struct test_mul_mat : public test_case {
|
struct test_mul_mat : public test_case {
|
||||||
const ggml_type type_a;
|
const ggml_type type_a;
|
||||||
|
@ -2240,6 +2263,10 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||||
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
|
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1}));
|
||||||
|
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
|
||||||
|
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
for (ggml_type type_a : base_types) {
|
for (ggml_type type_a : base_types) {
|
||||||
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue