metal : clean-up

This commit is contained in:
Georgi Gerganov 2024-04-19 15:30:27 +03:00
parent 703c6e6528
commit 97eaece7d6
No known key found for this signature in database
GPG key ID: BF970631944C16B7

View file

@ -2560,7 +2560,9 @@ static enum ggml_status ggml_metal_graph_compute(
id<MTLComputePipelineState> pipeline = nil;
if (ne01 > 1 || (ne00%128 != 0)) {
bool use_vec_kernel = false;
if (ne01 >= 4 || (ne00%128 != 0)) {
switch (ne00) {
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
@ -2576,6 +2578,8 @@ static enum ggml_status ggml_metal_graph_compute(
}
}
} else {
use_vec_kernel = true;
switch (ne00) {
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
@ -2588,7 +2592,6 @@ static enum ggml_status ggml_metal_graph_compute(
}
}
// TODO: extend if necessary
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@ -2619,8 +2622,8 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
if (!use_vec_kernel) {
// half8x8 kernel
if (ne01 > 1 || (ne00%128 != 0)) {
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
@ -2635,7 +2638,7 @@ static enum ggml_status ggml_metal_graph_compute(
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
} else {
@ -2926,7 +2929,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff
return NULL;
}
ggml_backend_metal_log_allocated_size(device, size_aligned);
//ggml_backend_metal_log_allocated_size(device, size_aligned);
return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
}