ggml : add ALiBi support for ggml_soft_max_ext (#5488)
* ggml : avoid recomputing alibi slopes (CPU) * llama : reuse hparams.f_max_alibi_bias in all cases ggml-ci * ggml : support alibi bias in ggml_soft_max_ext (CPU + Metal) ggml-ci * ggml : handle all SRCs (do not break on first null) ggml-ci * tests : do not use slope for large soft_max accumulates too much error ggml-ci * ggml : alternative ALiBi without extra tensor We compute the slopes in the kernel ggml-ci * cuda : add ALiBi support in ggml_soft_max_ext ggml-ci * ggml : deprecate ggml_alibi * ggml : support multi-sequence ALiBi (Metal) ggml-ci * cuda : add multi-seq ALiBi + remote F16 soft_max ggml-ci * ggml : update deprecation message * ggml : fix pos ptr when no ALiBi ggml-ci * cuda : fix performance (pow -> powf) * cuda : precompute ALiBi constants * metal : pre-compute ALiBi slopes ggml-ci * llama : init kq_pos only if needed ggml-ci * test-backend-ops : add null pos test to soft_max test-backend-ops : replace soft_max tests ggml-ci --------- Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
6e4e973b26
commit
8f1be0d42f
9 changed files with 348 additions and 357 deletions
|
@ -351,12 +351,17 @@ kernel void kernel_sum_rows(
|
|||
kernel void kernel_soft_max(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device const float * src2,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant float & scale,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
constant float & max_bias,
|
||||
constant float & m0,
|
||||
constant float & m1,
|
||||
constant uint32_t & n_head_log2,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
|
@ -368,13 +373,26 @@ kernel void kernel_soft_max(
|
|||
|
||||
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
||||
device const float * ppos = src2 != src0 ? src2 : nullptr;
|
||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
float slope = 0.0f;
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
const int64_t h = i02;
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exp);
|
||||
}
|
||||
|
||||
// parallel max
|
||||
float lmax = -INFINITY;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
||||
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
|
||||
}
|
||||
|
||||
// find the max value in the block
|
||||
|
@ -399,7 +417,7 @@ kernel void kernel_soft_max(
|
|||
// parallel sum
|
||||
float lsum = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
||||
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
|
||||
lsum += exp_psrc0;
|
||||
pdst[i00] = exp_psrc0;
|
||||
}
|
||||
|
@ -437,12 +455,17 @@ kernel void kernel_soft_max(
|
|||
kernel void kernel_soft_max_4(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device const float * src2,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant float & scale,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
constant float & max_bias,
|
||||
constant float & m0,
|
||||
constant float & m1,
|
||||
constant uint32_t & n_head_log2,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
|
@ -454,13 +477,25 @@ kernel void kernel_soft_max_4(
|
|||
|
||||
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
||||
device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr;
|
||||
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
|
||||
float slope = 0.0f;
|
||||
|
||||
if (max_bias > 0.0f) {
|
||||
const int64_t h = i02;
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exp);
|
||||
}
|
||||
|
||||
// parallel max
|
||||
float4 lmax4 = -INFINITY;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
|
||||
}
|
||||
|
||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||
|
@ -486,7 +521,7 @@ kernel void kernel_soft_max_4(
|
|||
// parallel sum
|
||||
float4 lsum4 = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
|
||||
lsum4 += exp_psrc4;
|
||||
pdst4[i00] = exp_psrc4;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue