metal : reduce registers
This commit is contained in:
parent
e51778de5e
commit
c4dff1ec91
2 changed files with 10 additions and 22 deletions
14
ggml-metal.m
14
ggml-metal.m
|
@ -179,10 +179,6 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
||||
|
@ -625,10 +621,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80, flash_attn_ext_vec_f16_h80, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112, flash_attn_ext_vec_f16_h112, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
||||
|
@ -2521,7 +2513,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
if (ne01 > 1) {
|
||||
if (ne01 > 1 || (ne00%128 != 0)) {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||
|
@ -2538,10 +2530,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
}
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96 ].pipeline; break;
|
||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112].pipeline; break;
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
||||
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
||||
default:
|
||||
|
|
|
@ -2516,7 +2516,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + Q*T); // scratch buffer for the results
|
||||
|
||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||
half4 lo[Q][D4];
|
||||
half4 lo[Q][D4/NW];
|
||||
|
||||
// load heads from Q to shared memory
|
||||
for (short j = sgitg; j < Q; j += nsg) {
|
||||
|
@ -2534,7 +2534,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
// zero out lo
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
lo[j][i] = 0.0h;
|
||||
lo[j][i/NW] = 0.0h;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2711,7 +2711,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
//simdgroup_multiply(lo[j][i], mm, lo[j][i]);
|
||||
lo[j][i] = lo[j][i]*mm;
|
||||
lo[j][i/NW] = lo[j][i/NW]*mm;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2722,7 +2722,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
lo[j][i] += pv4[i]*ss[j*T + cc];
|
||||
lo[j][i/NW] += pv4[i]*ss[j*T + cc];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2743,7 +2743,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
// store results to shared memory
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
sr4[i] = lo[j][i];
|
||||
sr4[i] = lo[j][i/NW];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2805,10 +2805,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<64, 1, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<80, 1, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<96, 1, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<112, 1, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 2, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 3, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 4, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 5, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 1, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 1, 32>;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue