Ported mat mul from Metal

This commit is contained in:
niansa 2023-07-05 13:28:40 +02:00
parent 2fc8249ba3
commit 6be93e6071

View file

@ -173,7 +173,7 @@ static std::vector<uint32_t> glsl_compile_source(const std::string& source, cons
std::ofstream fileOut("tmp_kp_shader.comp"); std::ofstream fileOut("tmp_kp_shader.comp");
fileOut << source; fileOut << source;
fileOut.close(); fileOut.close();
if (system(std::string("glslangValidator -V tmp_kp_shader.comp -o tmp_kp_shader.comp.spv > /dev/null").c_str())) if (system("glslangValidator -V tmp_kp_shader.comp -o tmp_kp_shader.comp.spv > /dev/null"))
throw std::runtime_error("Error running glslangValidator command"); throw std::runtime_error("Error running glslangValidator command");
std::ifstream fileStream("tmp_kp_shader.comp.spv", std::ios::binary); std::ifstream fileStream("tmp_kp_shader.comp.spv", std::ios::binary);
std::vector<char> buffer; std::vector<char> buffer;
@ -883,131 +883,59 @@ void ggml_vk_diag_mask_inf(kp::Sequence& seq,
static const std::string program_mul_mat_f16 = static const std::string program_mul_mat_f16 =
MULTILINE_QUOTE( MULTILINE_QUOTE(
layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1) in; layout(local_size_x = 64) in;
layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; }; layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
layout (binding = 1) readonly buffer tensorInB { float16_t inB[]; }; layout (binding = 1) readonly buffer tensorInB { float inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter { layout (push_constant) uniform parameter {
int M; int64_t ne00;
int N; int64_t ne01;
int K; uint64_t nb00;
int inAStride; uint64_t nb01;
int inBStride; uint64_t nb02;
int outStride; int64_t ne10;
int64_t ne11;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
int64_t ne0;
int64_t ne1;
uint inAOff; uint inAOff;
uint inBOff; uint inBOff;
uint outOff; uint outOff;
} pcs; } pcs;
shared float16_t bufA[BM * (BK+1)]; shared float sum[gl_WorkGroupSize.x];
shared float16_t bufB[BN * (BK+1)];
void main() { void main() {
const int ir = int(gl_WorkGroupID.x); const int64_t r0 = gl_GlobalInvocationID.x;
const int ic = int(gl_WorkGroupID.y); const int64_t r1 = gl_GlobalInvocationID.y;
const int64_t im = gl_GlobalInvocationID.z;
const int rstride = BM / TM; const uint x = uint((r0*pcs.nb01 + im*pcs.nb02) / 2); // Based from inA
const uint y = uint((r1*pcs.nb11 + im*pcs.nb12) / 4); // based from inB
const int lr = int(gl_LocalInvocationID.x % rstride); sum[gl_WorkGroupID.x] = 0.0f;
const int lc = int(gl_LocalInvocationID.x / rstride);
const int loadr = int(gl_LocalInvocationID.x % BK); for (uint i = gl_WorkGroupID.x; i < pcs.ne00; i += gl_WorkGroupSize.x) {
const int loadc = int(gl_LocalInvocationID.x / BK); sum[gl_WorkGroupID.x] += float(inA[x+i]) * float(inB[y+i]);
const int loadstride = int(gl_WorkGroupSize.x);
int posA = ir * BM * pcs.inAStride;
int posB = ic * BN * pcs.inBStride;
float sums[TM * TN];
float16_t cacheA[TM];
float16_t cacheB[TN];
[[unroll]] for (int i = 0; i < TM*TN; i++) {
sums[i] = 0.0hf;
}
[[unroll]] for (int block = 0; block < pcs.K; block += BK) {
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
const int lr = l % BK;
const int lc = l / BK;
bufA[(loadc + lc) * (BK+1) + loadr + lr] = inA[posA + (loadc + lc) * pcs.inAStride + loadr + lr];
}
[[unroll]] for (int l = 0; l < BN * BK; l += loadstride) {
const int lr = l % BK;
const int lc = l / BK;
bufB[(loadc + lc) * (BK+1) + loadr + lr] = inB[posB + (loadc + lc) * pcs.inBStride + loadr + lr];
} }
// accumulate the sum from all threads in the threadgroup
barrier(); barrier();
memoryBarrierShared();
posA += BK; for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
posB += BK; if (gl_WorkGroupID.x < i) {
sum[gl_WorkGroupID.x] += sum[gl_WorkGroupID.x + i];
[[unroll]] for (int i = 0; i < BK; i++) {
// Load from shared into cache
[[unroll]] for (int j = 0; j < BM; j++) {
cacheA[j] = bufA[(lr + j*rstride) * (BK+1) + i];
} }
[[unroll]] for (int j = 0; j < TN; j++) {
cacheB[j] = bufB[(lc * TN + j) * (BK+1) + i];
}
[[unroll]] for (int cc = 0; cc < TN; cc++) {
[[unroll]] for (int cr = 0; cr < TM; cr++) {
sums[cc * TM + cr] += float(cacheA[cr]) * float(cacheB[cc]);
}
}
}
barrier(); barrier();
memoryBarrierShared();
} }
const int dr = ir * BM + lr; if (gl_WorkGroupID.x == 0) {
const int dc = ic * BN + lc * TN; out_[uint(im*pcs.ne1*pcs.ne0 + r1*pcs.ne0 + r0)] = sum[0];
[[unroll]] for (int cc = 0; cc < TN; cc++) {
[[unroll]] for (int cr = 0; cr < TM; cr++) {
out_[(dc + cc) * pcs.outStride + dr + cr*rstride] = sums[cc * TM + cr];
}
}
}
);
static const std::string program_fast_mul_mat_f16 =
MULTILINE_QUOTE(
layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in;
layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
layout (binding = 1) readonly buffer tensorInB { float16_t inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
int M;
int N;
int K;
int inAStride;
int inBStride;
int outStride;
uint inAOff;
uint inBOff;
uint outOff;
} pcs;
void main() {
int row = int(gl_GlobalInvocationID.x);
int col = int(gl_GlobalInvocationID.y);
if (row < pcs.M && col < pcs.N) {
float sum = 0.0f;
for (int i = 0; i < pcs.K; i++) {
sum += float(inA[row * pcs.inAStride + i]) * float(inB[col * pcs.inBStride + i]);
}
out_[col * pcs.outStride + row] = sum;
} }
} }
); );
@ -1016,48 +944,26 @@ void ggml_vk_mul_mat_f16(kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& inA, uint32_t inAOff, const std::shared_ptr<kp::Tensor>& inA, uint32_t inAOff,
const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff, const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff,
const std::shared_ptr<kp::Tensor>& out, uint32_t outOff, const std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
int64_t ne00, int64_t ne01, int64_t ne02, uint64_t ne03, int64_t ne00, int64_t ne01,
int64_t ne10, int64_t ne11, uint64_t nb00, uint64_t nb01, uint64_t nb02,
int nb10, int nb11, int nb12, int nb13, int64_t ne10, int64_t ne11, int64_t ne12,
int nb2, int nb3) { uint64_t nb10, uint64_t nb11, uint64_t nb12,
const static auto spirv = glsl_compile_source(program_source_head+program_fast_mul_mat_f16, __func__); int64_t ne0, int64_t ne1) {
const static auto spirv = glsl_compile_source(program_source_head+program_mul_mat_f16, __func__);
const bool inB_cont_rows = nb10 == sizeof(float);
const bool inB_cont_cols = (size_t)nb11 == ne11 * sizeof(float);
struct PushConstants { struct PushConstants {
int32_t M, N, K, inAStride, inBStride, outStride; int64_t ne00, ne01;
uint64_t nb00, nb01, nb02;
int64_t ne10, ne11;
uint64_t nb10, nb11, nb12;
int64_t ne0, ne1;
uint32_t inAOff, inBOff, outOff; uint32_t inAOff, inBOff, outOff;
} pushConsts { } pushConsts {
(int)ne01, (int)ne11, (int)ne10, (int)ne00, (int)ne10, (int)ne01, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, ne0, ne1,
inAOff, inBOff, outOff inAOff, inBOff, outOff
}; };
for (int64_t i03 = 0; i03 < ne03; i03++) { seq.record<kp::OpAlgoDispatch>(mgr.algorithm<float, PushConstants>({inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne11), unsigned(ne12)}, {}, {pushConsts}));
for (int64_t i02 = 0; i02 < ne02; i02++) {
auto tmp = mgr.tensorT<half>(std::vector<half>(ne10*ne11));
if (inB_cont_rows) {
if (inB_cont_cols) {
ggml_vk_fp32_to_fp16_row(seq, inB, (i03*nb13 + i02*nb12)/sizeof(float), tmp, 0, ne10*ne11);
}
else {
for (int64_t i01 = 0; i01 < ne11; i01++) {
ggml_vk_fp32_to_fp16_row(seq, inB, (i03*nb13 + i02*nb12 + i01*nb11)/sizeof(float), tmp, i01*ne10, ne10);
}
}
} else {
for (int64_t i01 = 0; i01 < ne11; i01++) {
for (int64_t i00 = 0; i00 < ne10; i00++) {
// Extremely slow because of single shader invocation
ggml_vk_fp32_to_fp16_row(seq, inB, (i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10)/sizeof(float), tmp, i01*ne10 + i00, 1);
}
}
}
seq.record<kp::OpAlgoDispatch>(mgr.algorithm<float, PushConstants>({inA, tmp, out}, spirv, {uint32_t(ne01/128), uint32_t(ne11/128)}, {}, {pushConsts}));
}
}
} }
@ -1179,7 +1085,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
{ {
if (src0->type == GGML_TYPE_F16 if (src0->type == GGML_TYPE_F16
&& src1->type == GGML_TYPE_F32) { && src1->type == GGML_TYPE_F32) {
ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02, ne03, ne10, ne11, nb10, nb11, nb12, nb13, nb2, nb3); ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1);
break; break;
} }
} }