mtl : initial mul_mat Q4 kernel (wrong results)
This commit is contained in:
parent
64afc0b53a
commit
2a24994bad
3 changed files with 144 additions and 19 deletions
|
@ -38,6 +38,9 @@ struct ggml_mtl_context {
|
|||
|
||||
id<MTLFunction> function_rms_norm;
|
||||
id<MTLComputePipelineState> pipeline_rms_norm;
|
||||
|
||||
id<MTLFunction> function_mul_mat_q4_0;
|
||||
id<MTLComputePipelineState> pipeline_mul_mat_q4_0;
|
||||
};
|
||||
|
||||
// MSL code
|
||||
|
@ -141,6 +144,10 @@ struct ggml_mtl_context * llama_mtl_init(
|
|||
ctx->function_rms_norm = [ctx->library newFunctionWithName:@"kernel_rms_norm"];
|
||||
ctx->pipeline_rms_norm = [ctx->device newComputePipelineStateWithFunction:ctx->function_rms_norm error:nil];
|
||||
fprintf(stderr, "%s: loaded kernel_rms_norm: %p\n", __func__, (void *) ctx->pipeline_rms_norm);
|
||||
|
||||
ctx->function_mul_mat_q4_0 = [ctx->library newFunctionWithName:@"kernel_mul_mat_q4_0"];
|
||||
ctx->pipeline_mul_mat_q4_0 = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul_mat_q4_0 error:nil];
|
||||
fprintf(stderr, "%s: loaded kernel_mul_mat_q4_0: %p\n", __func__, (void *) ctx->pipeline_mul_mat_q4_0);
|
||||
}
|
||||
|
||||
// MTLBuffer approach
|
||||
|
@ -317,7 +324,9 @@ int llama_mtl_eval(
|
|||
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
if (gf->nodes[i]->src0->type == GGML_TYPE_F32) {
|
||||
// for F32 x F32 we use MPS
|
||||
|
||||
if (encoder != nil) {
|
||||
[encoder endEncoding];
|
||||
encoder = nil;
|
||||
|
@ -354,6 +363,43 @@ int llama_mtl_eval(
|
|||
transposeLeft:false transposeRight:true resultRows:nrows1 resultColumns:nrows0 interiorColumns:ncols0 alpha:1.0 beta:0.0];
|
||||
|
||||
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
|
||||
} else {
|
||||
// for Q4 x F32 we use custom kernel
|
||||
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
}
|
||||
|
||||
GGML_ASSERT(gf->nodes[i]->src0->ne[2] == 1);
|
||||
GGML_ASSERT(gf->nodes[i]->src1->ne[2] == 1);
|
||||
|
||||
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
||||
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
|
||||
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
||||
|
||||
const int64_t ncols0 = gf->nodes[i]->src0->ne[0];
|
||||
const int64_t nrows0 = gf->nodes[i]->src0->ne[1];
|
||||
|
||||
const int64_t ncols1 = gf->nodes[i]->src1->ne[0];
|
||||
const int64_t nrows1 = gf->nodes[i]->src1->ne[1];
|
||||
|
||||
const int64_t ncols = gf->nodes[i]->ne[0];
|
||||
const int64_t nrows = gf->nodes[i]->ne[1];
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ncols0 length:sizeof(ncols0) atIndex:3];
|
||||
[encoder setBytes:&nrows0 length:sizeof(nrows0) atIndex:4];
|
||||
[encoder setBytes:&ncols1 length:sizeof(ncols1) atIndex:5];
|
||||
[encoder setBytes:&nrows1 length:sizeof(nrows1) atIndex:6];
|
||||
[encoder setBytes:&ncols length:sizeof(ncols) atIndex:7];
|
||||
[encoder setBytes:&nrows length:sizeof(nrows) atIndex:8];
|
||||
|
||||
printf("mul_mat: %lldx%lld * %lldx%lld -> %lldx%lld\n", ncols0, nrows0, ncols1, nrows1, ncols, nrows);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows0, nrows1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
|
|
|
@ -38,8 +38,8 @@ kernel void kernel_add(
|
|||
device const float * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
dst[gid] = src0[gid] + src1[gid];
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] + src1[tpig];
|
||||
}
|
||||
|
||||
// assumption: src1 is a row
|
||||
|
@ -49,15 +49,15 @@ kernel void kernel_mul(
|
|||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
dst[gid] = src0[gid] * src1[gid % ne00];
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] * src1[tpig % ne00];
|
||||
}
|
||||
|
||||
kernel void kernel_relu(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
dst[gid] = max(0.0f, src0[gid]);
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = max(0.0f, src0[tpig]);
|
||||
}
|
||||
|
||||
// TODO: broken
|
||||
|
@ -85,8 +85,8 @@ kernel void kernel_get_rows_q4_0(
|
|||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb1,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
const int i = gid;
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const int i = tpig;
|
||||
const int r = ((device int32_t *) src1)[i];
|
||||
|
||||
dequantize_row_q4_0(
|
||||
|
@ -100,8 +100,8 @@ kernel void kernel_rms_norm(
|
|||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant float & eps,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
device const float * x = (device const float *) ((device const char *) src0 + gid*nb01);
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float * x = (device const float *) ((device const char *) src0 + tpig*nb01);
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
|
@ -111,8 +111,84 @@ kernel void kernel_rms_norm(
|
|||
const float mean = sum/ne00;
|
||||
const float scale = 1.0f/sqrt(mean + eps);
|
||||
|
||||
device float * y = dst + gid*ne00;
|
||||
device float * y = dst + tpig*ne00;
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
y[i00] = x[i00] * scale;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mat_q4_0(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
uint2 tgpig[[threadgroup_position_in_grid]],
|
||||
uint2 tpig[[thread_position_in_grid]],
|
||||
uint2 tpitg[[thread_position_in_threadgroup]],
|
||||
uint2 tptg[[threads_per_threadgroup]]) {
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t r1 = tgpig.y;
|
||||
|
||||
const int qk = QK4_0;
|
||||
const int nb = ne00/qk;
|
||||
|
||||
device const block_q4_0 * x = (device const block_q4_0 *) (src0) + r0*nb;
|
||||
device const float * y = (device const float *) (src1) + r1*ne10;
|
||||
|
||||
threadgroup float sum[32]; // TODO: should be equal to threadgroup size
|
||||
sum[tpitg.x] = 0.0f;
|
||||
|
||||
for (int i = 0; i < nb; i += tptg.x) {
|
||||
device const uint4 * x0p = (device const uint4 *) (x + i);
|
||||
device const float4 * y0p = (device const float4 *) (y + i*qk);
|
||||
|
||||
const uint4 x0 = *x0p;
|
||||
|
||||
const uint4 x0l = x0 & uint4(0x0F0F0F0F);
|
||||
const uint4 x0h = x0 >> 4;
|
||||
|
||||
const int4 x0ls = as_type<int4>(x0l) - int4(8);
|
||||
const int4 x0hs = as_type<int4>(x0h) - int4(8);
|
||||
|
||||
thread const uchar * x0lsb = (thread const uchar *) &x0ls;
|
||||
thread const uchar * x0hsb = (thread const uchar *) &x0hs;
|
||||
|
||||
const float4 y00 = *(y0p + 0);
|
||||
const float4 y01 = *(y0p + 1);
|
||||
const float4 y02 = *(y0p + 2);
|
||||
const float4 y03 = *(y0p + 3);
|
||||
const float4 y04 = *(y0p + 4);
|
||||
const float4 y05 = *(y0p + 5);
|
||||
const float4 y06 = *(y0p + 6);
|
||||
const float4 y07 = *(y0p + 7);
|
||||
|
||||
const float d = (x + i)->d;
|
||||
|
||||
sum[tpitg.x] += (
|
||||
x0lsb[ 0]*y00[0] + x0lsb[ 1]*y00[1] + x0lsb[ 2]*y00[2] + x0lsb[ 3]*y00[3] +
|
||||
x0lsb[ 4]*y01[0] + x0lsb[ 5]*y01[1] + x0lsb[ 6]*y01[2] + x0lsb[ 7]*y01[3] +
|
||||
x0lsb[ 8]*y02[0] + x0lsb[ 9]*y02[1] + x0lsb[10]*y02[2] + x0lsb[11]*y02[3] +
|
||||
x0lsb[12]*y03[0] + x0lsb[13]*y03[1] + x0lsb[14]*y03[2] + x0lsb[15]*y03[3] +
|
||||
x0hsb[ 0]*y04[0] + x0hsb[ 1]*y04[1] + x0hsb[ 2]*y04[2] + x0hsb[ 3]*y04[3] +
|
||||
x0hsb[ 4]*y05[0] + x0hsb[ 5]*y05[1] + x0hsb[ 6]*y05[2] + x0hsb[ 7]*y05[3] +
|
||||
x0hsb[ 8]*y06[0] + x0hsb[ 9]*y06[1] + x0hsb[10]*y06[2] + x0hsb[11]*y06[3] +
|
||||
x0hsb[12]*y07[0] + x0hsb[13]*y07[1] + x0hsb[14]*y07[2] + x0hsb[15]*y07[3]
|
||||
) * d;
|
||||
}
|
||||
|
||||
// accumulate the sum from all threads in the threadgroup
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (uint i = tptg.x/2; i > 0; i /= 2) {
|
||||
if (tpitg.x < i) {
|
||||
sum[tpitg.x] += sum[tpitg.x + i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
dst[r1*ne0 + r0] = sum[0];
|
||||
}
|
||||
|
|
13
llama.cpp
13
llama.cpp
|
@ -1266,16 +1266,19 @@ static bool llama_eval_internal(
|
|||
|
||||
// cur = cur*attention_norm(broadcasted)
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm);
|
||||
// TODO: TMP !!!!
|
||||
if (il == 0) {
|
||||
ggml_set_name(cur, "mtl-check");
|
||||
}
|
||||
}
|
||||
|
||||
// self-attention
|
||||
{
|
||||
auto * x = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
// TODO: TMP !!!!
|
||||
if (il == 0) {
|
||||
ggml_set_name(x, "mtl-check");
|
||||
}
|
||||
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
//struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, x, n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
ggml_set_name(Qcur, "Qcur");
|
||||
ggml_set_name(Kcur, "Kcur");
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue