mtl : initial mul_mat Q4 kernel (wrong results)

This commit is contained in:
Georgi Gerganov 2023-05-30 22:02:54 +03:00
parent 64afc0b53a
commit 2a24994bad
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 144 additions and 19 deletions

View file

@ -38,6 +38,9 @@ struct ggml_mtl_context {
id<MTLFunction> function_rms_norm;
id<MTLComputePipelineState> pipeline_rms_norm;
id<MTLFunction> function_mul_mat_q4_0;
id<MTLComputePipelineState> pipeline_mul_mat_q4_0;
};
// MSL code
@ -141,6 +144,10 @@ struct ggml_mtl_context * llama_mtl_init(
ctx->function_rms_norm = [ctx->library newFunctionWithName:@"kernel_rms_norm"];
ctx->pipeline_rms_norm = [ctx->device newComputePipelineStateWithFunction:ctx->function_rms_norm error:nil];
fprintf(stderr, "%s: loaded kernel_rms_norm: %p\n", __func__, (void *) ctx->pipeline_rms_norm);
ctx->function_mul_mat_q4_0 = [ctx->library newFunctionWithName:@"kernel_mul_mat_q4_0"];
ctx->pipeline_mul_mat_q4_0 = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul_mat_q4_0 error:nil];
fprintf(stderr, "%s: loaded kernel_mul_mat_q4_0: %p\n", __func__, (void *) ctx->pipeline_mul_mat_q4_0);
}
// MTLBuffer approach
@ -317,7 +324,9 @@ int llama_mtl_eval(
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_MUL_MAT:
{
if (gf->nodes[i]->src0->type == GGML_TYPE_F32) {
// for F32 x F32 we use MPS
if (encoder != nil) {
[encoder endEncoding];
encoder = nil;
@ -354,6 +363,43 @@ int llama_mtl_eval(
transposeLeft:false transposeRight:true resultRows:nrows1 resultColumns:nrows0 interiorColumns:ncols0 alpha:1.0 beta:0.0];
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
} else {
// for Q4 x F32 we use custom kernel
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
GGML_ASSERT(gf->nodes[i]->src0->ne[2] == 1);
GGML_ASSERT(gf->nodes[i]->src1->ne[2] == 1);
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
const int64_t ncols0 = gf->nodes[i]->src0->ne[0];
const int64_t nrows0 = gf->nodes[i]->src0->ne[1];
const int64_t ncols1 = gf->nodes[i]->src1->ne[0];
const int64_t nrows1 = gf->nodes[i]->src1->ne[1];
const int64_t ncols = gf->nodes[i]->ne[0];
const int64_t nrows = gf->nodes[i]->ne[1];
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ncols0 length:sizeof(ncols0) atIndex:3];
[encoder setBytes:&nrows0 length:sizeof(nrows0) atIndex:4];
[encoder setBytes:&ncols1 length:sizeof(ncols1) atIndex:5];
[encoder setBytes:&nrows1 length:sizeof(nrows1) atIndex:6];
[encoder setBytes:&ncols length:sizeof(ncols) atIndex:7];
[encoder setBytes:&nrows length:sizeof(nrows) atIndex:8];
printf("mul_mat: %lldx%lld * %lldx%lld -> %lldx%lld\n", ncols0, nrows0, ncols1, nrows1, ncols, nrows);
[encoder dispatchThreadgroups:MTLSizeMake(nrows0, nrows1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
} break;
case GGML_OP_GET_ROWS:
{

View file

@ -7,8 +7,8 @@ using namespace metal;
#define QK4_0 32
#define QR4_0 2
typedef struct {
half d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
half d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
} block_q4_0;
static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
@ -38,8 +38,8 @@ kernel void kernel_add(
device const float * src0,
device const float * src1,
device float * dst,
uint gid[[thread_position_in_grid]]) {
dst[gid] = src0[gid] + src1[gid];
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] + src1[tpig];
}
// assumption: src1 is a row
@ -49,15 +49,15 @@ kernel void kernel_mul(
device const float * src1,
device float * dst,
constant int64_t & ne00,
uint gid[[thread_position_in_grid]]) {
dst[gid] = src0[gid] * src1[gid % ne00];
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src1[tpig % ne00];
}
kernel void kernel_relu(
device const float * src0,
device float * dst,
uint gid[[thread_position_in_grid]]) {
dst[gid] = max(0.0f, src0[gid]);
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = max(0.0f, src0[tpig]);
}
// TODO: broken
@ -85,8 +85,8 @@ kernel void kernel_get_rows_q4_0(
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb1,
uint gid[[thread_position_in_grid]]) {
const int i = gid;
uint tpig[[thread_position_in_grid]]) {
const int i = tpig;
const int r = ((device int32_t *) src1)[i];
dequantize_row_q4_0(
@ -100,8 +100,8 @@ kernel void kernel_rms_norm(
constant int64_t & ne00,
constant uint64_t & nb01,
constant float & eps,
uint gid[[thread_position_in_grid]]) {
device const float * x = (device const float *) ((device const char *) src0 + gid*nb01);
uint tpig[[thread_position_in_grid]]) {
device const float * x = (device const float *) ((device const char *) src0 + tpig*nb01);
float sum = 0.0f;
for (int i00 = 0; i00 < ne00; i00++) {
@ -111,8 +111,84 @@ kernel void kernel_rms_norm(
const float mean = sum/ne00;
const float scale = 1.0f/sqrt(mean + eps);
device float * y = dst + gid*ne00;
device float * y = dst + tpig*ne00;
for (int i00 = 0; i00 < ne00; i00++) {
y[i00] = x[i00] * scale;
}
}
kernel void kernel_mul_mat_q4_0(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne0,
constant int64_t & ne1,
uint2 tgpig[[threadgroup_position_in_grid]],
uint2 tpig[[thread_position_in_grid]],
uint2 tpitg[[thread_position_in_threadgroup]],
uint2 tptg[[threads_per_threadgroup]]) {
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
const int qk = QK4_0;
const int nb = ne00/qk;
device const block_q4_0 * x = (device const block_q4_0 *) (src0) + r0*nb;
device const float * y = (device const float *) (src1) + r1*ne10;
threadgroup float sum[32]; // TODO: should be equal to threadgroup size
sum[tpitg.x] = 0.0f;
for (int i = 0; i < nb; i += tptg.x) {
device const uint4 * x0p = (device const uint4 *) (x + i);
device const float4 * y0p = (device const float4 *) (y + i*qk);
const uint4 x0 = *x0p;
const uint4 x0l = x0 & uint4(0x0F0F0F0F);
const uint4 x0h = x0 >> 4;
const int4 x0ls = as_type<int4>(x0l) - int4(8);
const int4 x0hs = as_type<int4>(x0h) - int4(8);
thread const uchar * x0lsb = (thread const uchar *) &x0ls;
thread const uchar * x0hsb = (thread const uchar *) &x0hs;
const float4 y00 = *(y0p + 0);
const float4 y01 = *(y0p + 1);
const float4 y02 = *(y0p + 2);
const float4 y03 = *(y0p + 3);
const float4 y04 = *(y0p + 4);
const float4 y05 = *(y0p + 5);
const float4 y06 = *(y0p + 6);
const float4 y07 = *(y0p + 7);
const float d = (x + i)->d;
sum[tpitg.x] += (
x0lsb[ 0]*y00[0] + x0lsb[ 1]*y00[1] + x0lsb[ 2]*y00[2] + x0lsb[ 3]*y00[3] +
x0lsb[ 4]*y01[0] + x0lsb[ 5]*y01[1] + x0lsb[ 6]*y01[2] + x0lsb[ 7]*y01[3] +
x0lsb[ 8]*y02[0] + x0lsb[ 9]*y02[1] + x0lsb[10]*y02[2] + x0lsb[11]*y02[3] +
x0lsb[12]*y03[0] + x0lsb[13]*y03[1] + x0lsb[14]*y03[2] + x0lsb[15]*y03[3] +
x0hsb[ 0]*y04[0] + x0hsb[ 1]*y04[1] + x0hsb[ 2]*y04[2] + x0hsb[ 3]*y04[3] +
x0hsb[ 4]*y05[0] + x0hsb[ 5]*y05[1] + x0hsb[ 6]*y05[2] + x0hsb[ 7]*y05[3] +
x0hsb[ 8]*y06[0] + x0hsb[ 9]*y06[1] + x0hsb[10]*y06[2] + x0hsb[11]*y06[3] +
x0hsb[12]*y07[0] + x0hsb[13]*y07[1] + x0hsb[14]*y07[2] + x0hsb[15]*y07[3]
) * d;
}
// accumulate the sum from all threads in the threadgroup
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = tptg.x/2; i > 0; i /= 2) {
if (tpitg.x < i) {
sum[tpitg.x] += sum[tpitg.x + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
dst[r1*ne0 + r0] = sum[0];
}

View file

@ -1266,16 +1266,19 @@ static bool llama_eval_internal(
// cur = cur*attention_norm(broadcasted)
cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm);
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(cur, "mtl-check");
}
}
// self-attention
{
auto * x = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(x, "mtl-check");
}
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
//struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, x, n_embd/n_head, n_head, N), n_past, n_rot, 0);
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
ggml_set_name(Qcur, "Qcur");
ggml_set_name(Kcur, "Kcur");