metal: minor q4 optimization and reduce code size (#2248)

* metal: use uint16_t instead of uint8_t.

Apple GPU doesn't like uint8_t. For every operation on uint8_t
the gpu need to copy the uint8_t to an empty 16 bit register, then
it can issue other instructions.

For the matrix-vector multiplication kernel only, we observed a
340~350 GB/s memory read speed on M1 Max after this commit, which is
very close to the reported hardware limit.

* metal: update rms_norm kernel

This commit double the speed of rms_norm operations by using 512 threads
per threadgroup, combining with SIMD primitives to minimize the need for
thread group barriers.

* metal: use template to reduce size

Revert modifications on block_q4_0 and block_q4_1.
This commit is contained in:
Shouzheng Liu 2023-07-20 06:32:22 -04:00 committed by GitHub
parent 294f424554
commit 417a85a001
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 110 additions and 164 deletions

View file

@ -792,7 +792,7 @@ void ggml_metal_graph_compute(
const float eps = 1e-6f;
const int nth = 256;
const int nth = 512;
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -800,7 +800,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
[encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
const int64_t nrows = ggml_nrows(src0);