Bug fix suggested by Georgi

large dot product kernel selection is now consistent
This commit is contained in:
Alexander Komarov 2024-05-25 07:49:18 -07:00 committed by GitHub
parent 26cb415267
commit aa3fd500b1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1573,6 +1573,7 @@ static enum ggml_status ggml_metal_graph_compute(
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
id<MTLComputePipelineState> pipeline = nil;
bool is_large = false;
// use custom matrix x vector kernel
switch (src0t) {
@ -1592,6 +1593,7 @@ static enum ggml_status ggml_metal_graph_compute(
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
if (ne01 > 128) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4_LARGE].pipeline;
is_large = true;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
}
@ -1784,7 +1786,7 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
const int64_t ny = (ne11 + nrows - 1)/nrows;
if (ne01 > 128) {
if (is_large) {
[encoder dispatchThreadgroups:MTLSizeMake(ne01/32, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];