Extend kernel_mul_mat_f16_f32 to handle gqa broadcast
This commit is contained in:
parent
fee39ecd48
commit
d28b07ca7c
2 changed files with 17 additions and 71 deletions
35
ggml-metal.m
35
ggml-metal.m
|
@ -65,7 +65,6 @@ struct ggml_metal_context {
|
||||||
GGML_METAL_DECL_KERNEL(rms_norm);
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
||||||
GGML_METAL_DECL_KERNEL(norm);
|
GGML_METAL_DECL_KERNEL(norm);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_gqa8);
|
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
|
||||||
|
@ -183,7 +182,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||||
GGML_METAL_ADD_KERNEL(rms_norm);
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
||||||
GGML_METAL_ADD_KERNEL(norm);
|
GGML_METAL_ADD_KERNEL(norm);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_gqa8);
|
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
|
||||||
|
@ -720,8 +718,7 @@ void ggml_metal_graph_compute(
|
||||||
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
|
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
|
||||||
|
|
||||||
GGML_ASSERT(ne00 == ne10);
|
GGML_ASSERT(ne00 == ne10);
|
||||||
int llama_2_70_gqa_step = ne02 == 8 && ne12 == 64;
|
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
|
||||||
GGML_ASSERT(ne02 == ne12 || llama_2_70_gqa_step);
|
|
||||||
|
|
||||||
if (ggml_is_contiguous(src0) &&
|
if (ggml_is_contiguous(src0) &&
|
||||||
ggml_is_contiguous(src1) &&
|
ggml_is_contiguous(src1) &&
|
||||||
|
@ -775,15 +772,9 @@ void ggml_metal_graph_compute(
|
||||||
switch (src0t) {
|
switch (src0t) {
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne02 == ne12 || llama_2_70_gqa_step);
|
|
||||||
|
|
||||||
nth0 = 64;
|
nth0 = 64;
|
||||||
nth1 = 1;
|
nth1 = 1;
|
||||||
if (llama_2_70_gqa_step) {
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_gqa8];
|
|
||||||
} else {
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
|
||||||
}
|
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
{
|
{
|
||||||
|
@ -860,16 +851,18 @@ void ggml_metal_graph_compute(
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
|
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
||||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
||||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
||||||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
|
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
||||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
|
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
||||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
|
||||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
|
||||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
|
||||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
||||||
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
||||||
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
||||||
|
|
||||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
||||||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
||||||
|
|
|
@ -509,11 +509,13 @@ kernel void kernel_mul_mat_f16_f32(
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
constant uint64_t & nb00,
|
constant uint64_t & nb00,
|
||||||
constant uint64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
constant uint64_t & nb02,
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne10,
|
constant int64_t & ne10,
|
||||||
constant int64_t & ne11,
|
constant int64_t & ne11,
|
||||||
|
constant int64_t & ne12,
|
||||||
constant uint64_t & nb10,
|
constant uint64_t & nb10,
|
||||||
constant uint64_t & nb11,
|
constant uint64_t & nb11,
|
||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
|
@ -529,56 +531,7 @@ kernel void kernel_mul_mat_f16_f32(
|
||||||
const int64_t r1 = tgpig.y;
|
const int64_t r1 = tgpig.y;
|
||||||
const int64_t im = tgpig.z;
|
const int64_t im = tgpig.z;
|
||||||
|
|
||||||
device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02);
|
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
|
||||||
|
|
||||||
sum[tpitg.x] = 0.0f;
|
|
||||||
|
|
||||||
for (int i = tpitg.x; i < ne00; i += tptg.x) {
|
|
||||||
sum[tpitg.x] += (float) x[i] * (float) y[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
// accumulate the sum from all threads in the threadgroup
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
for (uint i = tptg.x/2; i > 0; i /= 2) {
|
|
||||||
if (tpitg.x < i) {
|
|
||||||
sum[tpitg.x] += sum[tpitg.x + i];
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tpitg.x == 0) {
|
|
||||||
dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
kernel void kernel_mul_mat_f16_f32_gqa8(
|
|
||||||
device const char * src0,
|
|
||||||
device const char * src1,
|
|
||||||
device float * dst,
|
|
||||||
constant int64_t & ne00,
|
|
||||||
constant int64_t & ne01,
|
|
||||||
constant uint64_t & nb00,
|
|
||||||
constant uint64_t & nb01,
|
|
||||||
constant uint64_t & nb02,
|
|
||||||
constant int64_t & ne10,
|
|
||||||
constant int64_t & ne11,
|
|
||||||
constant uint64_t & nb10,
|
|
||||||
constant uint64_t & nb11,
|
|
||||||
constant uint64_t & nb12,
|
|
||||||
constant int64_t & ne0,
|
|
||||||
constant int64_t & ne1,
|
|
||||||
threadgroup float * sum [[threadgroup(0)]],
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
||||||
uint3 tpig[[thread_position_in_grid]],
|
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
||||||
uint3 tptg[[threads_per_threadgroup]]) {
|
|
||||||
|
|
||||||
const int64_t r0 = tgpig.x;
|
|
||||||
const int64_t r1 = tgpig.y;
|
|
||||||
const int64_t im = tgpig.z;
|
|
||||||
|
|
||||||
device const half * x = (device const half *) (src0 + r0*nb01 + im/8*nb02);
|
|
||||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||||
|
|
||||||
sum[tpitg.x] = 0.0f;
|
sum[tpitg.x] = 0.0f;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue