Bug fix suggested by Georgi
large dot product kernel selection is now consistent
This commit is contained in:
parent
26cb415267
commit
aa3fd500b1
1 changed files with 3 additions and 1 deletions
|
@ -1573,6 +1573,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
bool is_large = false;
|
||||
|
||||
// use custom matrix x vector kernel
|
||||
switch (src0t) {
|
||||
|
@ -1592,6 +1593,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||
if (ne01 > 128) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4_LARGE].pipeline;
|
||||
is_large = true;
|
||||
} else {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
||||
}
|
||||
|
@ -1784,7 +1786,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
} else {
|
||||
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
||||
if (ne01 > 128) {
|
||||
if (is_large) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01/32, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
} else {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue