mtl : fix bug in f16 x f32 mul mat + speed-up computation

This commit is contained in:
Georgi Gerganov 2023-06-02 18:23:51 +03:00
parent e55f7b0bdb
commit 33671460b0
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 79 additions and 37 deletions

View file

@ -480,41 +480,41 @@ int llama_mtl_eval(
const int64_t ne01 = gf->nodes[i]->src0->ne[1];
const int64_t ne02 = gf->nodes[i]->src0->ne[2];
//const uint64_t nb00 = gf->nodes[i]->src0->nb[0];
//const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
const uint64_t nb00 = gf->nodes[i]->src0->nb[0];
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
const uint64_t nb02 = gf->nodes[i]->src0->nb[2];
const int64_t ne10 = gf->nodes[i]->src1->ne[0];
const int64_t ne11 = gf->nodes[i]->src1->ne[1];
const int64_t ne12 = gf->nodes[i]->src1->ne[2];
//const uint64_t nb10 = gf->nodes[i]->src1->nb[0];
//const uint64_t nb11 = gf->nodes[i]->src1->nb[1];
const uint64_t nb10 = gf->nodes[i]->src1->nb[0];
const uint64_t nb11 = gf->nodes[i]->src1->nb[1];
const uint64_t nb12 = gf->nodes[i]->src1->nb[2];
const int64_t ne0 = gf->nodes[i]->ne[0];
const int64_t ne1 = gf->nodes[i]->ne[1];
const int64_t ne2 = gf->nodes[i]->ne[2];
//const uint64_t nb0 = gf->nodes[i]->nb[0];
//const uint64_t nb1 = gf->nodes[i]->nb[1];
const uint64_t nb0 = gf->nodes[i]->nb[0];
const uint64_t nb1 = gf->nodes[i]->nb[1];
const uint64_t nb2 = gf->nodes[i]->nb[2];
const int nth = 16;
const enum ggml_type src0t = gf->nodes[i]->src0->type;
const enum ggml_type src1t = gf->nodes[i]->src1->type;
const enum ggml_type dstt = gf->nodes[i]->type;
fprintf(stderr, "mul_mat: src0 - %s[%lld, %lld, %lld]\n", ggml_type_name(src0t), ne00, ne01, ne02);
fprintf(stderr, "mul_mat: src1 - %s[%lld, %lld, %lld]\n", ggml_type_name(src1t), ne10, ne11, ne12);
fprintf(stderr, "mul_mat: src0 - %s[%lld, %lld, %lld], %d\n", ggml_type_name(src0t), ne00, ne01, ne02, ggml_is_contiguous(gf->nodes[i]->src0));
fprintf(stderr, "mul_mat: src1 - %s[%lld, %lld, %lld], %d\n", ggml_type_name(src1t), ne10, ne11, ne12, ggml_is_contiguous(gf->nodes[i]->src1));
fprintf(stderr, "mul_mat: dst - %s[%lld, %lld, %lld]\n", ggml_type_name(dstt), ne0, ne1, ne2);
fprintf(stderr, "mul_mat: %s * %s -> %s\n", ggml_type_name(src0t), ggml_type_name(src1t), ggml_type_name(dstt));
GGML_ASSERT(ne00 == ne10);
GGML_ASSERT(ne02 == ne12);
if ((src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
if (ggml_is_contiguous(gf->nodes[i]->src0) &&
ggml_is_contiguous(gf->nodes[i]->src1) &&
(src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
if (encoder != nil) {
[encoder endEncoding];
encoder = nil;
@ -555,25 +555,52 @@ int llama_mtl_eval(
encoder = [command_buffer computeCommandEncoder];
}
int nth = 32;
// use custom matrix x vector kernel
switch (src0t) {
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; break;
case GGML_TYPE_Q4_0:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth = 4;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
} break;
case GGML_TYPE_F16:
{
GGML_ASSERT(ne02 == ne12);
nth = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
} break;
default: GGML_ASSERT(false && "not implemented");
};
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:5];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:6];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:7];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:8];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
if (src0t == GGML_TYPE_Q4_0) {
[encoder setThreadgroupMemoryLength:16*nth*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth, 16, 1)];
} else {
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
}
} break;
case GGML_OP_GET_ROWS:

View file

@ -265,7 +265,10 @@ kernel void kernel_mul_mat_q4_0_f32(
device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
device const float * y = (device const float *) src1 + r1*ne10;
sum[tpitg.x] = 0.0f;
const uint nth = tptg.x*tptg.y;
const uint ith = 16*tpitg.x + tpitg.y;
sum[ith] = 0.0f;
for (int i = tpitg.x; i < nb; i += tptg.x) {
device const uchar * x0p = (device const uchar *) (x + i)->qs;
@ -273,7 +276,9 @@ kernel void kernel_mul_mat_q4_0_f32(
float acc = 0.0f;
for (int j = 0; j < 16; ++j) {
//for (int j = 0; j < 16; ++j) {
const int j = tpitg.y;
{
const uchar x0v = *(x0p + j);
const int x0 = x0v & 0x0F;
@ -285,43 +290,50 @@ kernel void kernel_mul_mat_q4_0_f32(
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
}
sum[tpitg.x] += acc * (x + i)->d;
sum[ith] += acc * (x + i)->d;
}
// accumulate the sum from all threads in the threadgroup
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = tptg.x/2; i > 0; i /= 2) {
if (tpitg.x < i) {
sum[tpitg.x] += sum[tpitg.x + i];
for (uint i = nth/2; i > 0; i /= 2) {
if (ith < i) {
sum[ith] += sum[ith + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
if (tpitg.x == 0) {
if (ith == 0) {
dst[r1*ne0 + r0] = sum[0];
}
}
kernel void kernel_mul_mat_f16_f32(
device const half * src0,
device const float * src1,
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
threadgroup float * sum [[threadgroup(0)]],
uint2 tgpig[[threadgroup_position_in_grid]],
uint2 tpig[[thread_position_in_grid]],
uint2 tpitg[[thread_position_in_threadgroup]],
uint2 tptg[[threads_per_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpig[[thread_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]]) {
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
const int64_t im = tgpig.z;
device const half * x = src0 + r0*ne00;
device const float * y = src1 + r1*ne10;
device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02);
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
sum[tpitg.x] = 0.0f;
@ -339,7 +351,7 @@ kernel void kernel_mul_mat_f16_f32(
}
if (tpitg.x == 0) {
dst[r1*ne0 + r0] = sum[0];
dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
}
}

4
ggml.c
View file

@ -3821,11 +3821,11 @@ size_t ggml_tensor_overhead(void) {
return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE + 16;
}
static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
bool ggml_is_transposed(const struct ggml_tensor * tensor) {
return tensor->nb[0] > tensor->nb[1];
}
static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return

3
ggml.h
View file

@ -442,6 +442,9 @@ extern "C" {
// TODO: temporary until model loading of ggml examples is refactored
GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor);
// use this to compute the memory overhead of a tensor
GGML_API size_t ggml_tensor_overhead(void);