ggml : sync latest ggml_mul_mat_id
This commit is contained in:
parent
a3eefe95a8
commit
861cd67899
4 changed files with 114 additions and 75 deletions
21
ggml-metal.m
21
ggml-metal.m
|
@ -177,6 +177,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
|||
ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
|
||||
} else {
|
||||
char* buffer2 = malloc(len+1);
|
||||
va_end(args);
|
||||
va_start(args, format);
|
||||
vsnprintf(buffer2, len+1, format, args);
|
||||
buffer2[len] = 0;
|
||||
ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
|
||||
|
@ -1193,7 +1195,9 @@ void ggml_metal_graph_compute(
|
|||
const float scale = ((float *) dst->op_params)[0];
|
||||
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
if (id_src1) {
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
}
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||
|
@ -1511,9 +1515,7 @@ void ggml_metal_graph_compute(
|
|||
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
|
||||
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
||||
}
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
|
||||
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
|
||||
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
|
||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
|
||||
|
@ -1523,7 +1525,7 @@ void ggml_metal_graph_compute(
|
|||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
||||
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
|
||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
||||
[encoder setBytes:&idx length:sizeof(idx) atIndex:15];
|
||||
|
@ -1538,7 +1540,14 @@ void ggml_metal_graph_compute(
|
|||
}
|
||||
|
||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||
|
||||
for (int64_t i01 = 0; i01 < src0->ne[1]; i01++) {
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 + i01*nb01 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 + i01*nb11 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst + i01*nb1 atIndex:2];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue