ggml : add ggml_soft_max_ext (#4256)
* metal : implement soft_max_ext * cuda : implement soft_max_ext * ggml : implement soft_max_ext (CPU) * batched-bench : print threads ggml-ci * metal : simplify soft_max encoding ggml-ci * cuda : use 512 threads for soft_max instead of 32 * ggml : update soft max cpu * cuda : do warp-based block reduce * cuda : increase max block size to 1024 * cuda : fix warp reduction initialization of shared mem * metal : warp-based reduction for soft max kernel * metal : warp-based reduce for rms_norm * metal : simplify soft max kernel ggml-ci * alloc : fix build with debug
This commit is contained in:
parent
1d144112c0
commit
ef47ec18da
8 changed files with 311 additions and 196 deletions
210
ggml-metal.metal
210
ggml-metal.metal
|
@ -39,6 +39,8 @@ typedef struct {
|
|||
int8_t qs[QK8_0]; // quants
|
||||
} block_q8_0;
|
||||
|
||||
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
||||
|
||||
// general-purpose kernel for addition of two tensors
|
||||
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
|
||||
// cons: not very efficient
|
||||
|
@ -180,10 +182,12 @@ kernel void kernel_gelu(
|
|||
|
||||
kernel void kernel_soft_max(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant float & scale,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
|
@ -194,73 +198,77 @@ kernel void kernel_soft_max(
|
|||
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
|
||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
// parallel max
|
||||
float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
|
||||
float lmax = -INFINITY;
|
||||
|
||||
for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
|
||||
lmax = MAX(lmax, psrc0[i00]);
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
||||
}
|
||||
|
||||
float max = simd_max(lmax);
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = max;
|
||||
// find the max value in the block
|
||||
float max_val = simd_max(lmax);
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = -INFINITY;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = max_val;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
max_val = buf[tiisg];
|
||||
max_val = simd_max(max_val);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// broadcast, simd group number is ntg / 32
|
||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||
if (tpitg < i) {
|
||||
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
max = buf[0];
|
||||
|
||||
// parallel sum
|
||||
float lsum = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
const float exp_psrc0 = exp(psrc0[i00] - max);
|
||||
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
||||
lsum += exp_psrc0;
|
||||
// Remember the result of exp here. exp is expensive, so we really do not
|
||||
// wish to compute it twice.
|
||||
pdst[i00] = exp_psrc0;
|
||||
}
|
||||
|
||||
float sum = simd_sum(lsum);
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = sum;
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = sum;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sum = buf[tiisg];
|
||||
sum = simd_sum(sum);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// broadcast, simd group number is ntg / 32
|
||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||
if (tpitg < i) {
|
||||
buf[tpitg] += buf[tpitg + i];
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sum = buf[0];
|
||||
const float inv_sum = 1.0f/sum;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
pdst[i00] /= sum;
|
||||
pdst[i00] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_soft_max_4(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant float & scale,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
|
@ -271,64 +279,68 @@ kernel void kernel_soft_max_4(
|
|||
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
||||
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
|
||||
// parallel max
|
||||
float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
|
||||
float4 lmax4 = -INFINITY;
|
||||
|
||||
for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
|
||||
lmax4 = fmax(lmax4, psrc4[i00]);
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
||||
}
|
||||
|
||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||
float max = simd_max(lmax);
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = max;
|
||||
|
||||
float max_val = simd_max(lmax);
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = -INFINITY;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = max_val;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
max_val = buf[tiisg];
|
||||
max_val = simd_max(max_val);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// broadcast, simd group number is ntg / 32
|
||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||
if (tpitg < i) {
|
||||
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
max = buf[0];
|
||||
|
||||
// parallel sum
|
||||
float4 lsum4 = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
const float4 exp_psrc4 = exp(psrc4[i00] - max);
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
||||
lsum4 += exp_psrc4;
|
||||
pdst4[i00] = exp_psrc4;
|
||||
}
|
||||
|
||||
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
||||
float sum = simd_sum(lsum);
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = sum;
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = sum;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sum = buf[tiisg];
|
||||
sum = simd_sum(sum);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// broadcast, simd group number is ntg / 32
|
||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||
if (tpitg < i) {
|
||||
buf[tpitg] += buf[tpitg + i];
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sum = buf[0];
|
||||
const float inv_sum = 1.0f/sum;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
pdst4[i00] /= sum;
|
||||
pdst4[i00] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -435,14 +447,13 @@ kernel void kernel_rms_norm(
|
|||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant float & eps,
|
||||
threadgroup float * sum [[threadgroup(0)]],
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
||||
device const float * x_scalar = (device const float *) x;
|
||||
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
||||
|
||||
float4 sumf = 0;
|
||||
float all_sum = 0;
|
||||
|
@ -453,40 +464,30 @@ kernel void kernel_rms_norm(
|
|||
}
|
||||
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
|
||||
all_sum = simd_sum(all_sum);
|
||||
if (tiisg == 0) {
|
||||
sum[sgitg] = all_sum;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// broadcast, simd group number is ntg / 32
|
||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||
if (tpitg < i) {
|
||||
sum[tpitg] += sum[tpitg + i];
|
||||
}
|
||||
}
|
||||
if (tpitg == 0) {
|
||||
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
||||
sum[0] += x_scalar[i];
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
sum[0] /= ne00;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = all_sum;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
all_sum = buf[tiisg];
|
||||
all_sum = simd_sum(all_sum);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const float mean = sum[0];
|
||||
const float mean = all_sum/ne00;
|
||||
const float scale = 1.0f/sqrt(mean + eps);
|
||||
|
||||
device float4 * y = (device float4 *) (dst + tgpig*ne00);
|
||||
device float * y_scalar = (device float *) y;
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
y[i00] = x[i00] * scale;
|
||||
}
|
||||
if (tpitg == 0) {
|
||||
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
||||
y_scalar[i00] = x_scalar[i00] * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
||||
|
@ -576,7 +577,6 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|||
// putting them in the kernel cause a significant performance penalty
|
||||
#define N_DST 4 // each SIMD group works on 4 rows
|
||||
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
||||
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
||||
//Note: This is a template, but strictly speaking it only applies to
|
||||
// quantizations where the block size is 32. It also does not
|
||||
// giard against the number of rows not being divisible by
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue