Extend kernel_mul_mat_f16_f32 to handle gqa broadcast

This commit is contained in:
Matteo Boschini 2023-07-31 14:41:23 +02:00
parent fee39ecd48
commit d28b07ca7c
2 changed files with 17 additions and 71 deletions

View file

@ -65,7 +65,6 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(norm); GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_gqa8);
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32); GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
@ -183,7 +182,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(norm); GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_gqa8);
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32); GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
@ -720,8 +718,7 @@ void ggml_metal_graph_compute(
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224 // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
GGML_ASSERT(ne00 == ne10); GGML_ASSERT(ne00 == ne10);
int llama_2_70_gqa_step = ne02 == 8 && ne12 == 64; // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
GGML_ASSERT(ne02 == ne12 || llama_2_70_gqa_step);
if (ggml_is_contiguous(src0) && if (ggml_is_contiguous(src0) &&
ggml_is_contiguous(src1) && ggml_is_contiguous(src1) &&
@ -775,15 +772,9 @@ void ggml_metal_graph_compute(
switch (src0t) { switch (src0t) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
{ {
GGML_ASSERT(ne02 == ne12 || llama_2_70_gqa_step);
nth0 = 64; nth0 = 64;
nth1 = 1; nth1 = 1;
if (llama_2_70_gqa_step) { [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_gqa8];
} else {
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
}
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
{ {
@ -860,16 +851,18 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6]; [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9]; [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10]; [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11]; [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12]; [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {

View file

@ -509,11 +509,13 @@ kernel void kernel_mul_mat_f16_f32(
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01, constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00, constant uint64_t & nb00,
constant uint64_t & nb01, constant uint64_t & nb01,
constant uint64_t & nb02, constant uint64_t & nb02,
constant int64_t & ne10, constant int64_t & ne10,
constant int64_t & ne11, constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10, constant uint64_t & nb10,
constant uint64_t & nb11, constant uint64_t & nb11,
constant uint64_t & nb12, constant uint64_t & nb12,
@ -529,56 +531,7 @@ kernel void kernel_mul_mat_f16_f32(
const int64_t r1 = tgpig.y; const int64_t r1 = tgpig.y;
const int64_t im = tgpig.z; const int64_t im = tgpig.z;
device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02); device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
sum[tpitg.x] = 0.0f;
for (int i = tpitg.x; i < ne00; i += tptg.x) {
sum[tpitg.x] += (float) x[i] * (float) y[i];
}
// accumulate the sum from all threads in the threadgroup
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = tptg.x/2; i > 0; i /= 2) {
if (tpitg.x < i) {
sum[tpitg.x] += sum[tpitg.x + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
if (tpitg.x == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
}
}
kernel void kernel_mul_mat_f16_f32_gqa8(
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
threadgroup float * sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpig[[thread_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]]) {
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
const int64_t im = tgpig.z;
device const half * x = (device const half *) (src0 + r0*nb01 + im/8*nb02);
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
sum[tpitg.x] = 0.0f; sum[tpitg.x] = 0.0f;