From c4dff1ec910a2057a3c17b170028cb9c1d418865 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 5 Apr 2024 16:24:10 +0300 Subject: [PATCH] metal : reduce registers --- ggml-metal.m | 14 +------------- ggml-metal.metal | 18 +++++++++--------- 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 6106bc7e3..07535828d 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -179,10 +179,6 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, @@ -625,10 +621,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80, flash_attn_ext_vec_f16_h80, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112, flash_attn_ext_vec_f16_h112, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); @@ -2521,7 +2513,7 @@ static enum ggml_status ggml_metal_graph_compute( id pipeline = nil; - if (ne01 > 1) { + if (ne01 > 1 || (ne00%128 != 0)) { switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; @@ -2538,10 +2530,6 @@ static enum ggml_status ggml_metal_graph_compute( } } else { switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112].pipeline; break; case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; default: diff --git a/ggml-metal.metal b/ggml-metal.metal index 404bd16e0..7709865c9 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2516,7 +2516,7 @@ kernel void kernel_flash_attn_ext_vec_f16( threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + Q*T); // scratch buffer for the results // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - half4 lo[Q][D4]; + half4 lo[Q][D4/NW]; // load heads from Q to shared memory for (short j = sgitg; j < Q; j += nsg) { @@ -2534,7 +2534,7 @@ kernel void kernel_flash_attn_ext_vec_f16( // zero out lo for (short j = 0; j < Q; ++j) { for (short i = tiisg; i < D4; i += NW) { - lo[j][i] = 0.0h; + lo[j][i/NW] = 0.0h; } } @@ -2711,7 +2711,7 @@ kernel void kernel_flash_attn_ext_vec_f16( for (short i = tiisg; i < D4; i += NW) { //simdgroup_multiply(lo[j][i], mm, lo[j][i]); - lo[j][i] = lo[j][i]*mm; + lo[j][i/NW] = lo[j][i/NW]*mm; } } @@ -2722,7 +2722,7 @@ kernel void kernel_flash_attn_ext_vec_f16( for (short i = tiisg; i < D4; i += NW) { for (short j = 0; j < Q; ++j) { - lo[j][i] += pv4[i]*ss[j*T + cc]; + lo[j][i/NW] += pv4[i]*ss[j*T + cc]; } } } @@ -2743,7 +2743,7 @@ kernel void kernel_flash_attn_ext_vec_f16( // store results to shared memory for (short j = 0; j < Q; ++j) { for (short i = tiisg; i < D4; i += NW) { - sr4[i] = lo[j][i]; + sr4[i] = lo[j][i/NW]; } } @@ -2805,10 +2805,10 @@ kernel void kernel_flash_attn_ext_vec_f16( } } -template [[host_name("kernel_flash_attn_ext_vec_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<64, 1, 32>; -template [[host_name("kernel_flash_attn_ext_vec_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<80, 1, 32>; -template [[host_name("kernel_flash_attn_ext_vec_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<96, 1, 32>; -template [[host_name("kernel_flash_attn_ext_vec_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<112, 1, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 2, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 3, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 4, 32>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 5, 32>; template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 1, 32>; template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 1, 32>;