metal : clean-up
This commit is contained in:
parent
703c6e6528
commit
97eaece7d6
1 changed files with 178 additions and 175 deletions
13
ggml-metal.m
13
ggml-metal.m
|
@ -2560,7 +2560,9 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
if (ne01 > 1 || (ne00%128 != 0)) {
|
||||
bool use_vec_kernel = false;
|
||||
|
||||
if (ne01 >= 4 || (ne00%128 != 0)) {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||
|
@ -2576,6 +2578,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
}
|
||||
}
|
||||
} else {
|
||||
use_vec_kernel = true;
|
||||
|
||||
switch (ne00) {
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
||||
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
||||
|
@ -2588,7 +2592,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: extend if necessary
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
|
@ -2619,8 +2622,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
||||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||
|
||||
if (!use_vec_kernel) {
|
||||
// half8x8 kernel
|
||||
if (ne01 > 1 || (ne00%128 != 0)) {
|
||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||
|
||||
|
@ -2635,7 +2638,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
|
||||
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
} else {
|
||||
|
@ -2926,7 +2929,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff
|
|||
return NULL;
|
||||
}
|
||||
|
||||
ggml_backend_metal_log_allocated_size(device, size_aligned);
|
||||
//ggml_backend_metal_log_allocated_size(device, size_aligned);
|
||||
|
||||
return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue