mtl : fix soft_max kernel
This commit is contained in:
parent
17a70362a6
commit
17930fbcb7
3 changed files with 32 additions and 35 deletions
|
@ -378,11 +378,19 @@ int llama_mtl_eval(
|
||||||
id<MTLBuffer> id_src = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
id<MTLBuffer> id_src = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
||||||
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
||||||
|
|
||||||
|
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
|
||||||
|
const int64_t ne01 = gf->nodes[i]->src0->ne[1];
|
||||||
|
const int64_t ne02 = gf->nodes[i]->src0->ne[2];
|
||||||
|
const int64_t ne03 = gf->nodes[i]->src0->ne[3];
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
||||||
[encoder setBuffer:id_src offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||||
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||||
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
{
|
{
|
||||||
|
|
|
@ -31,9 +31,6 @@ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: not needed
|
|
||||||
constant int nsoftmax [[function_constant(0)]];
|
|
||||||
|
|
||||||
kernel void kernel_add(
|
kernel void kernel_add(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
|
@ -68,42 +65,34 @@ kernel void kernel_relu(
|
||||||
dst[tpig] = max(0.0f, src0[tpig]);
|
dst[tpig] = max(0.0f, src0[tpig]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: broken
|
|
||||||
kernel void kernel_soft_max(
|
kernel void kernel_soft_max(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device float * dst) {
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
uint3 tpig[[thread_position_in_grid]]) {
|
||||||
|
const int64_t i03 = tpig[2];
|
||||||
|
const int64_t i02 = tpig[1];
|
||||||
|
const int64_t i01 = tpig[0];
|
||||||
|
|
||||||
|
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
|
|
||||||
float max = 0.0f;
|
float max = 0.0f;
|
||||||
for (int i = 0; i < nsoftmax; i++) {
|
for (int i = 0; i < ne00; i++) {
|
||||||
max = MAX(max, src0[i]);
|
max = MAX(max, psrc0[i]);
|
||||||
}
|
}
|
||||||
float sum = 0.0f;
|
float sum = 0.0f;
|
||||||
for (int i = 0; i < nsoftmax; i++) {
|
for (int i = 0; i < ne00; i++) {
|
||||||
dst[i] = exp(src0[i] - max);
|
pdst[i] = exp(psrc0[i] - max);
|
||||||
sum += dst[i];
|
sum += pdst[i];
|
||||||
}
|
}
|
||||||
for (int i = 0; i < nsoftmax; i++) {
|
for (int i = 0; i < ne00; i++) {
|
||||||
dst[i] /= sum;
|
pdst[i] /= sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//const int n = ggml_nrows(src0);
|
|
||||||
//const int nc = src0->ne[0];
|
|
||||||
//const int nr = src0->ne[1];
|
|
||||||
//const int nz = n/nr;
|
|
||||||
//
|
|
||||||
//assert( dst->nb[0] == sizeof(float));
|
|
||||||
//assert(src0->nb[0] == sizeof(float));
|
|
||||||
//
|
|
||||||
//for (int k = 0; k < nz; k++) {
|
|
||||||
// for (int j = ith; j < nr; j += nth) {
|
|
||||||
// for (int i = n_past; i < nc; i++) {
|
|
||||||
// if (i > n_past + j) {
|
|
||||||
// *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
|
|
||||||
kernel void kernel_diag_mask_inf(
|
kernel void kernel_diag_mask_inf(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
|
|
@ -1336,15 +1336,15 @@ static bool llama_eval_internal(
|
||||||
// KQ_masked = mask_past(KQ_scaled)
|
// KQ_masked = mask_past(KQ_scaled)
|
||||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
|
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
|
||||||
ggml_set_name(KQ_masked, "KQ_masked");
|
ggml_set_name(KQ_masked, "KQ_masked");
|
||||||
// TODO: TMP !!!!
|
|
||||||
if (il == 0) {
|
|
||||||
ggml_set_name(KQ_masked, "mtl-check");
|
|
||||||
}
|
|
||||||
|
|
||||||
// KQ = soft_max(KQ_masked)
|
// KQ = soft_max(KQ_masked)
|
||||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
|
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
|
||||||
ggml_set_name(KQ_soft_max, "KQ_soft_max");
|
ggml_set_name(KQ_soft_max, "KQ_soft_max");
|
||||||
|
|
||||||
|
// TODO: TMP !!!!
|
||||||
|
if (il == 0) {
|
||||||
|
ggml_set_name(KQ_soft_max, "mtl-check");
|
||||||
|
}
|
||||||
|
|
||||||
// split cached V into n_head heads
|
// split cached V into n_head heads
|
||||||
struct ggml_tensor * V =
|
struct ggml_tensor * V =
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue