metal: very slightly faster TG for Q5_K

This commit is contained in:
Iwan Kawrakow 2023-09-11 11:07:05 +02:00
parent b42dfdcd89
commit f34783d326
2 changed files with 44 additions and 26 deletions

View file

@ -15,6 +15,7 @@
#include <sstream>
#include <string>
#include <vector>
#include <thread>
#include "ggml.h"
#include "llama.h"
@ -143,6 +144,7 @@ struct cmd_params {
std::vector<bool> low_vram;
std::vector<std::array<float, LLAMA_MAX_DEVICES>> tensor_split;
int reps;
int sleep;
bool verbose;
output_formats output_format;
};
@ -160,6 +162,7 @@ static const cmd_params cmd_params_defaults = {
/* low_vram */ {false},
/* tensor_split */ {{}},
/* reps */ 5,
/* sleep */ 0,
/* verbose */ false,
/* output_format */ MARKDOWN
};
@ -181,6 +184,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -mmq, --mul-mat-q <0|1> (default: %s)\n", join(cmd_params_defaults.mul_mat_q, ",").c_str());
printf(" -ts, --tensor_split <ts0/ts1/..> \n");
printf(" -r, --repetitions <n> (default: %d)\n", cmd_params_defaults.reps);
printf(" -s, --sleep <n ms> (default: %d)\n", cmd_params_defaults.sleep);
printf(" -o, --output <csv|json|md|sql> (default: %s)\n", cmd_params_defaults.output_format == CSV ? "csv" : cmd_params_defaults.output_format == JSON ? "json" : cmd_params_defaults.output_format == MARKDOWN ? "md" : "sql");
printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
printf("\n");
@ -305,6 +309,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
break;
}
params.reps = std::stoi(argv[i]);
} else if (arg == "-s" || arg == "--sleep") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.sleep = std::stoi(argv[i]);
} else if (arg == "-o" || arg == "--output") {
if (++i >= argc) {
invalid_param = true;
@ -1003,6 +1013,9 @@ int main(int argc, char ** argv) {
}
uint64_t t_ns = get_time_ns() - t_start;
t.samples_ns.push_back(t_ns);
if (i < params.reps-1 && params.sleep > 0) {
std::this_thread::sleep_for(std::chrono::milliseconds(params.sleep));
}
}
p->print_test(t);

View file

@ -1558,7 +1558,7 @@ kernel void kernel_mul_mat_q4_K_f32(
uint16_t sc16[4];
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
const float4 norm = {1.f, 1.f/16.f, 1.f, 1.f/16.f};
const float4 norm = {256.f, 16.f, 256.f, 16.f};
for (int ib = ix; ib < nb; ib += 4) {
@ -1596,10 +1596,9 @@ kernel void kernel_mul_mat_q4_K_f32(
}
float dall = dh[0];
float dall = dh[0] / 256.f;
float dmin = dh[1];
acc1 += acc2 / 256.f;
acc1 *= norm;
acc1 = acc1 * norm + acc2;
sumf[row] += dall * (acc1[0] * sc8[0] + acc1[1] * sc8[1] + acc1[2] * sc8[4] + acc1[3] * sc8[5]) -
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
@ -1741,7 +1740,7 @@ kernel void kernel_mul_mat_q5_K_f32(
#if QK_K == 256
#
float yl[16], yh[16];
float4 yl[8];
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
@ -1762,6 +1761,8 @@ kernel void kernel_mul_mat_q5_K_f32(
const uint8_t hm3 = hm1 << 4;
const uint8_t hm4 = hm2 << 4;
const float4 norm = {1/16.f, 1/256.f, 1/16.f, 1/256.f};
uint16_t sc16[4];
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
@ -1777,10 +1778,12 @@ kernel void kernel_mul_mat_q5_K_f32(
device const float * y2 = y1 + 128;
float4 sumy = {0.f, 0.f, 0.f, 0.f};
for (int l = 0; l < 8; ++l) {
yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
yl[l] = {y1[l], y1[l+32], y2[l], y2[l+32]};
sumy += yl[l];
//yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
//yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
//yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
//yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
}
for (int row = 0; row < 2; ++row) {
@ -1796,22 +1799,24 @@ kernel void kernel_mul_mat_q5_K_f32(
float4 acc2 = {0.f};
for (int l = 0; l < n; ++l) {
uint8_t h = qh[l];
acc1[0] += yl[l+0] * (q1[l] & 0x0F);
acc1[1] += yl[l+8] * (q1[l] & 0xF0);
acc1[2] += yh[l+0] * (q2[l] & 0x0F);
acc1[3] += yh[l+8] * (q2[l] & 0xF0);
acc2[0] += h & hm1 ? yl[l+0] : 0.f;
acc2[1] += h & hm2 ? yl[l+8] : 0.f;
acc2[2] += h & hm3 ? yh[l+0] : 0.f;
acc2[3] += h & hm4 ? yh[l+8] : 0.f;
//acc1[0] += yl[l+0] * (q1[l] & 0x0F);
//acc1[1] += yl[l+8] * (q1[l] & 0xF0);
//acc1[2] += yh[l+0] * (q2[l] & 0x0F);
//acc1[3] += yh[l+8] * (q2[l] & 0xF0);
acc1[0] += yl[l][0] * (q1[l] & 0x0F);
acc1[1] += yl[l][1] * (q1[l] & 0xF0);
acc1[2] += yl[l][2] * (q2[l] & 0x0F);
acc1[3] += yl[l][3] * (q2[l] & 0xF0);
acc2[0] += h & hm1 ? yl[l][0] : 0.f;
acc2[1] += h & hm2 ? yl[l][1] : 0.f;
acc2[2] += h & hm3 ? yl[l][2] : 0.f;
acc2[3] += h & hm4 ? yl[l][3] : 0.f;
}
const float dall = dh[0];
const float dmin = dh[1];
sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
sc8[4] * (acc1[2] + 16.f*acc2[2]) +
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
acc1 = acc1 * norm + acc2;
sumf[row] += dall * (acc1[0] * sc8[0] + acc1[1] * sc8[1] + acc1[2] * sc8[4] + acc1[3] * sc8[5]) * 16.f
- dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
q1 += step;
qh += step;
@ -1941,10 +1946,10 @@ kernel void kernel_mul_mat_q6_K_f32(
float4 sums = {0.f, 0.f, 0.f, 0.f};
for (int l = 0; l < n; ++l) {
sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
sums[0] += y[l+ 0] * ((int16_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
sums[1] += y[l+32] * ((int16_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
sums[2] += y[l+64] * ((int16_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
sums[3] += y[l+96] * ((int16_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
}
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);