ggml : full ALiBi support
This commit is contained in:
parent
d11afd6652
commit
7fdca3348c
10 changed files with 82 additions and 680 deletions
|
@ -356,7 +356,6 @@ template<typename T>
|
|||
kernel void kernel_soft_max(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -378,10 +377,9 @@ kernel void kernel_soft_max(
|
|||
|
||||
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
|
||||
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
||||
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
|
||||
float slope = 0.0f;
|
||||
float slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
|
@ -397,7 +395,7 @@ kernel void kernel_soft_max(
|
|||
float lmax = -INFINITY;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
|
||||
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
||||
}
|
||||
|
||||
// find the max value in the block
|
||||
|
@ -422,7 +420,7 @@ kernel void kernel_soft_max(
|
|||
// parallel sum
|
||||
float lsum = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
|
||||
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
||||
lsum += exp_psrc0;
|
||||
pdst[i00] = exp_psrc0;
|
||||
}
|
||||
|
@ -461,7 +459,6 @@ template<typename T>
|
|||
kernel void kernel_soft_max_4(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -483,10 +480,9 @@ kernel void kernel_soft_max_4(
|
|||
|
||||
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
|
||||
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
||||
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||
|
||||
float slope = 0.0f;
|
||||
float slope = 1.0f;
|
||||
|
||||
if (max_bias > 0.0f) {
|
||||
const int64_t h = i02;
|
||||
|
@ -501,7 +497,7 @@ kernel void kernel_soft_max_4(
|
|||
float4 lmax4 = -INFINITY;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
||||
}
|
||||
|
||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||
|
@ -527,7 +523,7 @@ kernel void kernel_soft_max_4(
|
|||
// parallel sum
|
||||
float4 lsum4 = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
||||
lsum4 += exp_psrc4;
|
||||
pdst4[i00] = exp_psrc4;
|
||||
}
|
||||
|
@ -1595,60 +1591,6 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|||
}
|
||||
}
|
||||
|
||||
kernel void kernel_alibi_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
constant float & m0,
|
||||
constant float & m1,
|
||||
constant int & n_heads_log2_floor,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig[2];
|
||||
const int64_t i02 = tgpig[1];
|
||||
const int64_t i01 = tgpig[0];
|
||||
|
||||
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||
//const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
||||
|
||||
const int64_t k = i3*ne3 + i2;
|
||||
|
||||
float m_k;
|
||||
if (k < n_heads_log2_floor) {
|
||||
m_k = pow(m0, k + 1);
|
||||
} else {
|
||||
m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
|
||||
}
|
||||
|
||||
device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
|
||||
device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
||||
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
||||
const float src_v = *(device float *)(src_row + i00*nb00);
|
||||
device float * dst_v = (device float *)(dst_row + i00*nb0);
|
||||
*dst_v = i00 * m_k + src_v;
|
||||
}
|
||||
}
|
||||
|
||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||
return 1.0f - min(1.0f, max(0.0f, y));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue