mtl : fix soft_max kernel

This commit is contained in:
Georgi Gerganov 2023-06-01 20:48:24 +03:00
parent 17a70362a6
commit 17930fbcb7
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 32 additions and 35 deletions

View file

@ -378,11 +378,19 @@ int llama_mtl_eval(
id<MTLBuffer> id_src = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
const int64_t ne01 = gf->nodes[i]->src0->ne[1];
const int64_t ne02 = gf->nodes[i]->src0->ne[2];
const int64_t ne03 = gf->nodes[i]->src0->ne[3];
[encoder setComputePipelineState:ctx->pipeline_soft_max];
[encoder setBuffer:id_src offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_DIAG_MASK_INF:
{

View file

@ -31,9 +31,6 @@ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, i
}
}
// TODO: not needed
constant int nsoftmax [[function_constant(0)]];
kernel void kernel_add(
device const float * src0,
device const float * src1,
@ -68,42 +65,34 @@ kernel void kernel_relu(
dst[tpig] = max(0.0f, src0[tpig]);
}
// TODO: broken
kernel void kernel_soft_max(
device const float * src0,
device float * dst) {
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
uint3 tpig[[thread_position_in_grid]]) {
const int64_t i03 = tpig[2];
const int64_t i02 = tpig[1];
const int64_t i01 = tpig[0];
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
float max = 0.0f;
for (int i = 0; i < nsoftmax; i++) {
max = MAX(max, src0[i]);
for (int i = 0; i < ne00; i++) {
max = MAX(max, psrc0[i]);
}
float sum = 0.0f;
for (int i = 0; i < nsoftmax; i++) {
dst[i] = exp(src0[i] - max);
sum += dst[i];
for (int i = 0; i < ne00; i++) {
pdst[i] = exp(psrc0[i] - max);
sum += pdst[i];
}
for (int i = 0; i < nsoftmax; i++) {
dst[i] /= sum;
for (int i = 0; i < ne00; i++) {
pdst[i] /= sum;
}
}
//const int n = ggml_nrows(src0);
//const int nc = src0->ne[0];
//const int nr = src0->ne[1];
//const int nz = n/nr;
//
//assert( dst->nb[0] == sizeof(float));
//assert(src0->nb[0] == sizeof(float));
//
//for (int k = 0; k < nz; k++) {
// for (int j = ith; j < nr; j += nth) {
// for (int i = n_past; i < nc; i++) {
// if (i > n_past + j) {
// *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
// }
// }
// }
//}
kernel void kernel_diag_mask_inf(
device const float * src0,
device float * dst,

View file

@ -1336,15 +1336,15 @@ static bool llama_eval_internal(
// KQ_masked = mask_past(KQ_scaled)
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
ggml_set_name(KQ_masked, "KQ_masked");
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(KQ_masked, "mtl-check");
}
// KQ = soft_max(KQ_masked)
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
ggml_set_name(KQ_soft_max, "KQ_soft_max");
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(KQ_soft_max, "mtl-check");
}
// split cached V into n_head heads
struct ggml_tensor * V =