From f34783d326d7d8c4ff29b95c7fe3cfd2dc009d6d Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 11 Sep 2023 11:07:05 +0200 Subject: [PATCH] metal: very slightly faster TG for Q5_K --- examples/llama-bench/llama-bench.cpp | 13 +++++++ ggml-metal.metal | 57 +++++++++++++++------------- 2 files changed, 44 insertions(+), 26 deletions(-) diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index dedaa34fd..6208daf16 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include "ggml.h" #include "llama.h" @@ -143,6 +144,7 @@ struct cmd_params { std::vector low_vram; std::vector> tensor_split; int reps; + int sleep; bool verbose; output_formats output_format; }; @@ -160,6 +162,7 @@ static const cmd_params cmd_params_defaults = { /* low_vram */ {false}, /* tensor_split */ {{}}, /* reps */ 5, + /* sleep */ 0, /* verbose */ false, /* output_format */ MARKDOWN }; @@ -181,6 +184,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -mmq, --mul-mat-q <0|1> (default: %s)\n", join(cmd_params_defaults.mul_mat_q, ",").c_str()); printf(" -ts, --tensor_split \n"); printf(" -r, --repetitions (default: %d)\n", cmd_params_defaults.reps); + printf(" -s, --sleep (default: %d)\n", cmd_params_defaults.sleep); printf(" -o, --output (default: %s)\n", cmd_params_defaults.output_format == CSV ? "csv" : cmd_params_defaults.output_format == JSON ? "json" : cmd_params_defaults.output_format == MARKDOWN ? "md" : "sql"); printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0"); printf("\n"); @@ -305,6 +309,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } params.reps = std::stoi(argv[i]); + } else if (arg == "-s" || arg == "--sleep") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.sleep = std::stoi(argv[i]); } else if (arg == "-o" || arg == "--output") { if (++i >= argc) { invalid_param = true; @@ -1003,6 +1013,9 @@ int main(int argc, char ** argv) { } uint64_t t_ns = get_time_ns() - t_start; t.samples_ns.push_back(t_ns); + if (i < params.reps-1 && params.sleep > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(params.sleep)); + } } p->print_test(t); diff --git a/ggml-metal.metal b/ggml-metal.metal index a7f9c9b6d..d5f72dba0 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1558,7 +1558,7 @@ kernel void kernel_mul_mat_q4_K_f32( uint16_t sc16[4]; thread const uint8_t * sc8 = (thread const uint8_t *)sc16; - const float4 norm = {1.f, 1.f/16.f, 1.f, 1.f/16.f}; + const float4 norm = {256.f, 16.f, 256.f, 16.f}; for (int ib = ix; ib < nb; ib += 4) { @@ -1596,10 +1596,9 @@ kernel void kernel_mul_mat_q4_K_f32( } - float dall = dh[0]; + float dall = dh[0] / 256.f; float dmin = dh[1]; - acc1 += acc2 / 256.f; - acc1 *= norm; + acc1 = acc1 * norm + acc2; sumf[row] += dall * (acc1[0] * sc8[0] + acc1[1] * sc8[1] + acc1[2] * sc8[4] + acc1[3] * sc8[5]) - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); @@ -1741,7 +1740,7 @@ kernel void kernel_mul_mat_q5_K_f32( #if QK_K == 256 # - float yl[16], yh[16]; + float4 yl[8]; const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; @@ -1762,6 +1761,8 @@ kernel void kernel_mul_mat_q5_K_f32( const uint8_t hm3 = hm1 << 4; const uint8_t hm4 = hm2 << 4; + const float4 norm = {1/16.f, 1/256.f, 1/16.f, 1/256.f}; + uint16_t sc16[4]; thread const uint8_t * sc8 = (thread const uint8_t *)sc16; @@ -1777,10 +1778,12 @@ kernel void kernel_mul_mat_q5_K_f32( device const float * y2 = y1 + 128; float4 sumy = {0.f, 0.f, 0.f, 0.f}; for (int l = 0; l < 8; ++l) { - yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; - yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; - yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; - yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; + yl[l] = {y1[l], y1[l+32], y2[l], y2[l+32]}; + sumy += yl[l]; + //yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; + //yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; + //yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; + //yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; } for (int row = 0; row < 2; ++row) { @@ -1796,22 +1799,24 @@ kernel void kernel_mul_mat_q5_K_f32( float4 acc2 = {0.f}; for (int l = 0; l < n; ++l) { uint8_t h = qh[l]; - acc1[0] += yl[l+0] * (q1[l] & 0x0F); - acc1[1] += yl[l+8] * (q1[l] & 0xF0); - acc1[2] += yh[l+0] * (q2[l] & 0x0F); - acc1[3] += yh[l+8] * (q2[l] & 0xF0); - acc2[0] += h & hm1 ? yl[l+0] : 0.f; - acc2[1] += h & hm2 ? yl[l+8] : 0.f; - acc2[2] += h & hm3 ? yh[l+0] : 0.f; - acc2[3] += h & hm4 ? yh[l+8] : 0.f; + //acc1[0] += yl[l+0] * (q1[l] & 0x0F); + //acc1[1] += yl[l+8] * (q1[l] & 0xF0); + //acc1[2] += yh[l+0] * (q2[l] & 0x0F); + //acc1[3] += yh[l+8] * (q2[l] & 0xF0); + acc1[0] += yl[l][0] * (q1[l] & 0x0F); + acc1[1] += yl[l][1] * (q1[l] & 0xF0); + acc1[2] += yl[l][2] * (q2[l] & 0x0F); + acc1[3] += yl[l][3] * (q2[l] & 0xF0); + acc2[0] += h & hm1 ? yl[l][0] : 0.f; + acc2[1] += h & hm2 ? yl[l][1] : 0.f; + acc2[2] += h & hm3 ? yl[l][2] : 0.f; + acc2[3] += h & hm4 ? yl[l][3] : 0.f; } const float dall = dh[0]; const float dmin = dh[1]; - sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + - sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + - sc8[4] * (acc1[2] + 16.f*acc2[2]) + - sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + acc1 = acc1 * norm + acc2; + sumf[row] += dall * (acc1[0] * sc8[0] + acc1[1] * sc8[1] + acc1[2] * sc8[4] + acc1[3] * sc8[5]) * 16.f + - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); q1 += step; qh += step; @@ -1941,10 +1946,10 @@ kernel void kernel_mul_mat_q6_K_f32( float4 sums = {0.f, 0.f, 0.f, 0.f}; for (int l = 0; l < n; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); - sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + sums[0] += y[l+ 0] * ((int16_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += y[l+32] * ((int16_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += y[l+64] * ((int16_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); + sums[3] += y[l+96] * ((int16_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); } sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);