ggml : add ALiBi support for ggml_soft_max_ext (#5488)
* ggml : avoid recomputing alibi slopes (CPU) * llama : reuse hparams.f_max_alibi_bias in all cases ggml-ci * ggml : support alibi bias in ggml_soft_max_ext (CPU + Metal) ggml-ci * ggml : handle all SRCs (do not break on first null) ggml-ci * tests : do not use slope for large soft_max accumulates too much error ggml-ci * ggml : alternative ALiBi without extra tensor We compute the slopes in the kernel ggml-ci * cuda : add ALiBi support in ggml_soft_max_ext ggml-ci * ggml : deprecate ggml_alibi * ggml : support multi-sequence ALiBi (Metal) ggml-ci * cuda : add multi-seq ALiBi + remote F16 soft_max ggml-ci * ggml : update deprecation message * ggml : fix pos ptr when no ALiBi ggml-ci * cuda : fix performance (pow -> powf) * cuda : precompute ALiBi constants * metal : pre-compute ALiBi slopes ggml-ci * llama : init kq_pos only if needed ggml-ci * test-backend-ops : add null pos test to soft_max test-backend-ops : replace soft_max tests ggml-ci --------- Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
6e4e973b26
commit
8f1be0d42f
9 changed files with 348 additions and 357 deletions
35
ggml-metal.m
35
ggml-metal.m
|
@ -728,6 +728,7 @@ static bool ggml_metal_graph_compute(
|
|||
|
||||
size_t offs_src0 = 0;
|
||||
size_t offs_src1 = 0;
|
||||
size_t offs_src2 = 0;
|
||||
size_t offs_dst = 0;
|
||||
|
||||
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
|
||||
|
@ -746,6 +747,7 @@ static bool ggml_metal_graph_compute(
|
|||
|
||||
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
||||
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
||||
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
||||
struct ggml_tensor * dst = gf->nodes[i];
|
||||
|
||||
switch (dst->op) {
|
||||
|
@ -807,6 +809,7 @@ static bool ggml_metal_graph_compute(
|
|||
|
||||
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
|
||||
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
|
||||
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
||||
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
||||
|
||||
//GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
||||
|
@ -1188,7 +1191,16 @@ static bool ggml_metal_graph_compute(
|
|||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
|
||||
}
|
||||
|
||||
const float scale = ((float *) dst->op_params)[0];
|
||||
const float scale = ((float *) dst->op_params)[0];
|
||||
const float max_bias = ((float *) dst->op_params)[1];
|
||||
|
||||
const int64_t nrows_x = ggml_nrows(src0);
|
||||
const int64_t nrows_y = src0->ne[1];
|
||||
const uint32_t n_head_kv = nrows_x/nrows_y;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
|
@ -1197,11 +1209,20 @@ static bool ggml_metal_graph_compute(
|
|||
} else {
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
}
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
||||
if (id_src2) {
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
} else {
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
|
||||
}
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
|
||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:7];
|
||||
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
|
||||
[encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
|
||||
[encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
|
||||
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
|
@ -1514,8 +1535,6 @@ static bool ggml_metal_graph_compute(
|
|||
// max size of the src1ids array in the kernel stack
|
||||
GGML_ASSERT(ne11 <= 512);
|
||||
|
||||
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
||||
|
||||
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
||||
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
||||
const int64_t ne22 = src2 ? src2->ne[2] : 0;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue