metal: fix performance degradation from gqa
Integers are slow on the GPU, and 64-bit divides are extremely slow. In the context of GQA, we introduce a 64-bit divide that cannot be optimized out by the compiler, which results in a decrease of ~8% in inference performance. This commit fixes that issue by calculating a part of the offset with a 32-bit divide. Naturally, this limits the size of a single matrix to ~4GB. However, this limitation should suffice for the near future.
This commit is contained in:
parent
5f6de2a2bb
commit
bfa455de43
2 changed files with 37 additions and 19 deletions
|
@ -712,6 +712,7 @@ void ggml_metal_graph_compute(
|
||||||
|
|
||||||
GGML_ASSERT(ne00 == ne10);
|
GGML_ASSERT(ne00 == ne10);
|
||||||
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
|
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
|
||||||
|
uint gqa = ne12/ne02;
|
||||||
GGML_ASSERT(ne03 == ne13);
|
GGML_ASSERT(ne03 == ne13);
|
||||||
|
|
||||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||||
|
@ -743,6 +744,7 @@ void ggml_metal_graph_compute(
|
||||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
||||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
|
||||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
|
||||||
|
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
|
||||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||||
}
|
}
|
||||||
|
@ -845,6 +847,7 @@ void ggml_metal_graph_compute(
|
||||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
||||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
||||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
||||||
|
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
||||||
|
|
||||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
||||||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
||||||
|
|
|
@ -343,14 +343,15 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
|
||||||
// N_DST, so this is another explicit assumption of the implementation.
|
// N_DST, so this is another explicit assumption of the implementation.
|
||||||
template<typename block_q_type, int nr, int nsg, int nw>
|
template<typename block_q_type, int nr, int nsg, int nw>
|
||||||
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
|
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
|
||||||
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0,
|
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, uint gqa,
|
||||||
uint3 tgpig, uint tiisg, uint sgitg) {
|
uint3 tgpig, uint tiisg, uint sgitg) {
|
||||||
const int nb = ne00/QK4_0;
|
const int nb = ne00/QK4_0;
|
||||||
const int r0 = tgpig.x;
|
const int r0 = tgpig.x;
|
||||||
const int r1 = tgpig.y;
|
const int r1 = tgpig.y;
|
||||||
const int im = tgpig.z;
|
const int im = tgpig.z;
|
||||||
const int first_row = (r0 * nsg + sgitg) * nr;
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
||||||
device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb + im/(ne12/ne02)*(ne02/QK4_0);
|
const uint offset0 = first_row * nb + im/gqa*(ne02/QK4_0);
|
||||||
|
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
||||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne12;
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne12;
|
||||||
float yl[16]; // src1 vector cache
|
float yl[16]; // src1 vector cache
|
||||||
float sumf[nr]={0.f};
|
float sumf[nr]={0.f};
|
||||||
|
@ -397,10 +398,11 @@ kernel void kernel_mul_mat_q4_0_f32(
|
||||||
constant int64_t & ne10[[buffer(9)]],
|
constant int64_t & ne10[[buffer(9)]],
|
||||||
constant int64_t & ne12[[buffer(11)]],
|
constant int64_t & ne12[[buffer(11)]],
|
||||||
constant int64_t & ne0[[buffer(15)]],
|
constant int64_t & ne0[[buffer(15)]],
|
||||||
|
constant uint & gqa[[buffer(17)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,tgpig,tiisg,sgitg);
|
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mat_q4_1_f32(
|
kernel void kernel_mul_mat_q4_1_f32(
|
||||||
|
@ -413,10 +415,11 @@ kernel void kernel_mul_mat_q4_1_f32(
|
||||||
constant int64_t & ne10[[buffer(9)]],
|
constant int64_t & ne10[[buffer(9)]],
|
||||||
constant int64_t & ne12[[buffer(11)]],
|
constant int64_t & ne12[[buffer(11)]],
|
||||||
constant int64_t & ne0[[buffer(15)]],
|
constant int64_t & ne0[[buffer(15)]],
|
||||||
|
constant uint & gqa[[buffer(17)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,tgpig,tiisg,sgitg);
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mat_f16_f32(
|
kernel void kernel_mul_mat_f16_f32(
|
||||||
|
@ -797,6 +800,7 @@ kernel void kernel_mul_mat_q2_K_f32(
|
||||||
constant int64_t & ne10[[buffer(9)]],
|
constant int64_t & ne10[[buffer(9)]],
|
||||||
constant int64_t & ne12[[buffer(11)]],
|
constant int64_t & ne12[[buffer(11)]],
|
||||||
constant int64_t & ne0[[buffer(15)]],
|
constant int64_t & ne0[[buffer(15)]],
|
||||||
|
constant uint & gqa[[buffer(17)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
@ -808,7 +812,8 @@ kernel void kernel_mul_mat_q2_K_f32(
|
||||||
|
|
||||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
||||||
const int ib_row = first_row * nb;
|
const int ib_row = first_row * nb;
|
||||||
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + r2/(ne12/ne02)*(ne02/QK_K);
|
const uint offset0 = r2/gqa*(ne02/QK_K);
|
||||||
|
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
|
||||||
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
|
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
|
||||||
float yl[32];
|
float yl[32];
|
||||||
float sumf[N_DST]={0.f}, all_sum;
|
float sumf[N_DST]={0.f}, all_sum;
|
||||||
|
@ -938,6 +943,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
constant int64_t & ne10[[buffer(9)]],
|
constant int64_t & ne10[[buffer(9)]],
|
||||||
constant int64_t & ne12[[buffer(11)]],
|
constant int64_t & ne12[[buffer(11)]],
|
||||||
constant int64_t & ne0[[buffer(15)]],
|
constant int64_t & ne0[[buffer(15)]],
|
||||||
|
constant uint & gqa[[buffer(17)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
@ -946,11 +952,11 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
|
|
||||||
const int64_t r0 = tgpig.x;
|
const int64_t r0 = tgpig.x;
|
||||||
const int64_t r1 = tgpig.y;
|
const int64_t r1 = tgpig.y;
|
||||||
const int64_t r2 = tgpig.x;
|
const int64_t r2 = tgpig.z;
|
||||||
|
|
||||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
||||||
|
const uint offset0 = r2/gqa*(ne02/QK_K);
|
||||||
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + r2/(ne12/ne02)*(ne02/QK_K);
|
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
||||||
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02;
|
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02;
|
||||||
|
|
||||||
float yl[16];
|
float yl[16];
|
||||||
|
@ -1054,6 +1060,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
constant int64_t & ne10[[buffer(9)]],
|
constant int64_t & ne10[[buffer(9)]],
|
||||||
constant int64_t & ne12[[buffer(11)]],
|
constant int64_t & ne12[[buffer(11)]],
|
||||||
constant int64_t & ne0[[buffer(15)]],
|
constant int64_t & ne0[[buffer(15)]],
|
||||||
|
constant uint & gqa[[buffer(17)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
@ -1062,11 +1069,11 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
|
|
||||||
const int64_t r0 = tgpig.x;
|
const int64_t r0 = tgpig.x;
|
||||||
const int64_t r1 = tgpig.y;
|
const int64_t r1 = tgpig.y;
|
||||||
const int64_t r2 = tgpig.x;
|
const int64_t r2 = tgpig.z;
|
||||||
|
|
||||||
const int row = 2 * r0 + sgitg;
|
const int row = 2 * r0 + sgitg;
|
||||||
|
const uint offset0 = r2/gqa*(ne02/QK_K);
|
||||||
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + r2/(ne12/ne02)*(ne02/QK_K);
|
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
|
||||||
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02;
|
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02;
|
||||||
const int ix = tiisg/4;
|
const int ix = tiisg/4;
|
||||||
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
||||||
|
@ -1123,6 +1130,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
||||||
constant int64_t & ne10[[buffer(9)]],
|
constant int64_t & ne10[[buffer(9)]],
|
||||||
constant int64_t & ne12[[buffer(11)]],
|
constant int64_t & ne12[[buffer(11)]],
|
||||||
constant int64_t & ne0[[buffer(15)]],
|
constant int64_t & ne0[[buffer(15)]],
|
||||||
|
constant uint & gqa[[buffer(17)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
@ -1142,7 +1150,8 @@ kernel void kernel_mul_mat_q4_K_f32(
|
||||||
const int r2 = tgpig.z;
|
const int r2 = tgpig.z;
|
||||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
||||||
const int ib_row = first_row * nb;
|
const int ib_row = first_row * nb;
|
||||||
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + r2/(ne12/ne02)*(ne02/QK_K);
|
const uint offset0 = r2/gqa*(ne02/QK_K);
|
||||||
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
||||||
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
|
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
|
||||||
float yl[16];
|
float yl[16];
|
||||||
float yh[16];
|
float yh[16];
|
||||||
|
@ -1225,6 +1234,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
||||||
constant int64_t & ne10[[buffer(9)]],
|
constant int64_t & ne10[[buffer(9)]],
|
||||||
constant int64_t & ne12[[buffer(11)]],
|
constant int64_t & ne12[[buffer(11)]],
|
||||||
constant int64_t & ne0[[buffer(15)]],
|
constant int64_t & ne0[[buffer(15)]],
|
||||||
|
constant uint & gqa[[buffer(17)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
@ -1238,7 +1248,8 @@ kernel void kernel_mul_mat_q4_K_f32(
|
||||||
const int r2 = tgpig.z;
|
const int r2 = tgpig.z;
|
||||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
||||||
const int ib_row = first_row * nb;
|
const int ib_row = first_row * nb;
|
||||||
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + r2/(ne12/ne02)*(ne02/QK_K);
|
const uint offset0 = r2/gqa*(ne02/QK_K);
|
||||||
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
||||||
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
|
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
|
||||||
float yl[8];
|
float yl[8];
|
||||||
float yh[8];
|
float yh[8];
|
||||||
|
@ -1311,6 +1322,7 @@ kernel void kernel_mul_mat_q5_K_f32(
|
||||||
constant int64_t & ne10[[buffer(9)]],
|
constant int64_t & ne10[[buffer(9)]],
|
||||||
constant int64_t & ne12[[buffer(11)]],
|
constant int64_t & ne12[[buffer(11)]],
|
||||||
constant int64_t & ne0[[buffer(15)]],
|
constant int64_t & ne0[[buffer(15)]],
|
||||||
|
constant uint & gqa[[buffer(17)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
@ -1322,8 +1334,8 @@ kernel void kernel_mul_mat_q5_K_f32(
|
||||||
const int r2 = tgpig.z;
|
const int r2 = tgpig.z;
|
||||||
|
|
||||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
||||||
|
const uint offset0 = r2/gqa*(ne02/QK_K);
|
||||||
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + r2/(ne12/ne02)*(ne02/QK_K);
|
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
|
||||||
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12;
|
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12;
|
||||||
|
|
||||||
float sumf[2]={0.f};
|
float sumf[2]={0.f};
|
||||||
|
@ -1474,6 +1486,7 @@ kernel void kernel_mul_mat_q6_K_f32(
|
||||||
constant int64_t & ne10[[buffer(9)]],
|
constant int64_t & ne10[[buffer(9)]],
|
||||||
constant int64_t & ne12[[buffer(11)]],
|
constant int64_t & ne12[[buffer(11)]],
|
||||||
constant int64_t & ne0[[buffer(15)]],
|
constant int64_t & ne0[[buffer(15)]],
|
||||||
|
constant uint & gqa[[buffer(17)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
@ -1490,8 +1503,8 @@ kernel void kernel_mul_mat_q6_K_f32(
|
||||||
const int r2 = tgpig.z;
|
const int r2 = tgpig.z;
|
||||||
|
|
||||||
const int row = 2 * r0 + sgitg;
|
const int row = 2 * r0 + sgitg;
|
||||||
|
const uint offset0 = r2/gqa*(ne02/QK_K);
|
||||||
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + r2/(ne12/ne02)*(ne02/QK_K);
|
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
|
||||||
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12;
|
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12;
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
|
@ -1792,6 +1805,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||||
constant int64_t & ne12,
|
constant int64_t & ne12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
|
constant uint & gqa,
|
||||||
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiitg[[thread_index_in_threadgroup]],
|
uint tiitg[[thread_index_in_threadgroup]],
|
||||||
|
@ -1818,7 +1832,8 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||||
}
|
}
|
||||||
|
|
||||||
short il = (tiitg % THREAD_PER_ROW);
|
short il = (tiitg % THREAD_PER_ROW);
|
||||||
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + im/(ne12/ne02)*nb02) + il/nl;
|
uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
|
||||||
|
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
||||||
device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
|
device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
|
||||||
+ BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne12;
|
+ BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne12;
|
||||||
|
|
||||||
|
@ -1909,7 +1924,7 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows
|
||||||
|
|
||||||
typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
|
typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
|
||||||
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
|
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
|
||||||
constant int64_t &, constant int64_t &, threadgroup uchar *, uint3, uint, uint);
|
constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
|
||||||
|
|
||||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
||||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue