metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip) * metal : support more than 1 warps * metal : opts * metal : opt * metal : switch to parallel reduce * metal : reduce registers * metal : simplify * metal : initial FA vec kernel
This commit is contained in:
parent
260cdb2d08
commit
105332cc17
2 changed files with 361 additions and 32 deletions
119
ggml-metal.m
119
ggml-metal.m
|
@ -183,6 +183,8 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
||||
|
@ -621,12 +623,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
||||
|
@ -2563,19 +2567,32 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
||||
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ASSERT(false && "add template specialization for this size");
|
||||
}
|
||||
if (ne01 > 1 || (ne00%128 != 0)) {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
||||
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ASSERT(false && "add template specialization for this size");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
||||
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ASSERT(false && "add template specialization for this size");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: extend if necessary
|
||||
|
@ -2609,24 +2626,62 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
||||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||
|
||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||
// half8x8 kernel
|
||||
if (ne01 > 1 || (ne00%128 != 0)) {
|
||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||
|
||||
GGML_ASSERT(nqptg <= 32);
|
||||
GGML_ASSERT(nqptg % 8 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
GGML_ASSERT(nqptg <= 32);
|
||||
GGML_ASSERT(nqptg % 8 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4;
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4;
|
||||
|
||||
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
||||
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
||||
|
||||
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
} else {
|
||||
// half1x4 kernel
|
||||
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||
|
||||
GGML_ASSERT(nqptg <= 32);
|
||||
GGML_ASSERT(nqptg % 1 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
||||
|
||||
int64_t nsg = 1;
|
||||
while (nsg <= nsgt) {
|
||||
nsg *= 2;
|
||||
}
|
||||
nsg /= 2;
|
||||
|
||||
// require power of 2
|
||||
//{
|
||||
// int64_t nsgm = 1;
|
||||
// while (nsgm < nsg) {
|
||||
// nsgm *= 2;
|
||||
// }
|
||||
// GGML_ASSERT(nsg == nsgm);
|
||||
//}
|
||||
|
||||
const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
|
||||
|
||||
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CPY:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue