metal: very slightly faster TG for Q5_K
This commit is contained in:
parent
b42dfdcd89
commit
f34783d326
2 changed files with 44 additions and 26 deletions
|
@ -15,6 +15,7 @@
|
|||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
|
||||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
|
@ -143,6 +144,7 @@ struct cmd_params {
|
|||
std::vector<bool> low_vram;
|
||||
std::vector<std::array<float, LLAMA_MAX_DEVICES>> tensor_split;
|
||||
int reps;
|
||||
int sleep;
|
||||
bool verbose;
|
||||
output_formats output_format;
|
||||
};
|
||||
|
@ -160,6 +162,7 @@ static const cmd_params cmd_params_defaults = {
|
|||
/* low_vram */ {false},
|
||||
/* tensor_split */ {{}},
|
||||
/* reps */ 5,
|
||||
/* sleep */ 0,
|
||||
/* verbose */ false,
|
||||
/* output_format */ MARKDOWN
|
||||
};
|
||||
|
@ -181,6 +184,7 @@ static void print_usage(int /* argc */, char ** argv) {
|
|||
printf(" -mmq, --mul-mat-q <0|1> (default: %s)\n", join(cmd_params_defaults.mul_mat_q, ",").c_str());
|
||||
printf(" -ts, --tensor_split <ts0/ts1/..> \n");
|
||||
printf(" -r, --repetitions <n> (default: %d)\n", cmd_params_defaults.reps);
|
||||
printf(" -s, --sleep <n ms> (default: %d)\n", cmd_params_defaults.sleep);
|
||||
printf(" -o, --output <csv|json|md|sql> (default: %s)\n", cmd_params_defaults.output_format == CSV ? "csv" : cmd_params_defaults.output_format == JSON ? "json" : cmd_params_defaults.output_format == MARKDOWN ? "md" : "sql");
|
||||
printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
|
||||
printf("\n");
|
||||
|
@ -305,6 +309,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
|||
break;
|
||||
}
|
||||
params.reps = std::stoi(argv[i]);
|
||||
} else if (arg == "-s" || arg == "--sleep") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.sleep = std::stoi(argv[i]);
|
||||
} else if (arg == "-o" || arg == "--output") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -1003,6 +1013,9 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
uint64_t t_ns = get_time_ns() - t_start;
|
||||
t.samples_ns.push_back(t_ns);
|
||||
if (i < params.reps-1 && params.sleep > 0) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(params.sleep));
|
||||
}
|
||||
}
|
||||
|
||||
p->print_test(t);
|
||||
|
|
|
@ -1558,7 +1558,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|||
uint16_t sc16[4];
|
||||
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
||||
|
||||
const float4 norm = {1.f, 1.f/16.f, 1.f, 1.f/16.f};
|
||||
const float4 norm = {256.f, 16.f, 256.f, 16.f};
|
||||
|
||||
for (int ib = ix; ib < nb; ib += 4) {
|
||||
|
||||
|
@ -1596,10 +1596,9 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|||
|
||||
}
|
||||
|
||||
float dall = dh[0];
|
||||
float dall = dh[0] / 256.f;
|
||||
float dmin = dh[1];
|
||||
acc1 += acc2 / 256.f;
|
||||
acc1 *= norm;
|
||||
acc1 = acc1 * norm + acc2;
|
||||
sumf[row] += dall * (acc1[0] * sc8[0] + acc1[1] * sc8[1] + acc1[2] * sc8[4] + acc1[3] * sc8[5]) -
|
||||
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
||||
|
||||
|
@ -1741,7 +1740,7 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|||
|
||||
#if QK_K == 256
|
||||
#
|
||||
float yl[16], yh[16];
|
||||
float4 yl[8];
|
||||
|
||||
const uint16_t kmask1 = 0x3f3f;
|
||||
const uint16_t kmask2 = 0x0f0f;
|
||||
|
@ -1762,6 +1761,8 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|||
const uint8_t hm3 = hm1 << 4;
|
||||
const uint8_t hm4 = hm2 << 4;
|
||||
|
||||
const float4 norm = {1/16.f, 1/256.f, 1/16.f, 1/256.f};
|
||||
|
||||
uint16_t sc16[4];
|
||||
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
||||
|
||||
|
@ -1777,10 +1778,12 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|||
device const float * y2 = y1 + 128;
|
||||
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
|
||||
yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
|
||||
yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
|
||||
yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
|
||||
yl[l] = {y1[l], y1[l+32], y2[l], y2[l+32]};
|
||||
sumy += yl[l];
|
||||
//yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
|
||||
//yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
|
||||
//yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
|
||||
//yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
|
||||
}
|
||||
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
|
@ -1796,22 +1799,24 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|||
float4 acc2 = {0.f};
|
||||
for (int l = 0; l < n; ++l) {
|
||||
uint8_t h = qh[l];
|
||||
acc1[0] += yl[l+0] * (q1[l] & 0x0F);
|
||||
acc1[1] += yl[l+8] * (q1[l] & 0xF0);
|
||||
acc1[2] += yh[l+0] * (q2[l] & 0x0F);
|
||||
acc1[3] += yh[l+8] * (q2[l] & 0xF0);
|
||||
acc2[0] += h & hm1 ? yl[l+0] : 0.f;
|
||||
acc2[1] += h & hm2 ? yl[l+8] : 0.f;
|
||||
acc2[2] += h & hm3 ? yh[l+0] : 0.f;
|
||||
acc2[3] += h & hm4 ? yh[l+8] : 0.f;
|
||||
//acc1[0] += yl[l+0] * (q1[l] & 0x0F);
|
||||
//acc1[1] += yl[l+8] * (q1[l] & 0xF0);
|
||||
//acc1[2] += yh[l+0] * (q2[l] & 0x0F);
|
||||
//acc1[3] += yh[l+8] * (q2[l] & 0xF0);
|
||||
acc1[0] += yl[l][0] * (q1[l] & 0x0F);
|
||||
acc1[1] += yl[l][1] * (q1[l] & 0xF0);
|
||||
acc1[2] += yl[l][2] * (q2[l] & 0x0F);
|
||||
acc1[3] += yl[l][3] * (q2[l] & 0xF0);
|
||||
acc2[0] += h & hm1 ? yl[l][0] : 0.f;
|
||||
acc2[1] += h & hm2 ? yl[l][1] : 0.f;
|
||||
acc2[2] += h & hm3 ? yl[l][2] : 0.f;
|
||||
acc2[3] += h & hm4 ? yl[l][3] : 0.f;
|
||||
}
|
||||
const float dall = dh[0];
|
||||
const float dmin = dh[1];
|
||||
sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
|
||||
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
|
||||
sc8[4] * (acc1[2] + 16.f*acc2[2]) +
|
||||
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
|
||||
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
||||
acc1 = acc1 * norm + acc2;
|
||||
sumf[row] += dall * (acc1[0] * sc8[0] + acc1[1] * sc8[1] + acc1[2] * sc8[4] + acc1[3] * sc8[5]) * 16.f
|
||||
- dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
||||
|
||||
q1 += step;
|
||||
qh += step;
|
||||
|
@ -1941,10 +1946,10 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|||
|
||||
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int l = 0; l < n; ++l) {
|
||||
sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
||||
sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
||||
sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
|
||||
sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
||||
sums[0] += y[l+ 0] * ((int16_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
||||
sums[1] += y[l+32] * ((int16_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
||||
sums[2] += y[l+64] * ((int16_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
|
||||
sums[3] += y[l+96] * ((int16_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
||||
}
|
||||
|
||||
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue