ggml : add ggml_flash_attn_ext API
This commit is contained in:
parent
ad19812cda
commit
a1c004ef2e
6 changed files with 456 additions and 38 deletions
50
ggml-metal.m
50
ggml-metal.m
|
@ -147,6 +147,7 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
||||
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
||||
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
||||
|
@ -511,6 +512,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, flash_attn_ext_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
||||
|
@ -665,6 +667,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|||
case GGML_OP_PAD:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
return true;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
|
@ -2161,6 +2164,53 @@ static bool ggml_metal_graph_compute(
|
|||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
|
||||
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
||||
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
||||
|
||||
size_t offs_src2 = 0;
|
||||
size_t offs_src3 = 0;
|
||||
|
||||
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(ctx, src2, &offs_src2) : nil;
|
||||
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(ctx, src3, &offs_src3) : nil;
|
||||
|
||||
float scale;
|
||||
memcpy(&scale, dst->op_params, sizeof(float));
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16].pipeline;
|
||||
|
||||
// TODO: extend if necessary
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
|
||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
|
||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
|
||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
|
||||
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:13];
|
||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:14];
|
||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:15];
|
||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:16];
|
||||
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:17];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:18];
|
||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:19];
|
||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20];
|
||||
[encoder setBytes:&scale length:sizeof( float) atIndex:21];
|
||||
|
||||
const int nth = MIN(1024, ne0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue