llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask)

This commit is contained in:
Georgi Gerganov 2023-09-17 19:42:39 +03:00
parent c5df72e848
commit 3b4bab6a38
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 128 additions and 31 deletions

View file

@ -736,25 +736,59 @@ void ggml_metal_graph_compute(
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
// utilize float4
GGML_ASSERT(ne00 % 4 == 0);
const int64_t nb = ne00/4;
bool bcast_row = false;
if (ggml_nelements(src1) == ne10) {
int64_t nb = ne00;
if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
// src1 is a row
GGML_ASSERT(ne11 == 1);
nb = ne00 / 4;
[encoder setComputePipelineState:ctx->pipeline_add_row];
bcast_row = true;
} else {
[encoder setComputePipelineState:ctx->pipeline_add];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
const int64_t n = ggml_nelements(dst)/4;
if (bcast_row) {
const int64_t n = ggml_nelements(dst)/4;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} else {
const int nth = MIN(1024, ne0);
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
} break;
case GGML_OP_MUL:
{