metal : fix kernel_norm

ggml-ci
This commit is contained in:
Georgi Gerganov 2023-09-07 14:11:21 +03:00
parent fec2fb19e4
commit 5e1c4089d8
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 24 additions and 23 deletions

View file

@ -995,8 +995,12 @@ void ggml_metal_graph_compute(
else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
int64_t ny = (ne11 + 3)/4;
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
// TODO: this breaks for Q4_0 - understand why and fix it
//int64_t ny = (ne11 + 3)/4;
//[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}
} break;