Don't force aligned matmul
This commit is contained in:
parent
105fd199be
commit
9e97cb0baf
6 changed files with 402 additions and 90 deletions
2
Makefile
2
Makefile
|
@ -233,7 +233,9 @@ ifdef LLAMA_VULKAN
|
||||||
ggml-vulkan.o: ggml-vulkan.cpp ggml-vulkan.h
|
ggml-vulkan.o: ggml-vulkan.cpp ggml-vulkan.h
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f32.glsl -o vk_shaders/matmul_f32.spv
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f32.glsl -o vk_shaders/matmul_f32.spv
|
||||||
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f32_aligned.glsl -o vk_shaders/matmul_f32_aligned.spv
|
||||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16.glsl -o vk_shaders/matmul_f16.spv
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16.glsl -o vk_shaders/matmul_f16.spv
|
||||||
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16_aligned.glsl -o vk_shaders/matmul_f16_aligned.spv
|
||||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_split_k_reduce.glsl -o vk_shaders/matmul_split_k_reduce.spv
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_split_k_reduce.glsl -o vk_shaders/matmul_split_k_reduce.spv
|
||||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/f16_to_f32.glsl -o vk_shaders/f16_to_f32.spv
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/f16_to_f32.glsl -o vk_shaders/f16_to_f32.spv
|
||||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_q4_0.glsl -o vk_shaders/dequant_q4_0.spv
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_q4_0.glsl -o vk_shaders/dequant_q4_0.spv
|
||||||
|
|
101
ggml-vulkan.cpp
101
ggml-vulkan.cpp
|
@ -134,6 +134,7 @@ vk_queue vk_compute_queue;
|
||||||
vk_queue vk_transfer_queues[VK_TRANSFER_QUEUE_COUNT];
|
vk_queue vk_transfer_queues[VK_TRANSFER_QUEUE_COUNT];
|
||||||
VmaAllocator vk_allocator;
|
VmaAllocator vk_allocator;
|
||||||
vk_pipeline vk_pipeline_matmul_f32_l, vk_pipeline_matmul_f32_m, vk_pipeline_matmul_f32_s, vk_pipeline_matmul_f16_l, vk_pipeline_matmul_f16_m, vk_pipeline_matmul_f16_s;
|
vk_pipeline vk_pipeline_matmul_f32_l, vk_pipeline_matmul_f32_m, vk_pipeline_matmul_f32_s, vk_pipeline_matmul_f16_l, vk_pipeline_matmul_f16_m, vk_pipeline_matmul_f16_s;
|
||||||
|
vk_pipeline vk_pipeline_matmul_f32_aligned_l, vk_pipeline_matmul_f32_aligned_m, vk_pipeline_matmul_f32_aligned_s, vk_pipeline_matmul_f16_aligned_l, vk_pipeline_matmul_f16_aligned_m, vk_pipeline_matmul_f16_aligned_s;
|
||||||
vk_pipeline vk_pipeline_matmul_split_k_reduce;
|
vk_pipeline vk_pipeline_matmul_split_k_reduce;
|
||||||
vk_pipeline vk_pipeline_f16_to_f32, vk_pipeline_dequant_q4_0;
|
vk_pipeline vk_pipeline_f16_to_f32, vk_pipeline_dequant_q4_0;
|
||||||
VmaAllocation vk_buffer_qa_alloc, vk_buffer_a_alloc, vk_buffer_b_alloc, vk_buffer_c_alloc;
|
VmaAllocation vk_buffer_qa_alloc, vk_buffer_a_alloc, vk_buffer_b_alloc, vk_buffer_c_alloc;
|
||||||
|
@ -644,10 +645,16 @@ void ggml_vk_init(void) {
|
||||||
vk_pipeline_matmul_f32_l = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
vk_pipeline_matmul_f32_l = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||||
vk_pipeline_matmul_f32_m = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
vk_pipeline_matmul_f32_m = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||||
vk_pipeline_matmul_f32_s = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
vk_pipeline_matmul_f32_s = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||||
|
vk_pipeline_matmul_f32_aligned_l = ggml_vk_create_pipeline("vk_shaders/matmul_f32_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||||
|
vk_pipeline_matmul_f32_aligned_m = ggml_vk_create_pipeline("vk_shaders/matmul_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||||
|
vk_pipeline_matmul_f32_aligned_s = ggml_vk_create_pipeline("vk_shaders/matmul_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||||
if (vk_fp16_support) {
|
if (vk_fp16_support) {
|
||||||
vk_pipeline_matmul_f16_l = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
vk_pipeline_matmul_f16_l = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||||
vk_pipeline_matmul_f16_m = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
vk_pipeline_matmul_f16_m = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||||
vk_pipeline_matmul_f16_s = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
vk_pipeline_matmul_f16_s = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||||
|
vk_pipeline_matmul_f16_aligned_l = ggml_vk_create_pipeline("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||||
|
vk_pipeline_matmul_f16_aligned_m = ggml_vk_create_pipeline("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||||
|
vk_pipeline_matmul_f16_aligned_s = ggml_vk_create_pipeline("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||||
}
|
}
|
||||||
vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline("vk_shaders/matmul_split_k_reduce.spv", "main", 1, 3 * sizeof(int), {32, 32, 1}, {}, 1);
|
vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline("vk_shaders/matmul_split_k_reduce.spv", "main", 1, 3 * sizeof(int), {32, 32, 1}, {}, 1);
|
||||||
|
|
||||||
|
@ -1242,7 +1249,7 @@ static void ggml_vk_buffer_read(vk_buffer* src, size_t offset, void * dst, size_
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_sequence ggml_vk_h2d_tensor_2d(vk_buffer* dst, size_t offset, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, size_t align, vk_queue& q, std::vector<vk::Semaphore>&& wait_semaphores, std::vector<vk::Semaphore>&& signal_semaphores) {
|
static vk_sequence ggml_vk_h2d_tensor_2d(vk_buffer* dst, size_t offset, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, vk_queue& q, std::vector<vk::Semaphore>&& wait_semaphores, std::vector<vk::Semaphore>&& signal_semaphores) {
|
||||||
#ifdef VK_DEBUG
|
#ifdef VK_DEBUG
|
||||||
std::cerr << "ggml_vk_h2d_tensor_2d()" << std::endl;
|
std::cerr << "ggml_vk_h2d_tensor_2d()" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
|
@ -1259,11 +1266,10 @@ static vk_sequence ggml_vk_h2d_tensor_2d(vk_buffer* dst, size_t offset, const st
|
||||||
|
|
||||||
const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
|
const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
|
||||||
if (nb0 == ts && nb1 == row_length) {
|
if (nb0 == ts && nb1 == row_length) {
|
||||||
// return ggml_vk_buffer_write_async(dst, offset, x, ne1*nb1, q, std::move(wait_semaphores), std::move(signal_semaphores));
|
return ggml_vk_buffer_write_async(dst, offset, x, ne1*nb1, q, std::move(wait_semaphores), std::move(signal_semaphores));
|
||||||
return ggml_vk_buffer_write_2d_async_zeropad(dst, offset, x, nb1, row_length, ne1, align, q, std::move(wait_semaphores), std::move(signal_semaphores));
|
|
||||||
}
|
}
|
||||||
if (nb0 == ts) {
|
if (nb0 == ts) {
|
||||||
return ggml_vk_buffer_write_2d_async_zeropad(dst, offset, x, nb1, row_length, ne1, align, q, std::move(wait_semaphores), std::move(signal_semaphores));
|
return ggml_vk_buffer_write_2d_async(dst, offset, x, nb1, row_length, ne1, q, std::move(wait_semaphores), std::move(signal_semaphores));
|
||||||
}
|
}
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
// TODO: also needs handling of staging buffers
|
// TODO: also needs handling of staging buffers
|
||||||
|
@ -1287,27 +1293,40 @@ static int ggml_vk_guess_split_k(int m, int n, int k) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_pipeline* ggml_vk_guess_matmul_pipeline(bool bit16, int m, int n) {
|
static uint32_t ggml_vk_guess_matmul_pipeline_align(int m, int n) {
|
||||||
|
#ifdef VK_DEBUG
|
||||||
|
std::cerr << "ggml_vk_guess_matmul_pipeline_padding()" << std::endl;
|
||||||
|
#endif
|
||||||
|
if (m <= 32 || n <= 32) {
|
||||||
|
return vk_pipeline_matmul_f32_s.align;
|
||||||
|
}
|
||||||
|
if (m <= 64 || n <= 64) {
|
||||||
|
return vk_pipeline_matmul_f32_m.align;
|
||||||
|
}
|
||||||
|
return vk_pipeline_matmul_f32_l.align;
|
||||||
|
}
|
||||||
|
|
||||||
|
static vk_pipeline* ggml_vk_guess_matmul_pipeline(bool bit16, int m, int n, bool aligned) {
|
||||||
#ifdef VK_DEBUG
|
#ifdef VK_DEBUG
|
||||||
std::cerr << "ggml_vk_guess_matmul_pipeline()" << std::endl;
|
std::cerr << "ggml_vk_guess_matmul_pipeline()" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
if (bit16) {
|
if (bit16) {
|
||||||
if (m <= 32 || n <= 32) {
|
if (m <= 32 || n <= 32) {
|
||||||
return &vk_pipeline_matmul_f16_s;
|
return aligned ? &vk_pipeline_matmul_f16_aligned_s : &vk_pipeline_matmul_f16_s;
|
||||||
}
|
}
|
||||||
if (m <= 64 || n <= 64) {
|
if (m <= 64 || n <= 64) {
|
||||||
return &vk_pipeline_matmul_f16_m;
|
return aligned ? &vk_pipeline_matmul_f16_aligned_m : &vk_pipeline_matmul_f16_m;
|
||||||
}
|
}
|
||||||
return &vk_pipeline_matmul_f16_l;
|
return aligned ? &vk_pipeline_matmul_f16_aligned_l : &vk_pipeline_matmul_f16_l;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (m <= 32 || n <= 32) {
|
if (m <= 32 || n <= 32) {
|
||||||
return &vk_pipeline_matmul_f32_s;
|
return aligned ? &vk_pipeline_matmul_f32_aligned_s : &vk_pipeline_matmul_f32_s;
|
||||||
}
|
}
|
||||||
if (m <= 64 || n <= 64) {
|
if (m <= 64 || n <= 64) {
|
||||||
return &vk_pipeline_matmul_f32_m;
|
return aligned ? &vk_pipeline_matmul_f32_aligned_m : &vk_pipeline_matmul_f32_m;
|
||||||
}
|
}
|
||||||
return &vk_pipeline_matmul_f32_l;
|
return aligned ? &vk_pipeline_matmul_f32_aligned_l : &vk_pipeline_matmul_f32_l;
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_sequence ggml_vk_matmul(vk_pipeline& pipeline, vk_buffer& a, vk_buffer& b, vk_buffer& d, int m, int n, int k, int stride_a, int stride_b, int stride_d, int split_k, vk_queue& q, std::vector<vk::Semaphore>&& wait_semaphores, std::vector<vk::Semaphore>&& signal_semaphores) {
|
static vk_sequence ggml_vk_matmul(vk_pipeline& pipeline, vk_buffer& a, vk_buffer& b, vk_buffer& d, int m, int n, int k, int stride_a, int stride_b, int stride_d, int split_k, vk_queue& q, std::vector<vk::Semaphore>&& wait_semaphores, std::vector<vk::Semaphore>&& signal_semaphores) {
|
||||||
|
@ -1341,7 +1360,6 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
std::cerr << "), (type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3];
|
std::cerr << "), (type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3];
|
||||||
std::cerr << "), (type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << "),)" << std::endl;
|
std::cerr << "), (type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << "),)" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
const int64_t ne00 = src0->ne[0];
|
|
||||||
const int64_t ne01 = src0->ne[1];
|
const int64_t ne01 = src0->ne[1];
|
||||||
const int64_t ne02 = src0->ne[2];
|
const int64_t ne02 = src0->ne[2];
|
||||||
const int64_t ne03 = src0->ne[3];
|
const int64_t ne03 = src0->ne[3];
|
||||||
|
@ -1355,9 +1373,10 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
const int d_ne = ne11 * ne01;
|
const int d_ne = ne11 * ne01;
|
||||||
|
|
||||||
const int split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
|
const int split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
|
||||||
vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(false, ne01, ne11);
|
|
||||||
|
|
||||||
const int kpad = ggml_vk_align_size(ne10, pipeline->align);
|
const int kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ne01, ne11));
|
||||||
|
|
||||||
|
vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(false, ne01, ne11, ne10 == kpad);
|
||||||
|
|
||||||
vk_buffer d_X;
|
vk_buffer d_X;
|
||||||
vk_buffer d_Y;
|
vk_buffer d_Y;
|
||||||
|
@ -1392,20 +1411,20 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
s_x = ggml_vk_create_semaphore(vk_compute_queue);
|
s_x = ggml_vk_create_semaphore(vk_compute_queue);
|
||||||
semaphores.push_back(s_x);
|
semaphores.push_back(s_x);
|
||||||
if (first) {
|
if (first) {
|
||||||
transfer_0_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02, pipeline->align * sizeof(float), vk_transfer_queues[0], {}, { s_x }));
|
transfer_0_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02, vk_transfer_queues[0], {}, { s_x }));
|
||||||
} else {
|
} else {
|
||||||
// Wait for previous matmul to be done before writing to the input buffers again
|
// Wait for previous matmul to be done before writing to the input buffers again
|
||||||
transfer_0_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02, pipeline->align * sizeof(float), vk_transfer_queues[0], { s_it_x }, { s_x }));
|
transfer_0_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02, vk_transfer_queues[0], { s_it_x }, { s_x }));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_vk_submit(vk_transfer_queues[0], transfer_0_seqs, VK_NULL_HANDLE);
|
ggml_vk_submit(vk_transfer_queues[0], transfer_0_seqs, VK_NULL_HANDLE);
|
||||||
|
|
||||||
if (first) {
|
if (first) {
|
||||||
transfer_1_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02, pipeline->align * sizeof(float), vk_transfer_queues[1], {}, { s_y }));
|
transfer_1_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02, vk_transfer_queues[1], {}, { s_y }));
|
||||||
} else {
|
} else {
|
||||||
// Wait for previous matmul to be done before writing to the input buffers again
|
// Wait for previous matmul to be done before writing to the input buffers again
|
||||||
transfer_1_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02, pipeline->align * sizeof(float), vk_transfer_queues[1], { s_it_y }, { s_y }));
|
transfer_1_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02, vk_transfer_queues[1], { s_it_y }, { s_y }));
|
||||||
}
|
}
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
|
@ -1415,13 +1434,13 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
if (load_x) {
|
if (load_x) {
|
||||||
s_it_x = ggml_vk_create_semaphore(vk_compute_queue);
|
s_it_x = ggml_vk_create_semaphore(vk_compute_queue);
|
||||||
s_it_y = ggml_vk_create_semaphore(vk_compute_queue);
|
s_it_y = ggml_vk_create_semaphore(vk_compute_queue);
|
||||||
compute_seqs.push_back(ggml_vk_matmul(*pipeline, d_X, d_Y, d_D, ne01, ne11, ne10, kpad, kpad, ne01, split_k, vk_compute_queue, std::move(semaphores), { s_mm, s_it_x, s_it_y }));
|
compute_seqs.push_back(ggml_vk_matmul(*pipeline, d_X, d_Y, d_D, ne01, ne11, ne10, ne10, ne10, ne01, split_k, vk_compute_queue, std::move(semaphores), { s_mm, s_it_x, s_it_y }));
|
||||||
} else {
|
} else {
|
||||||
s_it_y = ggml_vk_create_semaphore(vk_compute_queue);
|
s_it_y = ggml_vk_create_semaphore(vk_compute_queue);
|
||||||
compute_seqs.push_back(ggml_vk_matmul(*pipeline, d_X, d_Y, d_D, ne01, ne11, ne10, kpad, kpad, ne01, split_k, vk_compute_queue, std::move(semaphores), { s_mm, s_it_y }));
|
compute_seqs.push_back(ggml_vk_matmul(*pipeline, d_X, d_Y, d_D, ne01, ne11, ne10, ne10, ne10, ne01, split_k, vk_compute_queue, std::move(semaphores), { s_mm, s_it_y }));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
compute_seqs.push_back(ggml_vk_matmul(*pipeline, d_X, d_Y, d_D, ne01, ne11, ne10, kpad, kpad, ne01, split_k, vk_compute_queue, std::move(semaphores), { s_mm }));
|
compute_seqs.push_back(ggml_vk_matmul(*pipeline, d_X, d_Y, d_D, ne01, ne11, ne10, ne10, ne10, ne01, split_k, vk_compute_queue, std::move(semaphores), { s_mm }));
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy dst to host
|
// copy dst to host
|
||||||
|
@ -1473,12 +1492,14 @@ static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
const int nb2 = dst->nb[2];
|
const int nb2 = dst->nb[2];
|
||||||
const int nb3 = dst->nb[3];
|
const int nb3 = dst->nb[3];
|
||||||
|
|
||||||
|
const int y_ne = ne11 * ne10;
|
||||||
const int d_ne = ne11 * ne01;
|
const int d_ne = ne11 * ne01;
|
||||||
|
|
||||||
const int split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
|
const int split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
|
||||||
vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(true, ne01, ne11);
|
|
||||||
|
|
||||||
const int kpad = ggml_vk_align_size(ne10, pipeline->align);
|
const int kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ne01, ne11));
|
||||||
|
|
||||||
|
vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(true, ne01, ne11, ne10 == kpad);
|
||||||
|
|
||||||
vk_buffer d_X;
|
vk_buffer d_X;
|
||||||
vk_buffer d_Y;
|
vk_buffer d_Y;
|
||||||
|
@ -1519,10 +1540,10 @@ static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
s_x = ggml_vk_create_semaphore(vk_compute_queue);
|
s_x = ggml_vk_create_semaphore(vk_compute_queue);
|
||||||
semaphores.push_back(s_x);
|
semaphores.push_back(s_x);
|
||||||
if (first) {
|
if (first) {
|
||||||
transfer_0_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02, pipeline->align * sizeof(ggml_fp16_t), vk_transfer_queues[0], {}, { s_x }));
|
transfer_0_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02, vk_transfer_queues[0], {}, { s_x }));
|
||||||
} else {
|
} else {
|
||||||
// Wait for previous matmul to be done before writing to the input buffers again
|
// Wait for previous matmul to be done before writing to the input buffers again
|
||||||
transfer_0_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02, pipeline->align * sizeof(ggml_fp16_t), vk_transfer_queues[0], { s_it_x }, { s_x }));
|
transfer_0_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_X, 0, src0, i03, i02, vk_transfer_queues[0], { s_it_x }, { s_x }));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1530,7 +1551,6 @@ static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
|
|
||||||
// convert src1 to fp16
|
// convert src1 to fp16
|
||||||
// TODO: use multiple threads
|
// TODO: use multiple threads
|
||||||
// TODO: This memory isn't pinned
|
|
||||||
ggml_fp16_t * const tmp = fp16_staging + (ne11 * ne10) * (i03 * ne02 + i02);
|
ggml_fp16_t * const tmp = fp16_staging + (ne11 * ne10) * (i03 * ne02 + i02);
|
||||||
char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
|
char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
|
||||||
if (src1_cont_rows) {
|
if (src1_cont_rows) {
|
||||||
|
@ -1552,10 +1572,10 @@ static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
}
|
}
|
||||||
|
|
||||||
if (first) {
|
if (first) {
|
||||||
transfer_1_seqs.push_back(ggml_vk_buffer_write_2d_async_zeropad(&d_Y, 0, tmp, sizeof(ggml_fp16_t) * ne10, sizeof(ggml_fp16_t) * ne10, ne11, pipeline->align * sizeof(ggml_fp16_t), vk_transfer_queues[1], {}, { s_y }));
|
transfer_1_seqs.push_back(ggml_vk_buffer_write_async(&d_Y, 0, tmp, sizeof(ggml_fp16_t) * y_ne, vk_transfer_queues[1], {}, { s_y }));
|
||||||
} else {
|
} else {
|
||||||
// Wait for previous matmul to be done before writing to the input buffers again
|
// Wait for previous matmul to be done before writing to the input buffers again
|
||||||
transfer_1_seqs.push_back(ggml_vk_buffer_write_2d_async_zeropad(&d_Y, 0, tmp, sizeof(ggml_fp16_t) * ne10, sizeof(ggml_fp16_t) * ne10, ne11, pipeline->align * sizeof(ggml_fp16_t), vk_transfer_queues[1], { s_it_y }, { s_y }));
|
transfer_1_seqs.push_back(ggml_vk_buffer_write_async(&d_Y, 0, tmp, sizeof(ggml_fp16_t) * y_ne, vk_transfer_queues[1], { s_it_y }, { s_y }));
|
||||||
}
|
}
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
|
@ -1564,13 +1584,13 @@ static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
if (load_x) {
|
if (load_x) {
|
||||||
s_it_x = ggml_vk_create_semaphore(vk_compute_queue);
|
s_it_x = ggml_vk_create_semaphore(vk_compute_queue);
|
||||||
s_it_y = ggml_vk_create_semaphore(vk_compute_queue);
|
s_it_y = ggml_vk_create_semaphore(vk_compute_queue);
|
||||||
compute_seqs.push_back(ggml_vk_matmul(*pipeline, d_X, d_Y, d_D, ne01, ne11, ne10, kpad, kpad, ne01, split_k, vk_compute_queue, std::move(semaphores), { s_mm, s_it_x, s_it_y }));
|
compute_seqs.push_back(ggml_vk_matmul(*pipeline, d_X, d_Y, d_D, ne01, ne11, ne10, ne10, ne10, ne01, split_k, vk_compute_queue, std::move(semaphores), { s_mm, s_it_x, s_it_y }));
|
||||||
} else {
|
} else {
|
||||||
s_it_y = ggml_vk_create_semaphore(vk_compute_queue);
|
s_it_y = ggml_vk_create_semaphore(vk_compute_queue);
|
||||||
compute_seqs.push_back(ggml_vk_matmul(*pipeline, d_X, d_Y, d_D, ne01, ne11, ne10, kpad, kpad, ne01, split_k, vk_compute_queue, std::move(semaphores), { s_mm, s_it_y }));
|
compute_seqs.push_back(ggml_vk_matmul(*pipeline, d_X, d_Y, d_D, ne01, ne11, ne10, ne10, ne10, ne01, split_k, vk_compute_queue, std::move(semaphores), { s_mm, s_it_y }));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
compute_seqs.push_back(ggml_vk_matmul(*pipeline, d_X, d_Y, d_D, ne01, ne11, ne10, kpad, kpad, ne01, split_k, vk_compute_queue, std::move(semaphores), { s_mm }));
|
compute_seqs.push_back(ggml_vk_matmul(*pipeline, d_X, d_Y, d_D, ne01, ne11, ne10, ne10, ne10, ne01, split_k, vk_compute_queue, std::move(semaphores), { s_mm }));
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy dst to host
|
// copy dst to host
|
||||||
|
@ -1623,9 +1643,10 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
|
||||||
const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
|
const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
|
||||||
|
|
||||||
const int split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
|
const int split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
|
||||||
vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(false, ne01, ne11);
|
|
||||||
|
|
||||||
const int kpad = ggml_vk_align_size(ne10, pipeline->align);
|
const int kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ne01, ne11));
|
||||||
|
|
||||||
|
vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(false, ne01, ne11, ne10 == kpad);
|
||||||
|
|
||||||
vk_buffer d_X;
|
vk_buffer d_X;
|
||||||
vk_buffer d_Y;
|
vk_buffer d_Y;
|
||||||
|
@ -1672,10 +1693,10 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
|
||||||
s_x = ggml_vk_create_semaphore(vk_compute_queue);
|
s_x = ggml_vk_create_semaphore(vk_compute_queue);
|
||||||
q_semaphores.push_back(s_x);
|
q_semaphores.push_back(s_x);
|
||||||
if (first) {
|
if (first) {
|
||||||
transfer_0_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_Q, 0, src0, i03, i02, 1, vk_transfer_queues[0], {}, { s_x }));
|
transfer_0_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_Q, 0, src0, i03, i02, vk_transfer_queues[0], {}, { s_x }));
|
||||||
} else {
|
} else {
|
||||||
// Wait for previous dequant to be done before writing to the input buffers again
|
// Wait for previous dequant to be done before writing to the input buffers again
|
||||||
transfer_0_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_Q, 0, src0, i03, i02, 1, vk_transfer_queues[0], { s_it_x }, { s_x }));
|
transfer_0_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_Q, 0, src0, i03, i02, vk_transfer_queues[0], { s_it_x }, { s_x }));
|
||||||
}
|
}
|
||||||
} else if (src0->backend == GGML_BACKEND_GPU) {
|
} else if (src0->backend == GGML_BACKEND_GPU) {
|
||||||
d_Q = *(vk_buffer *) src0->data;
|
d_Q = *(vk_buffer *) src0->data;
|
||||||
|
@ -1687,10 +1708,10 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
|
||||||
|
|
||||||
// copy src1 to device
|
// copy src1 to device
|
||||||
if (first) {
|
if (first) {
|
||||||
transfer_1_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02, pipeline->align * sizeof(float), vk_transfer_queues[1], {}, { s_y }));
|
transfer_1_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02, vk_transfer_queues[1], {}, { s_y }));
|
||||||
} else {
|
} else {
|
||||||
// Wait for previous matmul to be done before writing to the input buffers again
|
// Wait for previous matmul to be done before writing to the input buffers again
|
||||||
transfer_1_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02, pipeline->align * sizeof(float), vk_transfer_queues[1], { s_it_y }, { s_y }));
|
transfer_1_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_Y, 0, src1, i03, i02, vk_transfer_queues[1], { s_it_y }, { s_y }));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
|
if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
|
||||||
|
@ -1714,7 +1735,7 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
|
||||||
|
|
||||||
// convert src0 to fp32 on device
|
// convert src0 to fp32 on device
|
||||||
vk_submission s = ggml_vk_begin_submission(vk_compute_queue);
|
vk_submission s = ggml_vk_begin_submission(vk_compute_queue);
|
||||||
const std::vector<int> pc = { (int)ne01, (int)ne10, (int)ne10, kpad };
|
const std::vector<int> pc = { (int)ne01, (int)ne10, (int)ne10, (int)ne10 };
|
||||||
ggml_vk_sync_buffers(s.buffer, { d_Q }, vk_compute_queue, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, false);
|
ggml_vk_sync_buffers(s.buffer, { d_Q }, vk_compute_queue, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, false);
|
||||||
ggml_vk_sync_buffers(s.buffer, { d_X }, vk_compute_queue, vk::AccessFlagBits::eShaderRead, vk::AccessFlagBits::eShaderWrite, false);
|
ggml_vk_sync_buffers(s.buffer, { d_X }, vk_compute_queue, vk::AccessFlagBits::eShaderRead, vk::AccessFlagBits::eShaderWrite, false);
|
||||||
ggml_vk_dispatch_pipeline(s, *to_fp32_vk, {d_Q, d_X}, pc.size() * sizeof(int), pc.data(), { (uint32_t)x_ne, 1, 1}, vk_compute_queue);
|
ggml_vk_dispatch_pipeline(s, *to_fp32_vk, {d_Q, d_X}, pc.size() * sizeof(int), pc.data(), { (uint32_t)x_ne, 1, 1}, vk_compute_queue);
|
||||||
|
@ -1729,9 +1750,9 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
|
||||||
// compute
|
// compute
|
||||||
if (!last) {
|
if (!last) {
|
||||||
s_it_y = ggml_vk_create_semaphore(vk_compute_queue);
|
s_it_y = ggml_vk_create_semaphore(vk_compute_queue);
|
||||||
compute_seqs.push_back(ggml_vk_matmul(*pipeline, d_X, d_Y, d_D, ne01, ne11, ne10, kpad, kpad, ne01, split_k, vk_compute_queue, std::move(semaphores), { s_mm, s_it_y }));
|
compute_seqs.push_back(ggml_vk_matmul(*pipeline, d_X, d_Y, d_D, ne01, ne11, ne10, ne10, ne10, ne01, split_k, vk_compute_queue, std::move(semaphores), { s_mm, s_it_y }));
|
||||||
} else {
|
} else {
|
||||||
compute_seqs.push_back(ggml_vk_matmul(*pipeline, d_X, d_Y, d_D, ne01, ne11, ne10, kpad, kpad, ne01, split_k, vk_compute_queue, std::move(semaphores), { s_mm }));
|
compute_seqs.push_back(ggml_vk_matmul(*pipeline, d_X, d_Y, d_D, ne01, ne11, ne10, ne10, ne10, ne01, split_k, vk_compute_queue, std::move(semaphores), { s_mm }));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,8 +7,8 @@
|
||||||
|
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A { f16mat2x4 data_a[]; };
|
layout (binding = 0) readonly buffer A { float16_t data_a[]; };
|
||||||
layout (binding = 1) readonly buffer B { f16mat2x4 data_b[]; };
|
layout (binding = 1) readonly buffer B { float16_t data_b[]; };
|
||||||
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
||||||
|
|
||||||
layout (push_constant) uniform parameter
|
layout (push_constant) uniform parameter
|
||||||
|
@ -52,16 +52,16 @@ void main() {
|
||||||
const int tiwr = tiw % (WSUBM / TM);
|
const int tiwr = tiw % (WSUBM / TM);
|
||||||
const int tiwc = tiw / (WSUBM / TM);
|
const int tiwc = tiw / (WSUBM / TM);
|
||||||
|
|
||||||
const int loadr = int(gl_LocalInvocationID.x % (BK / 8));
|
const int loadr = int(gl_LocalInvocationID.x % BK);
|
||||||
const int loadc = int(gl_LocalInvocationID.x / (BK / 8));
|
const int loadc = int(gl_LocalInvocationID.x / BK);
|
||||||
|
|
||||||
const int loadstride = int(gl_WorkGroupSize.x * 8) / BK;
|
const int loadstride = int(gl_WorkGroupSize.x);
|
||||||
|
|
||||||
const int start_k = ik * p.k_split;
|
const int start_k = ik * p.k_split;
|
||||||
const int end_k = (ik + 1) * p.k_split;
|
const int end_k = (ik + 1) * p.k_split;
|
||||||
|
|
||||||
int pos_a = ir * BM * p.stride_a / 8 + start_k / 8;
|
int pos_a = ir * BM * p.stride_a + start_k;
|
||||||
int pos_b = ic * BN * p.stride_b / 8 + start_k / 8;
|
int pos_b = ic * BN * p.stride_b + start_k;
|
||||||
|
|
||||||
float sums[WMITER * TM * WNITER * TN];
|
float sums[WMITER * TM * WNITER * TN];
|
||||||
float16_t cache_a[WMITER * TM];
|
float16_t cache_a[WMITER * TM];
|
||||||
|
@ -72,33 +72,29 @@ void main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
|
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
|
||||||
[[unroll]] for (int l = 0; l < BM; l += loadstride) {
|
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
|
||||||
f16mat2x4 tmp = data_a[pos_a + (loadc + l) * p.stride_a / 8 + loadr];
|
const int lr = l % BK;
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 0] = tmp[0].x;
|
const int lc = l / BK;
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 1] = tmp[0].y;
|
if (ir * BM + loadc + lc < p.M && block + loadr + lr < p.K) {
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 2] = tmp[0].z;
|
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr];
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 3] = tmp[0].w;
|
} else {
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 4] = tmp[1].x;
|
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf;
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 5] = tmp[1].y;
|
}
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 6] = tmp[1].z;
|
}
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 7] = tmp[1].w;
|
[[unroll]] for (int l = 0; l < BN * BK; l += loadstride) {
|
||||||
|
const int lr = l % BK;
|
||||||
|
const int lc = l / BK;
|
||||||
|
if (ic * BN + loadc + lc < p.N && block + loadr + lr < p.K) {
|
||||||
|
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr];
|
||||||
|
} else {
|
||||||
|
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf;
|
||||||
}
|
}
|
||||||
[[unroll]] for (int l = 0; l < BN; l += loadstride) {
|
|
||||||
f16mat2x4 tmp = data_b[pos_b + (loadc + l) * p.stride_b / 8 + loadr];
|
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 0] = tmp[0].x;
|
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 1] = tmp[0].y;
|
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 2] = tmp[0].z;
|
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 3] = tmp[0].w;
|
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 4] = tmp[1].x;
|
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 5] = tmp[1].y;
|
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 6] = tmp[1].z;
|
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 7] = tmp[1].w;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
barrier();
|
barrier();
|
||||||
|
|
||||||
pos_a += BK / 8;
|
pos_a += BK;
|
||||||
pos_b += BK / 8;
|
pos_b += BK;
|
||||||
|
|
||||||
for (int i = 0; i < min(BK, p.K - block); i++) {
|
for (int i = 0; i < min(BK, p.K - block); i++) {
|
||||||
// Load from shared into cache
|
// Load from shared into cache
|
||||||
|
|
149
vk_shaders/matmul_f16_aligned.glsl
Normal file
149
vk_shaders/matmul_f16_aligned.glsl
Normal file
|
@ -0,0 +1,149 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#define WARP 32
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||||
|
|
||||||
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer A { f16mat2x4 data_a[]; };
|
||||||
|
layout (binding = 1) readonly buffer B { f16mat2x4 data_b[]; };
|
||||||
|
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter
|
||||||
|
{
|
||||||
|
int M;
|
||||||
|
int N;
|
||||||
|
int K;
|
||||||
|
int stride_a;
|
||||||
|
int stride_b;
|
||||||
|
int stride_d;
|
||||||
|
int k_split;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
layout (constant_id = 1) const int BM = 64;
|
||||||
|
layout (constant_id = 2) const int BN = 64;
|
||||||
|
layout (constant_id = 3) const int BK = 16;
|
||||||
|
layout (constant_id = 4) const int WM = 32;
|
||||||
|
layout (constant_id = 5) const int WN = 32;
|
||||||
|
layout (constant_id = 6) const int WMITER = 2;
|
||||||
|
layout (constant_id = 7) const int TM = 4;
|
||||||
|
layout (constant_id = 8) const int TN = 2;
|
||||||
|
|
||||||
|
shared float16_t buf_a[BM * (BK+1)];
|
||||||
|
shared float16_t buf_b[BN * (BK+1)];
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const int blocks_x = (p.M + BM - 1) / BM;
|
||||||
|
const int ir = int(gl_WorkGroupID.x) % blocks_x;
|
||||||
|
const int ik = int(gl_WorkGroupID.x) / blocks_x;
|
||||||
|
const int ic = int(gl_WorkGroupID.y);
|
||||||
|
|
||||||
|
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
||||||
|
const int warp_r = warp_i % (BM / WM);
|
||||||
|
const int warp_c = warp_i / (BM / WM);
|
||||||
|
|
||||||
|
const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||||
|
const int WSUBM = WM / WMITER;
|
||||||
|
const int WSUBN = WN / WNITER;
|
||||||
|
|
||||||
|
const int tiw = int(gl_LocalInvocationID.x % WARP);
|
||||||
|
const int tiwr = tiw % (WSUBM / TM);
|
||||||
|
const int tiwc = tiw / (WSUBM / TM);
|
||||||
|
|
||||||
|
const int loadr = int(gl_LocalInvocationID.x % (BK / 8));
|
||||||
|
const int loadc = int(gl_LocalInvocationID.x / (BK / 8));
|
||||||
|
|
||||||
|
const int loadstride = int(gl_WorkGroupSize.x * 8) / BK;
|
||||||
|
|
||||||
|
const int start_k = ik * p.k_split;
|
||||||
|
const int end_k = (ik + 1) * p.k_split;
|
||||||
|
|
||||||
|
int pos_a = ir * BM * p.stride_a / 8 + start_k / 8;
|
||||||
|
int pos_b = ic * BN * p.stride_b / 8 + start_k / 8;
|
||||||
|
|
||||||
|
float sums[WMITER * TM * WNITER * TN];
|
||||||
|
float16_t cache_a[WMITER * TM];
|
||||||
|
float16_t cache_b[WNITER * TN];
|
||||||
|
|
||||||
|
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||||
|
sums[i] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
|
||||||
|
[[unroll]] for (int l = 0; l < BM; l += loadstride) {
|
||||||
|
f16mat2x4 tmp = data_a[pos_a + (loadc + l) * p.stride_a / 8 + loadr];
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 0] = tmp[0].x;
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 1] = tmp[0].y;
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 2] = tmp[0].z;
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 3] = tmp[0].w;
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 4] = tmp[1].x;
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 5] = tmp[1].y;
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 6] = tmp[1].z;
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 7] = tmp[1].w;
|
||||||
|
}
|
||||||
|
[[unroll]] for (int l = 0; l < BN; l += loadstride) {
|
||||||
|
f16mat2x4 tmp = data_b[pos_b + (loadc + l) * p.stride_b / 8 + loadr];
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 0] = tmp[0].x;
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 1] = tmp[0].y;
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 2] = tmp[0].z;
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 3] = tmp[0].w;
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 4] = tmp[1].x;
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 5] = tmp[1].y;
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 6] = tmp[1].z;
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 7] = tmp[1].w;
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
pos_a += BK / 8;
|
||||||
|
pos_b += BK / 8;
|
||||||
|
|
||||||
|
for (int i = 0; i < min(BK, p.K - block); i++) {
|
||||||
|
// Load from shared into cache
|
||||||
|
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||||
|
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||||
|
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||||
|
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||||
|
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||||
|
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||||
|
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||||
|
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||||
|
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
const int dr = ir * BM + warp_r * WM;
|
||||||
|
const int dc = ic * BN + warp_c * WN;
|
||||||
|
|
||||||
|
const int k_split_offset = ik * p.M * p.N;
|
||||||
|
|
||||||
|
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||||
|
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||||
|
|
||||||
|
const int dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
||||||
|
const int dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||||
|
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||||
|
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||||
|
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||||
|
data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -6,8 +6,8 @@
|
||||||
|
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A { vec4 data_a[]; };
|
layout (binding = 0) readonly buffer A { float data_a[]; };
|
||||||
layout (binding = 1) readonly buffer B { vec4 data_b[]; };
|
layout (binding = 1) readonly buffer B { float data_b[]; };
|
||||||
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
||||||
|
|
||||||
layout (push_constant) uniform parameter
|
layout (push_constant) uniform parameter
|
||||||
|
@ -51,16 +51,16 @@ void main() {
|
||||||
const int tiwr = tiw % (WSUBM / TM);
|
const int tiwr = tiw % (WSUBM / TM);
|
||||||
const int tiwc = tiw / (WSUBM / TM);
|
const int tiwc = tiw / (WSUBM / TM);
|
||||||
|
|
||||||
const int loadr = int(gl_LocalInvocationID.x % (BK / 4));
|
const int loadr = int(gl_LocalInvocationID.x % BK);
|
||||||
const int loadc = int(gl_LocalInvocationID.x / (BK / 4));
|
const int loadc = int(gl_LocalInvocationID.x / BK);
|
||||||
|
|
||||||
const int loadstride = int(gl_WorkGroupSize.x * 4) / BK;
|
const int loadstride = int(gl_WorkGroupSize.x);
|
||||||
|
|
||||||
const int start_k = ik * p.k_split;
|
const int start_k = ik * p.k_split;
|
||||||
const int end_k = (ik + 1) * p.k_split;
|
const int end_k = (ik + 1) * p.k_split;
|
||||||
|
|
||||||
int pos_a = ir * BM * p.stride_a / 4 + start_k / 4;
|
int pos_a = ir * BM * p.stride_a + start_k;
|
||||||
int pos_b = ic * BN * p.stride_b / 4 + start_k / 4;
|
int pos_b = ic * BN * p.stride_b + start_k;
|
||||||
|
|
||||||
float sums[WMITER * TM * WNITER * TN];
|
float sums[WMITER * TM * WNITER * TN];
|
||||||
float cache_a[WMITER * TM];
|
float cache_a[WMITER * TM];
|
||||||
|
@ -71,25 +71,29 @@ void main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
|
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
|
||||||
[[unroll]] for (int l = 0; l < BM; l += loadstride) {
|
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
|
||||||
vec4 tmp = data_a[pos_a + (loadc + l) * p.stride_a / 4 + loadr];
|
const int lr = l % BK;
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * 4 + 0] = tmp.x;
|
const int lc = l / BK;
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * 4 + 1] = tmp.y;
|
if (ir * BM + loadc + lc < p.M && block + loadr + lr < p.K) {
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * 4 + 2] = tmp.z;
|
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr];
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * 4 + 3] = tmp.w;
|
} else {
|
||||||
|
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
[[unroll]] for (int l = 0; l < BN * BK; l += loadstride) {
|
||||||
|
const int lr = l % BK;
|
||||||
|
const int lc = l / BK;
|
||||||
|
if (ic * BN + loadc + lc < p.N && block + loadr + lr < p.K) {
|
||||||
|
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr];
|
||||||
|
} else {
|
||||||
|
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = 0.0f;
|
||||||
}
|
}
|
||||||
[[unroll]] for (int l = 0; l < BN; l += loadstride) {
|
|
||||||
vec4 tmp = data_b[pos_b + (loadc + l) * p.stride_b / 4 + loadr];
|
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * 4 + 0] = tmp.x;
|
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * 4 + 1] = tmp.y;
|
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * 4 + 2] = tmp.z;
|
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * 4 + 3] = tmp.w;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
barrier();
|
barrier();
|
||||||
|
|
||||||
pos_a += BK / 4;
|
pos_a += BK;
|
||||||
pos_b += BK / 4;
|
pos_b += BK;
|
||||||
|
|
||||||
for (int i = 0; i < min(BK, p.K - block); i++) {
|
for (int i = 0; i < min(BK, p.K - block); i++) {
|
||||||
// Load from shared into cache
|
// Load from shared into cache
|
||||||
|
|
140
vk_shaders/matmul_f32_aligned.glsl
Normal file
140
vk_shaders/matmul_f32_aligned.glsl
Normal file
|
@ -0,0 +1,140 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#define WARP 32
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer A { vec4 data_a[]; };
|
||||||
|
layout (binding = 1) readonly buffer B { vec4 data_b[]; };
|
||||||
|
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter
|
||||||
|
{
|
||||||
|
int M;
|
||||||
|
int N;
|
||||||
|
int K;
|
||||||
|
int stride_a;
|
||||||
|
int stride_b;
|
||||||
|
int stride_d;
|
||||||
|
int k_split;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
layout (constant_id = 1) const int BM = 64;
|
||||||
|
layout (constant_id = 2) const int BN = 64;
|
||||||
|
layout (constant_id = 3) const int BK = 16;
|
||||||
|
layout (constant_id = 4) const int WM = 32;
|
||||||
|
layout (constant_id = 5) const int WN = 32;
|
||||||
|
layout (constant_id = 6) const int WMITER = 2;
|
||||||
|
layout (constant_id = 7) const int TM = 4;
|
||||||
|
layout (constant_id = 8) const int TN = 2;
|
||||||
|
|
||||||
|
shared float buf_a[BM * (BK+1)];
|
||||||
|
shared float buf_b[BN * (BK+1)];
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const int blocks_x = (p.M + BM - 1) / BM;
|
||||||
|
const int ir = int(gl_WorkGroupID.x) % blocks_x;
|
||||||
|
const int ik = int(gl_WorkGroupID.x) / blocks_x;
|
||||||
|
const int ic = int(gl_WorkGroupID.y);
|
||||||
|
|
||||||
|
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
||||||
|
const int warp_r = warp_i % (BM / WM);
|
||||||
|
const int warp_c = warp_i / (BM / WM);
|
||||||
|
|
||||||
|
const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||||
|
const int WSUBM = WM / WMITER;
|
||||||
|
const int WSUBN = WN / WNITER;
|
||||||
|
|
||||||
|
const int tiw = int(gl_LocalInvocationID.x % WARP);
|
||||||
|
const int tiwr = tiw % (WSUBM / TM);
|
||||||
|
const int tiwc = tiw / (WSUBM / TM);
|
||||||
|
|
||||||
|
const int loadr = int(gl_LocalInvocationID.x % (BK / 4));
|
||||||
|
const int loadc = int(gl_LocalInvocationID.x / (BK / 4));
|
||||||
|
|
||||||
|
const int loadstride = int(gl_WorkGroupSize.x * 4) / BK;
|
||||||
|
|
||||||
|
const int start_k = ik * p.k_split;
|
||||||
|
const int end_k = (ik + 1) * p.k_split;
|
||||||
|
|
||||||
|
int pos_a = ir * BM * p.stride_a / 4 + start_k / 4;
|
||||||
|
int pos_b = ic * BN * p.stride_b / 4 + start_k / 4;
|
||||||
|
|
||||||
|
float sums[WMITER * TM * WNITER * TN];
|
||||||
|
float cache_a[WMITER * TM];
|
||||||
|
float cache_b[WNITER * TN];
|
||||||
|
|
||||||
|
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||||
|
sums[i] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
|
||||||
|
[[unroll]] for (int l = 0; l < BM; l += loadstride) {
|
||||||
|
vec4 tmp = data_a[pos_a + (loadc + l) * p.stride_a / 4 + loadr];
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 4 + 0] = tmp.x;
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 4 + 1] = tmp.y;
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 4 + 2] = tmp.z;
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 4 + 3] = tmp.w;
|
||||||
|
}
|
||||||
|
[[unroll]] for (int l = 0; l < BN; l += loadstride) {
|
||||||
|
vec4 tmp = data_b[pos_b + (loadc + l) * p.stride_b / 4 + loadr];
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 4 + 0] = tmp.x;
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 4 + 1] = tmp.y;
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 4 + 2] = tmp.z;
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 4 + 3] = tmp.w;
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
pos_a += BK / 4;
|
||||||
|
pos_b += BK / 4;
|
||||||
|
|
||||||
|
for (int i = 0; i < min(BK, p.K - block); i++) {
|
||||||
|
// Load from shared into cache
|
||||||
|
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||||
|
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||||
|
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||||
|
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||||
|
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||||
|
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||||
|
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||||
|
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||||
|
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += cache_a[wsir * TM + cr] * cache_b[wsic * TN + cc];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
const int dr = ir * BM + warp_r * WM;
|
||||||
|
const int dc = ic * BN + warp_c * WN;
|
||||||
|
|
||||||
|
const int k_split_offset = ik * p.M * p.N;
|
||||||
|
|
||||||
|
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||||
|
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||||
|
|
||||||
|
const int dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
||||||
|
const int dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||||
|
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||||
|
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||||
|
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||||
|
data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue