Use F16 kernel for most things, replace q_f32 with mul_mat_q_f16 function
This commit is contained in:
parent
1b2ec1aa72
commit
d0bd120814
9 changed files with 890 additions and 182 deletions
3
Makefile
3
Makefile
|
@ -236,8 +236,11 @@ ggml-vulkan.o: ggml-vulkan.cpp ggml-vulkan.h
|
||||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f32_aligned.glsl -o vk_shaders/matmul_f32_aligned.spv & \
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f32_aligned.glsl -o vk_shaders/matmul_f32_aligned.spv & \
|
||||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16.glsl -o vk_shaders/matmul_f16.spv & \
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16.glsl -o vk_shaders/matmul_f16.spv & \
|
||||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16_aligned.glsl -o vk_shaders/matmul_f16_aligned.spv & \
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16_aligned.glsl -o vk_shaders/matmul_f16_aligned.spv & \
|
||||||
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16_f32.glsl -o vk_shaders/matmul_f16_f32.spv & \
|
||||||
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16_f32_aligned.glsl -o vk_shaders/matmul_f16_f32_aligned.spv & \
|
||||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_split_k_reduce.glsl -o vk_shaders/matmul_split_k_reduce.spv & \
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_split_k_reduce.glsl -o vk_shaders/matmul_split_k_reduce.spv & \
|
||||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/f16_to_f32.glsl -o vk_shaders/f16_to_f32.spv & \
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/f16_to_f32.glsl -o vk_shaders/f16_to_f32.spv & \
|
||||||
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/f32_to_f16.glsl -o vk_shaders/f32_to_f16.spv & \
|
||||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_q4_0.glsl -o vk_shaders/dequant_q4_0.spv & \
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_q4_0.glsl -o vk_shaders/dequant_q4_0.spv & \
|
||||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_f16.glsl -o vk_shaders/dequant_mul_mat_vec_f16.spv & \
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_f16.glsl -o vk_shaders/dequant_mul_mat_vec_f16.spv & \
|
||||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_q4_0.glsl -o vk_shaders/dequant_mul_mat_vec_q4_0.spv & \
|
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_q4_0.glsl -o vk_shaders/dequant_mul_mat_vec_q4_0.spv & \
|
||||||
|
|
547
ggml-vulkan.cpp
547
ggml-vulkan.cpp
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
#ifdef VK_CHK_KERNEL
|
#ifdef VK_CHK_KERNEL
|
||||||
#include <cblas.h>
|
#include <cblas.h>
|
||||||
#include <cmath>
|
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -21,6 +20,7 @@
|
||||||
#include <vulkan/vulkan.hpp>
|
#include <vulkan/vulkan.hpp>
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
#include <cmath>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
@ -130,12 +130,16 @@ typedef std::vector<vk_submission> vk_sequence;
|
||||||
|
|
||||||
vk::Instance vk_instance;
|
vk::Instance vk_instance;
|
||||||
vk_device vk_device;
|
vk_device vk_device;
|
||||||
vk_pipeline vk_pipeline_matmul_f32_l, vk_pipeline_matmul_f32_m, vk_pipeline_matmul_f32_s, vk_pipeline_matmul_f16_l, vk_pipeline_matmul_f16_m, vk_pipeline_matmul_f16_s;
|
vk_pipeline vk_pipeline_matmul_f32_l, vk_pipeline_matmul_f32_m, vk_pipeline_matmul_f32_s;
|
||||||
vk_pipeline vk_pipeline_matmul_f32_aligned_l, vk_pipeline_matmul_f32_aligned_m, vk_pipeline_matmul_f32_aligned_s, vk_pipeline_matmul_f16_aligned_l, vk_pipeline_matmul_f16_aligned_m, vk_pipeline_matmul_f16_aligned_s;
|
vk_pipeline vk_pipeline_matmul_f32_aligned_l, vk_pipeline_matmul_f32_aligned_m, vk_pipeline_matmul_f32_aligned_s;
|
||||||
|
vk_pipeline vk_pipeline_matmul_f16_l, vk_pipeline_matmul_f16_m, vk_pipeline_matmul_f16_s;
|
||||||
|
vk_pipeline vk_pipeline_matmul_f16_aligned_l, vk_pipeline_matmul_f16_aligned_m, vk_pipeline_matmul_f16_aligned_s;
|
||||||
|
vk_pipeline vk_pipeline_matmul_f16_f32_l, vk_pipeline_matmul_f16_f32_m, vk_pipeline_matmul_f16_f32_s;
|
||||||
|
vk_pipeline vk_pipeline_matmul_f16_f32_aligned_l, vk_pipeline_matmul_f16_f32_aligned_m, vk_pipeline_matmul_f16_f32_aligned_s;
|
||||||
vk_pipeline vk_pipeline_matmul_split_k_reduce;
|
vk_pipeline vk_pipeline_matmul_split_k_reduce;
|
||||||
vk_pipeline vk_pipeline_dequant_mul_mat_vec_f16, vk_pipeline_dequant_mul_mat_vec_q4_0;
|
vk_pipeline vk_pipeline_dequant_mul_mat_vec_f16, vk_pipeline_dequant_mul_mat_vec_q4_0;
|
||||||
vk_pipeline vk_pipeline_mul_f32;
|
vk_pipeline vk_pipeline_mul_f32;
|
||||||
vk_pipeline vk_pipeline_f16_to_f32, vk_pipeline_dequant_q4_0;
|
vk_pipeline vk_pipeline_f32_to_f16, vk_pipeline_dequant_q4_0;
|
||||||
|
|
||||||
void * vk_pinned_workspace;
|
void * vk_pinned_workspace;
|
||||||
size_t vk_pinned_workspace_size;
|
size_t vk_pinned_workspace_size;
|
||||||
|
@ -322,18 +326,18 @@ static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_queue& q) {
|
||||||
return buf;
|
return buf;
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_submission ggml_vk_create_submission(vk_queue& q, std::vector<vk::Semaphore>&& wait_semaphores, std::vector<vk::Semaphore>&& signal_semaphores) {
|
static vk_submission ggml_vk_create_submission(vk_queue& q, std::vector<vk::Semaphore> wait_semaphores, std::vector<vk::Semaphore> signal_semaphores) {
|
||||||
#ifdef VK_DEBUG
|
#ifdef VK_DEBUG
|
||||||
std::cerr << "ggml_vk_create_submission()" << std::endl;
|
std::cerr << "ggml_vk_create_submission()" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
vk_submission s;
|
vk_submission s;
|
||||||
s.buffer = ggml_vk_create_cmd_buffer(q);
|
s.buffer = ggml_vk_create_cmd_buffer(q);
|
||||||
s.wait_semaphores = wait_semaphores;
|
s.wait_semaphores = std::move(wait_semaphores);
|
||||||
s.signal_semaphores = signal_semaphores;
|
s.signal_semaphores = std::move(signal_semaphores);
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_sequence ggml_vk_create_sequence_1(vk_queue& q, std::vector<vk::Semaphore>&& wait_semaphores, std::vector<vk::Semaphore>&& signal_semaphores) {
|
static vk_sequence ggml_vk_create_sequence_1(vk_queue& q, std::vector<vk::Semaphore> wait_semaphores, std::vector<vk::Semaphore> signal_semaphores) {
|
||||||
#ifdef VK_DEBUG
|
#ifdef VK_DEBUG
|
||||||
std::cerr << "ggml_vk_create_sequence_1()" << std::endl;
|
std::cerr << "ggml_vk_create_sequence_1()" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
|
@ -585,6 +589,7 @@ void ggml_vk_test_transfer(size_t ne);
|
||||||
void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k, size_t num_it, int split_k, int shader_size);
|
void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k, size_t num_it, int split_k, int shader_size);
|
||||||
void ggml_vk_test_matmul_f16(size_t m, size_t n, size_t k, size_t num_it, int split_k, int shader_size);
|
void ggml_vk_test_matmul_f16(size_t m, size_t n, size_t k, size_t num_it, int split_k, int shader_size);
|
||||||
void ggml_vk_test_buffer_write_zeropad(size_t m, size_t k, size_t align);
|
void ggml_vk_test_buffer_write_zeropad(size_t m, size_t k, size_t align);
|
||||||
|
void ggml_vk_test_f32_to_f16(size_t m, size_t k);
|
||||||
|
|
||||||
void ggml_vk_init(void) {
|
void ggml_vk_init(void) {
|
||||||
#ifdef VK_DEBUG
|
#ifdef VK_DEBUG
|
||||||
|
@ -738,10 +743,17 @@ void ggml_vk_init(void) {
|
||||||
vk_pipeline_matmul_f16_aligned_l = ggml_vk_create_pipeline("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
vk_pipeline_matmul_f16_aligned_l = ggml_vk_create_pipeline("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||||
vk_pipeline_matmul_f16_aligned_m = ggml_vk_create_pipeline("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
vk_pipeline_matmul_f16_aligned_m = ggml_vk_create_pipeline("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||||
vk_pipeline_matmul_f16_aligned_s = ggml_vk_create_pipeline("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
vk_pipeline_matmul_f16_aligned_s = ggml_vk_create_pipeline("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||||
|
|
||||||
|
vk_pipeline_matmul_f16_f32_l = ggml_vk_create_pipeline("vk_shaders/matmul_f16_f32.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||||
|
vk_pipeline_matmul_f16_f32_m = ggml_vk_create_pipeline("vk_shaders/matmul_f16_f32.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||||
|
vk_pipeline_matmul_f16_f32_s = ggml_vk_create_pipeline("vk_shaders/matmul_f16_f32.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||||
|
vk_pipeline_matmul_f16_f32_aligned_l = ggml_vk_create_pipeline("vk_shaders/matmul_f16_f32_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||||
|
vk_pipeline_matmul_f16_f32_aligned_m = ggml_vk_create_pipeline("vk_shaders/matmul_f16_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||||
|
vk_pipeline_matmul_f16_f32_aligned_s = ggml_vk_create_pipeline("vk_shaders/matmul_f16_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||||
}
|
}
|
||||||
vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline("vk_shaders/matmul_split_k_reduce.spv", "main", 1, 3 * sizeof(int), {32, 32, 1}, {}, 1);
|
vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline("vk_shaders/matmul_split_k_reduce.spv", "main", 1, 3 * sizeof(int), {32, 32, 1}, {}, 1);
|
||||||
|
|
||||||
vk_pipeline_f16_to_f32 = ggml_vk_create_pipeline("vk_shaders/f16_to_f32.spv", "main", 2, 4 * sizeof(int), {64, 1, 1}, {}, 1);
|
vk_pipeline_f32_to_f16 = ggml_vk_create_pipeline("vk_shaders/f32_to_f16.spv", "main", 2, 4 * sizeof(int), {64, 1, 1}, {}, 1);
|
||||||
vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline("vk_shaders/dequant_q4_0.spv", "main", 2, 4 * sizeof(int), {256*32, 1, 1}, {}, 1);
|
vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline("vk_shaders/dequant_q4_0.spv", "main", 2, 4 * sizeof(int), {256*32, 1, 1}, {}, 1);
|
||||||
|
|
||||||
vk_pipeline_dequant_mul_mat_vec_f16 = ggml_vk_create_pipeline("vk_shaders/dequant_mul_mat_vec_f16.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
vk_pipeline_dequant_mul_mat_vec_f16 = ggml_vk_create_pipeline("vk_shaders/dequant_mul_mat_vec_f16.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||||
|
@ -766,6 +778,11 @@ void ggml_vk_init(void) {
|
||||||
ggml_vk_test_buffer_write_zeropad(233, 97, 1);
|
ggml_vk_test_buffer_write_zeropad(233, 97, 1);
|
||||||
ggml_vk_test_buffer_write_zeropad(256, 128, 1);
|
ggml_vk_test_buffer_write_zeropad(256, 128, 1);
|
||||||
|
|
||||||
|
ggml_vk_test_f32_to_f16(214, 256);
|
||||||
|
ggml_vk_test_f32_to_f16(256, 2048);
|
||||||
|
ggml_vk_test_f32_to_f16(24, 1000);
|
||||||
|
ggml_vk_test_f32_to_f16(24, 24);
|
||||||
|
|
||||||
int step = 16;
|
int step = 16;
|
||||||
for (size_t m = step; m < 64; m += step) {
|
for (size_t m = step; m < 64; m += step) {
|
||||||
ggml_vk_test_transfer(1024 * 1024 * m);
|
ggml_vk_test_transfer(1024 * 1024 * m);
|
||||||
|
@ -809,23 +826,15 @@ void ggml_vk_init(void) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_pipeline* ggml_vk_get_to_fp32(ggml_type type) {
|
static vk_pipeline* ggml_vk_get_to_fp16(ggml_type type) {
|
||||||
#ifdef VK_DEBUG
|
#ifdef VK_DEBUG
|
||||||
std::cerr << "ggml_vk_get_to_fp32()" << std::endl;
|
std::cerr << "ggml_vk_get_to_fp16()" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
return &vk_pipeline_dequant_q4_0;
|
return &vk_pipeline_dequant_q4_0;
|
||||||
// case GGML_TYPE_Q4_1:
|
case GGML_TYPE_F32:
|
||||||
// return &dequantize_row_q4_1_cl;
|
return &vk_pipeline_f32_to_f16;
|
||||||
// case GGML_TYPE_Q5_0:
|
|
||||||
// return &dequantize_row_q5_0_cl;
|
|
||||||
// case GGML_TYPE_Q5_1:
|
|
||||||
// return &dequantize_row_q5_1_cl;
|
|
||||||
// case GGML_TYPE_Q8_0:
|
|
||||||
// return &dequantize_row_q8_0_cl;
|
|
||||||
case GGML_TYPE_F16:
|
|
||||||
return &vk_pipeline_f16_to_f32;
|
|
||||||
default:
|
default:
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -838,26 +847,8 @@ static vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_type type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
return &vk_pipeline_dequant_mul_mat_vec_q4_0;
|
return &vk_pipeline_dequant_mul_mat_vec_q4_0;
|
||||||
// case GGML_TYPE_Q4_1:
|
|
||||||
// return &dequantize_mul_mat_vec_q4_1_cl;
|
|
||||||
// case GGML_TYPE_Q5_0:
|
|
||||||
// return &dequantize_mul_mat_vec_q5_0_cl;
|
|
||||||
// case GGML_TYPE_Q5_1:
|
|
||||||
// return &dequantize_mul_mat_vec_q5_1_cl;
|
|
||||||
// case GGML_TYPE_Q8_0:
|
|
||||||
// return &dequantize_mul_mat_vec_q8_0_cl;
|
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
return &vk_pipeline_dequant_mul_mat_vec_f16;
|
return &vk_pipeline_dequant_mul_mat_vec_f16;
|
||||||
// case GGML_TYPE_Q2_K:
|
|
||||||
// return &dequantize_mul_mat_vec_q2_K_cl;
|
|
||||||
// case GGML_TYPE_Q3_K:
|
|
||||||
// return &dequantize_mul_mat_vec_q3_K_cl;
|
|
||||||
// case GGML_TYPE_Q4_K:
|
|
||||||
// return &dequantize_mul_mat_vec_q4_K_cl;
|
|
||||||
// case GGML_TYPE_Q5_K:
|
|
||||||
// return &dequantize_mul_mat_vec_q5_K_cl;
|
|
||||||
// case GGML_TYPE_Q6_K:
|
|
||||||
// return &dequantize_mul_mat_vec_q6_K_cl;
|
|
||||||
default:
|
default:
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -1019,6 +1010,7 @@ static void ggml_vk_dispatch_pipeline(vk_submission& s, vk_pipeline& pipeline, s
|
||||||
std::vector<vk::DescriptorBufferInfo> descriptor_buffer_infos;
|
std::vector<vk::DescriptorBufferInfo> descriptor_buffer_infos;
|
||||||
std::vector<vk::WriteDescriptorSet> write_descriptor_sets;
|
std::vector<vk::WriteDescriptorSet> write_descriptor_sets;
|
||||||
vk::DescriptorSet& descriptor_set = pipeline.descriptor_sets[pipeline.descriptor_set_index++];
|
vk::DescriptorSet& descriptor_set = pipeline.descriptor_sets[pipeline.descriptor_set_index++];
|
||||||
|
GGML_ASSERT(descriptor_set != nullptr);
|
||||||
for (uint32_t i = 0; i < pipeline.parameter_count; i++) {
|
for (uint32_t i = 0; i < pipeline.parameter_count; i++) {
|
||||||
descriptor_buffer_infos.push_back({buffers[i].buffer.buffer, buffers[i].offset, buffers[i].size});
|
descriptor_buffer_infos.push_back({buffers[i].buffer.buffer, buffers[i].offset, buffers[i].size});
|
||||||
}
|
}
|
||||||
|
@ -1038,14 +1030,14 @@ static void ggml_vk_dispatch_pipeline(vk_submission& s, vk_pipeline& pipeline, s
|
||||||
s.buffer.dispatch(wg0, wg1, wg2);
|
s.buffer.dispatch(wg0, wg1, wg2);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_end_submission(vk_submission& s, std::vector<vk::Semaphore>&& wait_semaphores, std::vector<vk::Semaphore>&& signal_semaphores) {
|
static void ggml_vk_end_submission(vk_submission& s, std::vector<vk::Semaphore> wait_semaphores, std::vector<vk::Semaphore> signal_semaphores) {
|
||||||
s.buffer.end();
|
s.buffer.end();
|
||||||
|
|
||||||
s.wait_semaphores = wait_semaphores;
|
s.wait_semaphores = std::move(wait_semaphores);
|
||||||
s.signal_semaphores = signal_semaphores;
|
s.signal_semaphores = std::move(signal_semaphores);
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_sequence ggml_vk_buffer_write_2d_async(vk_buffer* dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, vk_queue& q, std::vector<vk::Semaphore>&& wait_semaphores, std::vector<vk::Semaphore>&& signal_semaphores) {
|
static vk_sequence ggml_vk_buffer_write_2d_async(vk_buffer* dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, vk_queue& q, std::vector<vk::Semaphore> wait_semaphores, std::vector<vk::Semaphore> signal_semaphores) {
|
||||||
#ifdef VK_DEBUG
|
#ifdef VK_DEBUG
|
||||||
std::cerr << "ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")" << std::endl;
|
std::cerr << "ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
|
@ -1143,7 +1135,7 @@ static inline size_t ggml_vk_align_size(size_t width, size_t align) {
|
||||||
return CEIL_DIV(width, align) * align;
|
return CEIL_DIV(width, align) * align;
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_sequence ggml_vk_buffer_write_2d_async_zeropad(vk_buffer* dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, size_t align, vk_queue& q, std::vector<vk::Semaphore>&& wait_semaphores, std::vector<vk::Semaphore>&& signal_semaphores) {
|
static vk_sequence ggml_vk_buffer_write_2d_async_zeropad(vk_buffer* dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, size_t align, vk_queue& q, std::vector<vk::Semaphore> wait_semaphores, std::vector<vk::Semaphore> signal_semaphores) {
|
||||||
#ifdef VK_DEBUG
|
#ifdef VK_DEBUG
|
||||||
std::cerr << "ggml_vk_buffer_write_2d_async_zeropad(" << offset << ", " << spitch << ", " << width << ", " << height << ", " << align << ")" << std::endl;
|
std::cerr << "ggml_vk_buffer_write_2d_async_zeropad(" << offset << ", " << spitch << ", " << width << ", " << height << ", " << align << ")" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
|
@ -1238,7 +1230,7 @@ static vk_sequence ggml_vk_buffer_write_2d_async_zeropad(vk_buffer* dst, size_t
|
||||||
return { s };
|
return { s };
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_sequence ggml_vk_buffer_write_async(vk_buffer* dst, size_t offset, const void * src, size_t size, vk_queue& q, std::vector<vk::Semaphore>&& wait_semaphores, std::vector<vk::Semaphore>&& signal_semaphores) {
|
static vk_sequence ggml_vk_buffer_write_async(vk_buffer* dst, size_t offset, const void * src, size_t size, vk_queue& q, std::vector<vk::Semaphore> wait_semaphores, std::vector<vk::Semaphore> signal_semaphores) {
|
||||||
#ifdef VK_DEBUG
|
#ifdef VK_DEBUG
|
||||||
std::cerr << "ggml_vk_buffer_write_async(" << size << ")" << std::endl;
|
std::cerr << "ggml_vk_buffer_write_async(" << size << ")" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
|
@ -1252,7 +1244,7 @@ static void ggml_vk_buffer_write(vk_buffer* dst, size_t offset, const void * src
|
||||||
ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1, q);
|
ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1, q);
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_sequence ggml_vk_buffer_read_async(vk_buffer* src, size_t offset, void * dst, size_t size, vk_queue& q, std::vector<vk::Semaphore>&& wait_semaphores, std::vector<vk::Semaphore>&& signal_semaphores) {
|
static vk_sequence ggml_vk_buffer_read_async(vk_buffer* src, size_t offset, void * dst, size_t size, vk_queue& q, std::vector<vk::Semaphore> wait_semaphores, std::vector<vk::Semaphore> signal_semaphores) {
|
||||||
#ifdef VK_DEBUG
|
#ifdef VK_DEBUG
|
||||||
std::cerr << "ggml_vk_buffer_read_async(" << size << ")" << std::endl;
|
std::cerr << "ggml_vk_buffer_read_async(" << size << ")" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
|
@ -1360,7 +1352,7 @@ static void ggml_vk_buffer_read(vk_buffer* src, size_t offset, void * dst, size_
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_sequence ggml_vk_h2d_tensor_2d(vk_buffer* dst, size_t offset, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, vk_queue& q, std::vector<vk::Semaphore>&& wait_semaphores, std::vector<vk::Semaphore>&& signal_semaphores) {
|
static vk_sequence ggml_vk_h2d_tensor_2d(vk_buffer* dst, size_t offset, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, vk_queue& q, std::vector<vk::Semaphore> wait_semaphores, std::vector<vk::Semaphore> signal_semaphores) {
|
||||||
#ifdef VK_DEBUG
|
#ifdef VK_DEBUG
|
||||||
std::cerr << "ggml_vk_h2d_tensor_2d()" << std::endl;
|
std::cerr << "ggml_vk_h2d_tensor_2d()" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
|
@ -1393,14 +1385,87 @@ static vk_sequence ggml_vk_h2d_tensor_2d(vk_buffer* dst, size_t offset, const st
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static vk_sequence ggml_vk_h2d_tensor_2d_f32_to_f16(vk_buffer* dst, size_t offset, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, vk_queue& q, std::vector<vk::Semaphore> wait_semaphores, std::vector<vk::Semaphore> signal_semaphores) {
|
||||||
|
#ifdef VK_DEBUG
|
||||||
|
std::cerr << "ggml_vk_h2d_tensor_2d()" << std::endl;
|
||||||
|
#endif
|
||||||
|
GGML_ASSERT(src->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
const uint64_t ne0 = src->ne[0];
|
||||||
|
const uint64_t ne1 = src->ne[1];
|
||||||
|
const uint64_t nb0 = src->nb[0];
|
||||||
|
const uint64_t nb1 = src->nb[1];
|
||||||
|
const uint64_t nb2 = src->nb[2];
|
||||||
|
const uint64_t nb3 = src->nb[3];
|
||||||
|
const enum ggml_type type = src->type;
|
||||||
|
const size_t ts = ggml_type_size(type);
|
||||||
|
const size_t bs = ggml_blck_size(type);
|
||||||
|
const size_t row_length = ts*ne0/bs;
|
||||||
|
|
||||||
|
const uint32_t copy_size = sizeof(ggml_fp16_t) * ne0 * ne1;
|
||||||
|
|
||||||
|
if (dst->sb_write == nullptr) {
|
||||||
|
dst->sb_write = new vk_buffer;
|
||||||
|
*dst->sb_write = ggml_vk_create_buffer(dst->size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_fp16_t * tmp = (ggml_fp16_t *) ((uint8_t *) dst->sb_write->ptr + offset);
|
||||||
|
const uint8_t * x = (const uint8_t *) src->data + i2*nb2 + i3*nb3;
|
||||||
|
if (nb0 == ts && nb1 == row_length) {
|
||||||
|
ggml_fp32_to_fp16_row((const float *) x, tmp, ne0*ne1);
|
||||||
|
|
||||||
|
vk_submission s = ggml_vk_create_submission(q, std::move(wait_semaphores), std::move(signal_semaphores));
|
||||||
|
|
||||||
|
vk::BufferCopy buf_copy = {
|
||||||
|
offset,
|
||||||
|
offset,
|
||||||
|
copy_size,
|
||||||
|
};
|
||||||
|
|
||||||
|
s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit });
|
||||||
|
ggml_vk_sync_buffers(s.buffer, { { *dst, (uint32_t)offset, copy_size } }, q, vk::AccessFlagBits::eMemoryRead, vk::AccessFlagBits::eTransferWrite, false);
|
||||||
|
s.buffer.copyBuffer(dst->sb_write->buffer, dst->buffer, { buf_copy });
|
||||||
|
s.buffer.end();
|
||||||
|
|
||||||
|
return { s };
|
||||||
|
}
|
||||||
|
if (nb0 == ts) {
|
||||||
|
for (uint64_t i1 = 0; i1 < ne1; i1++) {
|
||||||
|
ggml_fp32_to_fp16_row((const float *) (x + i1*nb1), tmp + i1*ne0, ne0);
|
||||||
|
}
|
||||||
|
|
||||||
|
vk_submission s = ggml_vk_create_submission(q, std::move(wait_semaphores), std::move(signal_semaphores));
|
||||||
|
|
||||||
|
vk::BufferCopy buf_copy = {
|
||||||
|
offset,
|
||||||
|
offset,
|
||||||
|
copy_size,
|
||||||
|
};
|
||||||
|
|
||||||
|
s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit });
|
||||||
|
ggml_vk_sync_buffers(s.buffer, { { *dst, (uint32_t)offset, copy_size } }, q, vk::AccessFlagBits::eMemoryRead, vk::AccessFlagBits::eTransferWrite, false);
|
||||||
|
s.buffer.copyBuffer(dst->sb_write->buffer, dst->buffer, { buf_copy });
|
||||||
|
s.buffer.end();
|
||||||
|
|
||||||
|
return { s };
|
||||||
|
}
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
|
||||||
static int ggml_vk_guess_split_k(int m, int n, int k) {
|
static int ggml_vk_guess_split_k(int m, int n, int k) {
|
||||||
#ifdef VK_DEBUG
|
#ifdef VK_DEBUG
|
||||||
std::cerr << "ggml_vk_guess_split_k()" << std::endl;
|
std::cerr << "ggml_vk_guess_split_k()";
|
||||||
#endif
|
#endif
|
||||||
if (k > 128 && (m < 128 || n < 128)) {
|
if (k > 128 && (m < 128 || n < 128)) {
|
||||||
|
#ifdef VK_DEBUG
|
||||||
|
std::cerr << " = 4" << std::endl;
|
||||||
|
#endif
|
||||||
return 4;
|
return 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef VK_DEBUG
|
||||||
|
std::cerr << " = 1" << std::endl;
|
||||||
|
#endif
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1417,30 +1482,69 @@ static uint32_t ggml_vk_guess_matmul_pipeline_align(int m, int n) {
|
||||||
return vk_pipeline_matmul_f32_l.align;
|
return vk_pipeline_matmul_f32_l.align;
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_pipeline* ggml_vk_guess_matmul_pipeline(bool bit16, int m, int n, bool aligned) {
|
static vk_pipeline* ggml_vk_guess_matmul_pipeline(bool bit16_x, bool bit16_y, int m, int n, bool aligned) {
|
||||||
#ifdef VK_DEBUG
|
#ifdef VK_DEBUG
|
||||||
std::cerr << "ggml_vk_guess_matmul_pipeline()" << std::endl;
|
std::cerr << "ggml_vk_guess_matmul_pipeline(" << bit16 << ", " << m << ", " << n << ", " << aligned << ")";
|
||||||
#endif
|
#endif
|
||||||
if (bit16) {
|
if (bit16_x && bit16_y) {
|
||||||
if (m <= 32 || n <= 32) {
|
if (m <= 32 || n <= 32) {
|
||||||
|
#ifdef VK_DEBUG
|
||||||
|
std::cerr << " S" << std::endl;
|
||||||
|
#endif
|
||||||
return aligned ? &vk_pipeline_matmul_f16_aligned_s : &vk_pipeline_matmul_f16_s;
|
return aligned ? &vk_pipeline_matmul_f16_aligned_s : &vk_pipeline_matmul_f16_s;
|
||||||
}
|
}
|
||||||
if (m <= 64 || n <= 64) {
|
if (m <= 64 || n <= 64) {
|
||||||
|
#ifdef VK_DEBUG
|
||||||
|
std::cerr << " M" << std::endl;
|
||||||
|
#endif
|
||||||
return aligned ? &vk_pipeline_matmul_f16_aligned_m : &vk_pipeline_matmul_f16_m;
|
return aligned ? &vk_pipeline_matmul_f16_aligned_m : &vk_pipeline_matmul_f16_m;
|
||||||
}
|
}
|
||||||
|
#ifdef VK_DEBUG
|
||||||
|
std::cerr << " L" << std::endl;
|
||||||
|
#endif
|
||||||
return aligned ? &vk_pipeline_matmul_f16_aligned_l : &vk_pipeline_matmul_f16_l;
|
return aligned ? &vk_pipeline_matmul_f16_aligned_l : &vk_pipeline_matmul_f16_l;
|
||||||
}
|
}
|
||||||
|
if (bit16_x && !bit16_y) {
|
||||||
|
if (m <= 32 || n <= 32) {
|
||||||
|
#ifdef VK_DEBUG
|
||||||
|
std::cerr << " S" << std::endl;
|
||||||
|
#endif
|
||||||
|
return aligned ? &vk_pipeline_matmul_f16_f32_aligned_s : &vk_pipeline_matmul_f16_f32_s;
|
||||||
|
}
|
||||||
|
if (m <= 64 || n <= 64) {
|
||||||
|
#ifdef VK_DEBUG
|
||||||
|
std::cerr << " M" << std::endl;
|
||||||
|
#endif
|
||||||
|
return aligned ? &vk_pipeline_matmul_f16_f32_aligned_m : &vk_pipeline_matmul_f16_f32_m;
|
||||||
|
}
|
||||||
|
#ifdef VK_DEBUG
|
||||||
|
std::cerr << " L" << std::endl;
|
||||||
|
#endif
|
||||||
|
return aligned ? &vk_pipeline_matmul_f16_f32_aligned_l : &vk_pipeline_matmul_f16_f32_l;
|
||||||
|
}
|
||||||
|
if (!bit16_x && bit16_y) {
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
|
||||||
if (m <= 32 || n <= 32) {
|
if (m <= 32 || n <= 32) {
|
||||||
|
#ifdef VK_DEBUG
|
||||||
|
std::cerr << " S" << std::endl;
|
||||||
|
#endif
|
||||||
return aligned ? &vk_pipeline_matmul_f32_aligned_s : &vk_pipeline_matmul_f32_s;
|
return aligned ? &vk_pipeline_matmul_f32_aligned_s : &vk_pipeline_matmul_f32_s;
|
||||||
}
|
}
|
||||||
if (m <= 64 || n <= 64) {
|
if (m <= 64 || n <= 64) {
|
||||||
|
#ifdef VK_DEBUG
|
||||||
|
std::cerr << " M" << std::endl;
|
||||||
|
#endif
|
||||||
return aligned ? &vk_pipeline_matmul_f32_aligned_m : &vk_pipeline_matmul_f32_m;
|
return aligned ? &vk_pipeline_matmul_f32_aligned_m : &vk_pipeline_matmul_f32_m;
|
||||||
}
|
}
|
||||||
|
#ifdef VK_DEBUG
|
||||||
|
std::cerr << " L" << std::endl;
|
||||||
|
#endif
|
||||||
return aligned ? &vk_pipeline_matmul_f32_aligned_l : &vk_pipeline_matmul_f32_l;
|
return aligned ? &vk_pipeline_matmul_f32_aligned_l : &vk_pipeline_matmul_f32_l;
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_sequence ggml_vk_matmul(vk_pipeline& pipeline, vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, int m, int n, int k, int stride_a, int stride_b, int stride_d, int split_k, vk_queue& q, std::vector<vk::Semaphore>&& wait_semaphores, std::vector<vk::Semaphore>&& signal_semaphores) {
|
static vk_sequence ggml_vk_matmul(vk_pipeline& pipeline, vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, int m, int n, int k, int stride_a, int stride_b, int stride_d, int split_k, vk_queue& q, std::vector<vk::Semaphore> wait_semaphores, std::vector<vk::Semaphore> signal_semaphores) {
|
||||||
#ifdef VK_DEBUG
|
#ifdef VK_DEBUG
|
||||||
std::cerr << "ggml_vk_matmul(" << m << ", " << n << ", " << k << ")" << std::endl;
|
std::cerr << "ggml_vk_matmul(" << m << ", " << n << ", " << k << ")" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
|
@ -1490,10 +1594,10 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
|
|
||||||
const int kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ne01, ne11));
|
const int kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ne01, ne11));
|
||||||
|
|
||||||
vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(false, ne01, ne11, ne10 == kpad);
|
vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(false, false, ne01, ne11, ne10 == kpad);
|
||||||
|
|
||||||
const uint32_t x_sz = ggml_vk_align_size(sizeof(ggml_fp16_t) * x_ne, vk_device.properties.limits.minStorageBufferOffsetAlignment);
|
const uint32_t x_sz = ggml_vk_align_size(sizeof(float) * x_ne, vk_device.properties.limits.minStorageBufferOffsetAlignment);
|
||||||
const uint32_t y_sz = ggml_vk_align_size(sizeof(ggml_fp16_t) * y_ne, vk_device.properties.limits.minStorageBufferOffsetAlignment);
|
const uint32_t y_sz = ggml_vk_align_size(sizeof(float) * y_ne, vk_device.properties.limits.minStorageBufferOffsetAlignment);
|
||||||
const uint32_t d_sz = ggml_vk_align_size(sizeof(float) * d_ne * split_k, vk_device.properties.limits.minStorageBufferOffsetAlignment);
|
const uint32_t d_sz = ggml_vk_align_size(sizeof(float) * d_ne * split_k, vk_device.properties.limits.minStorageBufferOffsetAlignment);
|
||||||
|
|
||||||
vk_buffer d_X;
|
vk_buffer d_X;
|
||||||
|
@ -1608,7 +1712,7 @@ static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
|
|
||||||
const int kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ne01, ne11));
|
const int kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ne01, ne11));
|
||||||
|
|
||||||
vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(true, ne01, ne11, ne10 == kpad);
|
vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(true, true, ne01, ne11, ne10 == kpad);
|
||||||
|
|
||||||
const uint32_t x_sz = ggml_vk_align_size(sizeof(ggml_fp16_t) * x_ne, vk_device.properties.limits.minStorageBufferOffsetAlignment);
|
const uint32_t x_sz = ggml_vk_align_size(sizeof(ggml_fp16_t) * x_ne, vk_device.properties.limits.minStorageBufferOffsetAlignment);
|
||||||
const uint32_t y_sz = ggml_vk_align_size(sizeof(ggml_fp16_t) * y_ne, vk_device.properties.limits.minStorageBufferOffsetAlignment);
|
const uint32_t y_sz = ggml_vk_align_size(sizeof(ggml_fp16_t) * y_ne, vk_device.properties.limits.minStorageBufferOffsetAlignment);
|
||||||
|
@ -1672,30 +1776,7 @@ static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_vk_submit(vk_device.transfer_queues[0], transfer_0_seqs, VK_NULL_HANDLE);
|
ggml_vk_submit(vk_device.transfer_queues[0], transfer_0_seqs, VK_NULL_HANDLE);
|
||||||
|
transfer_1_seqs.push_back(ggml_vk_h2d_tensor_2d_f32_to_f16(&d_Y, y_offset, src1, i03, i02, vk_device.transfer_queues[1], {}, { s_y }));
|
||||||
// convert src1 to fp16
|
|
||||||
// TODO: use multiple threads
|
|
||||||
ggml_fp16_t * const tmp = fp16_staging + (ne11 * ne10) * (i03 * ne02 + i02);
|
|
||||||
char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
|
|
||||||
if (src1_cont_rows) {
|
|
||||||
if (src1_cont_cols) {
|
|
||||||
ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
for (int64_t i01 = 0; i01 < ne11; i01++) {
|
|
||||||
ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int64_t i01 = 0; i01 < ne11; i01++) {
|
|
||||||
for (int64_t i00 = 0; i00 < ne10; i00++) {
|
|
||||||
// very slow due to no inlining
|
|
||||||
tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
transfer_1_seqs.push_back(ggml_vk_buffer_write_async(&d_Y, y_offset, tmp, sizeof(ggml_fp16_t) * y_ne, vk_device.transfer_queues[1], {}, { s_y }));
|
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
vk::Semaphore s_mm = ggml_vk_create_semaphore(vk_device.compute_queue);
|
vk::Semaphore s_mm = ggml_vk_create_semaphore(vk_device.compute_queue);
|
||||||
|
@ -1703,10 +1784,33 @@ static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
|
|
||||||
// copy dst to host
|
// copy dst to host
|
||||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||||
transfer_0_seqs.push_back(ggml_vk_buffer_read_async(&d_D, d_offset, d, sizeof(float) * d_ne, vk_device.transfer_queues[0], { s_mm }, {}));
|
float * d_chk = (float *) ggml_vk_host_malloc(sizeof(float) * d_ne);
|
||||||
|
transfer_0_seqs.push_back(ggml_vk_buffer_read_async(&d_D, d_offset, d_chk, sizeof(float) * d_ne, vk_device.transfer_queues[0], { s_mm }, {}));
|
||||||
|
|
||||||
ggml_vk_submit(vk_device.transfer_queues[1], transfer_1_seqs, VK_NULL_HANDLE);
|
ggml_vk_submit(vk_device.transfer_queues[1], transfer_1_seqs, VK_NULL_HANDLE);
|
||||||
ggml_vk_submit(vk_device.compute_queue, compute_seqs, VK_NULL_HANDLE);
|
ggml_vk_submit(vk_device.compute_queue, compute_seqs, VK_NULL_HANDLE);
|
||||||
|
|
||||||
|
//DEBUG
|
||||||
|
ggml_vk_submit(vk_device.transfer_queues[0], transfer_0_seqs, VK_NULL_HANDLE);
|
||||||
|
vk_device.transfer_queues[0].queue.waitIdle();
|
||||||
|
|
||||||
|
double err = 0.0;
|
||||||
|
|
||||||
|
for (int i = 0; i < d_ne; i++) {
|
||||||
|
double abs_err = fabs(d[i] - d_chk[i]);
|
||||||
|
err += abs_err;
|
||||||
|
}
|
||||||
|
|
||||||
|
err /= d_ne;
|
||||||
|
|
||||||
|
if (err > 0.01) {
|
||||||
|
std::cerr << "ggml_vk_mul_mat_f16((type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3];
|
||||||
|
std::cerr << "), (type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3];
|
||||||
|
std::cerr << "), (type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << "),)" << std::endl;
|
||||||
|
std::cerr << "MUL_MAT_F16 i02=" << i02 << " i03=" << i03 << " avg_err=" << err << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_vk_host_free(d_chk);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1728,9 +1832,9 @@ static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
||||||
ggml_vk_pool_free(d_D);
|
ggml_vk_pool_free(d_D);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_vk_mul_mat_q_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
#ifdef VK_DEBUG
|
#ifdef VK_DEBUG
|
||||||
std::cerr << "ggml_vk_mul_mat_q_f32((type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3];
|
std::cerr << "ggml_vk_mul_mat_q_f16((type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3];
|
||||||
std::cerr << "), (type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3];
|
std::cerr << "), (type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3];
|
||||||
std::cerr << "), (type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << "),)" << std::endl;
|
std::cerr << "), (type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << "),)" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
|
@ -1744,8 +1848,16 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
|
||||||
|
|
||||||
const int nb2 = dst->nb[2];
|
const int nb2 = dst->nb[2];
|
||||||
const int nb3 = dst->nb[3];
|
const int nb3 = dst->nb[3];
|
||||||
const ggml_type type = src0->type;
|
const bool mul_mat_vec = ne11 == 1 && src0->type != GGML_TYPE_F16;
|
||||||
const bool mul_mat_vec = ne11 == 1;
|
|
||||||
|
const bool f16_f32_kernel = src1->type == GGML_TYPE_F32 && !mul_mat_vec;
|
||||||
|
|
||||||
|
const bool qx_needs_dequant = src0->type != GGML_TYPE_F16 && !mul_mat_vec;
|
||||||
|
const bool qy_needs_dequant = src1->type != GGML_TYPE_F16 && !f16_f32_kernel;
|
||||||
|
const bool dq = qx_needs_dequant || qy_needs_dequant;
|
||||||
|
|
||||||
|
const bool load_x = src0->backend != GGML_BACKEND_GPU;
|
||||||
|
const bool load_y = src1->backend != GGML_BACKEND_GPU;
|
||||||
|
|
||||||
const int x_ne = ne01 * ne00;
|
const int x_ne = ne01 * ne00;
|
||||||
const int y_ne = ne11 * ne10;
|
const int y_ne = ne11 * ne10;
|
||||||
|
@ -1755,114 +1867,148 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
|
||||||
|
|
||||||
const int kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ne01, ne11));
|
const int kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ne01, ne11));
|
||||||
|
|
||||||
vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(false, ne01, ne11, ne10 == kpad);
|
vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(true, !f16_f32_kernel, ne01, ne11, ne10 == kpad);
|
||||||
|
|
||||||
const uint32_t q_sz = ggml_vk_align_size(ggml_type_size(type) * x_ne / ggml_blck_size(type), vk_device.properties.limits.minStorageBufferOffsetAlignment);
|
const uint32_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), vk_device.properties.limits.minStorageBufferOffsetAlignment);
|
||||||
const uint32_t x_sz = ggml_vk_align_size(sizeof(float) * x_ne, vk_device.properties.limits.minStorageBufferOffsetAlignment);
|
const uint32_t qy_sz = ggml_vk_align_size(ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type), vk_device.properties.limits.minStorageBufferOffsetAlignment);
|
||||||
const uint32_t y_sz = ggml_vk_align_size(sizeof(float) * y_ne, vk_device.properties.limits.minStorageBufferOffsetAlignment);
|
const uint32_t x_sz = ggml_vk_align_size(sizeof(ggml_fp16_t) * x_ne, vk_device.properties.limits.minStorageBufferOffsetAlignment);
|
||||||
|
const uint32_t y_sz = ggml_vk_align_size(f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne, vk_device.properties.limits.minStorageBufferOffsetAlignment);
|
||||||
const uint32_t d_sz = ggml_vk_align_size(sizeof(float) * d_ne * split_k, vk_device.properties.limits.minStorageBufferOffsetAlignment);
|
const uint32_t d_sz = ggml_vk_align_size(sizeof(float) * d_ne * split_k, vk_device.properties.limits.minStorageBufferOffsetAlignment);
|
||||||
|
|
||||||
vk_buffer d_Q;
|
vk_buffer d_Qx;
|
||||||
if (src0->backend == GGML_BACKEND_CPU) {
|
if (load_x) {
|
||||||
ggml_vk_pool_malloc(q_sz, &d_Q, {});
|
ggml_vk_pool_malloc(qx_sz * ne02 * ne03, &d_Qx, {});
|
||||||
} else {
|
} else {
|
||||||
d_Q = *(vk_buffer *) src0->data;
|
d_Qx = *(vk_buffer *) src0->data;
|
||||||
|
}
|
||||||
|
vk_buffer d_Qy;
|
||||||
|
if (load_y) {
|
||||||
|
ggml_vk_pool_malloc(qy_sz * ne02 * ne03, &d_Qy, {});
|
||||||
|
} else {
|
||||||
|
d_Qy = *(vk_buffer *) src1->data;
|
||||||
}
|
}
|
||||||
vk_buffer d_X;
|
vk_buffer d_X;
|
||||||
vk_buffer d_Y;
|
vk_buffer d_Y;
|
||||||
vk_buffer d_D;
|
vk_buffer d_D;
|
||||||
if (!mul_mat_vec) {
|
if (qx_needs_dequant) {
|
||||||
ggml_vk_pool_malloc(x_sz, &d_X, {});
|
ggml_vk_pool_malloc(x_sz * ne02 * ne03, &d_X, {});
|
||||||
|
} else {
|
||||||
|
d_X = d_Qx;
|
||||||
|
GGML_ASSERT(qx_sz == x_sz || mul_mat_vec); // NOLINT
|
||||||
}
|
}
|
||||||
ggml_vk_pool_malloc(y_sz, &d_Y, {});
|
if (qy_needs_dequant) {
|
||||||
ggml_vk_pool_malloc(d_sz, &d_D, {});
|
ggml_vk_pool_malloc(y_sz * ne02 * ne03, &d_Y, {});
|
||||||
|
} else {
|
||||||
|
d_Y = d_Qy;
|
||||||
|
GGML_ASSERT(qy_sz == y_sz);
|
||||||
|
}
|
||||||
|
ggml_vk_pool_malloc(d_sz * ne02 * ne03, &d_D, {});
|
||||||
|
|
||||||
vk_pipeline* to_fp32_vk = ggml_vk_get_to_fp32(type);
|
vk_pipeline* to_fp16_vk_0 = ggml_vk_get_to_fp16(src0->type);
|
||||||
vk_pipeline* dmmv = ggml_vk_get_dequantize_mul_mat_vec(type);
|
vk_pipeline* to_fp16_vk_1 = ggml_vk_get_to_fp16(src1->type);
|
||||||
GGML_ASSERT(to_fp32_vk != nullptr);
|
vk_pipeline* dmmv = ggml_vk_get_dequantize_mul_mat_vec(src0->type);
|
||||||
|
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
|
||||||
|
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
|
||||||
GGML_ASSERT(dmmv != nullptr);
|
GGML_ASSERT(dmmv != nullptr);
|
||||||
|
|
||||||
std::vector<vk_sequence> compute_seqs;
|
std::vector<vk_sequence> compute_seqs;
|
||||||
std::vector<vk_sequence> transfer_0_seqs;
|
std::vector<vk_sequence> transfer_0_seqs;
|
||||||
std::vector<vk_sequence> transfer_1_seqs;
|
std::vector<vk_sequence> transfer_1_seqs;
|
||||||
|
|
||||||
const bool load_x = src0->backend != GGML_BACKEND_GPU;
|
|
||||||
|
|
||||||
// Allocate descriptor sets
|
// Allocate descriptor sets
|
||||||
ggml_vk_pipeline_allocate_descriptor_sets(*pipeline, ne02 * ne03);
|
ggml_vk_pipeline_allocate_descriptor_sets(*pipeline, ne02 * ne03);
|
||||||
ggml_vk_pipeline_allocate_descriptor_sets(*to_fp32_vk, ne02 * ne03);
|
if (qx_needs_dequant) {
|
||||||
ggml_vk_pipeline_allocate_descriptor_sets(*dmmv, ne02 * ne03);
|
ggml_vk_pipeline_allocate_descriptor_sets(*to_fp16_vk_0, ne02 * ne03);
|
||||||
|
}
|
||||||
|
if (qy_needs_dequant) {
|
||||||
|
ggml_vk_pipeline_allocate_descriptor_sets(*to_fp16_vk_1, ne02 * ne03);
|
||||||
|
}
|
||||||
|
if (mul_mat_vec) {
|
||||||
|
ggml_vk_pipeline_allocate_descriptor_sets(*dmmv, ne02 * ne03);
|
||||||
|
}
|
||||||
if (split_k > 1) {
|
if (split_k > 1) {
|
||||||
ggml_vk_pipeline_allocate_descriptor_sets(vk_pipeline_matmul_split_k_reduce, ne02 * ne03);
|
ggml_vk_pipeline_allocate_descriptor_sets(vk_pipeline_matmul_split_k_reduce, ne02 * ne03);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
const uint32_t q_offset = load_x ? q_sz * (i03 * ne02 + i02) : 0;
|
const uint32_t qx_offset = load_x ? qx_sz * (i03 * ne02 + i02) : 0;
|
||||||
|
const uint32_t qy_offset = load_y ? qy_sz * (i03 * ne02 + i02) : 0;
|
||||||
const uint32_t x_offset = x_sz * (i03 * ne02 + i02);
|
const uint32_t x_offset = x_sz * (i03 * ne02 + i02);
|
||||||
const uint32_t y_offset = y_sz * (i03 * ne02 + i02);
|
const uint32_t y_offset = y_sz * (i03 * ne02 + i02);
|
||||||
const uint32_t d_offset = d_sz * (i03 * ne02 + i02);
|
const uint32_t d_offset = d_sz * (i03 * ne02 + i02);
|
||||||
|
|
||||||
vk::Semaphore s_x;
|
vk::Semaphore s_x;
|
||||||
vk::Semaphore s_y = ggml_vk_create_semaphore(vk_device.transfer_queues[0]);
|
vk::Semaphore s_y;
|
||||||
vk::Semaphore s_q = ggml_vk_create_semaphore(vk_device.transfer_queues[0]);
|
vk::Semaphore s_q;
|
||||||
|
|
||||||
|
const vk::Semaphore s_mm = ggml_vk_create_semaphore(vk_device.compute_queue);
|
||||||
|
|
||||||
std::vector<vk::Semaphore> q_semaphores;
|
std::vector<vk::Semaphore> q_semaphores;
|
||||||
|
std::vector<vk::Semaphore> mm_semaphores;
|
||||||
|
|
||||||
vk::Semaphore s_mm = ggml_vk_create_semaphore(vk_device.compute_queue);
|
|
||||||
|
|
||||||
// copy src0 to device if necessary
|
|
||||||
if (load_x) {
|
if (load_x) {
|
||||||
s_x = ggml_vk_create_semaphore(vk_device.compute_queue);
|
s_x = ggml_vk_create_semaphore(vk_device.transfer_queues[0]);
|
||||||
q_semaphores.push_back(s_x);
|
if (qx_needs_dequant) {
|
||||||
transfer_0_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_Q, q_offset, src0, i03, i02, vk_device.transfer_queues[0], {}, { s_x }));
|
q_semaphores.push_back(s_x);
|
||||||
|
} else {
|
||||||
|
mm_semaphores.push_back(s_x);
|
||||||
|
}
|
||||||
|
transfer_0_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_Qx, qx_offset, src0, i03, i02, vk_device.transfer_queues[0], {}, { s_x }));
|
||||||
|
}
|
||||||
|
if (load_y) {
|
||||||
|
s_y = ggml_vk_create_semaphore(vk_device.transfer_queues[1]);
|
||||||
|
if (qy_needs_dequant) {
|
||||||
|
q_semaphores.push_back(s_y);
|
||||||
|
} else {
|
||||||
|
mm_semaphores.push_back(s_y);
|
||||||
|
}
|
||||||
|
transfer_1_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_Qy, qy_offset, src1, i03, i02, vk_device.transfer_queues[1], {}, { s_y }));
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_vk_submit(vk_device.transfer_queues[0], transfer_0_seqs, VK_NULL_HANDLE);
|
ggml_vk_submit(vk_device.transfer_queues[0], transfer_0_seqs, VK_NULL_HANDLE);
|
||||||
|
ggml_vk_submit(vk_device.transfer_queues[1], transfer_1_seqs, VK_NULL_HANDLE);
|
||||||
|
|
||||||
// copy src1 to device
|
if (dq) {
|
||||||
transfer_1_seqs.push_back(ggml_vk_h2d_tensor_2d(&d_Y, y_offset, src1, i03, i02, vk_device.transfer_queues[1], {}, { s_y }));
|
s_q = ggml_vk_create_semaphore(vk_device.compute_queue);
|
||||||
|
|
||||||
if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
|
|
||||||
// // compute
|
|
||||||
// const size_t global = ne01 * VK_DMMV_BLOCK_SIZE;
|
|
||||||
// const size_t local = VK_DMMV_BLOCK_SIZE;
|
|
||||||
// const vk_int ncols = ne00;
|
|
||||||
// events.emplace_back();
|
|
||||||
// VK_CHECK(vkSetKernelArg(*dmmv, 0, sizeof(vk_buffer), &d_Q));
|
|
||||||
// VK_CHECK(vkSetKernelArg(*dmmv, 1, sizeof(float) * local, NULL));
|
|
||||||
// VK_CHECK(vkSetKernelArg(*dmmv, 2, sizeof(vk_buffer), &d_Y));
|
|
||||||
// VK_CHECK(vkSetKernelArg(*dmmv, 3, sizeof(vk_buffer), &d_D));
|
|
||||||
// VK_CHECK(vkSetKernelArg(*dmmv, 4, sizeof(vk_int), &ncols));
|
|
||||||
// VK_CHECK(vkEnqueueNDRangeKernel(queue, *dmmv, 1, NULL, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++));
|
|
||||||
q_semaphores.push_back(s_y);
|
|
||||||
const int ncols = ne00;
|
|
||||||
vk_submission s = ggml_vk_begin_submission(vk_device.compute_queue);
|
vk_submission s = ggml_vk_begin_submission(vk_device.compute_queue);
|
||||||
ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(d_Q), ggml_vk_subbuffer(d_Y) }, vk_device.compute_queue, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, false);
|
if (qx_needs_dequant) {
|
||||||
ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(d_D) }, vk_device.compute_queue, vk::AccessFlagBits::eShaderRead, vk::AccessFlagBits::eShaderWrite, false);
|
const std::vector<int> pc = { (int)ne01, (int)ne10, (int)ne10, (int)ne10 };
|
||||||
ggml_vk_dispatch_pipeline(s, *dmmv, { { d_Q, q_offset, q_sz }, { d_Y, y_offset, y_sz }, { d_D, d_offset, d_sz } }, sizeof(int), &ncols, { (uint32_t)ne01, 1, 1});
|
ggml_vk_sync_buffers(s.buffer, { { d_Qx, qx_offset, qx_sz } }, vk_device.compute_queue, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, false);
|
||||||
ggml_vk_end_submission(s, std::move(q_semaphores), { s_mm });
|
ggml_vk_sync_buffers(s.buffer, { { d_X, x_offset, x_sz } }, vk_device.compute_queue, vk::AccessFlagBits::eShaderRead, vk::AccessFlagBits::eShaderWrite, false);
|
||||||
compute_seqs.push_back({ s });
|
ggml_vk_dispatch_pipeline(s, *to_fp16_vk_0, { { d_Qx, qx_offset, qx_sz }, { d_X, x_offset, x_sz } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)x_ne, 1, 1});
|
||||||
} else { // general dequantization kernel + VK matrix matrix multiplication
|
}
|
||||||
|
|
||||||
// convert src0 to fp32 on device
|
if (qy_needs_dequant) {
|
||||||
vk_submission s = ggml_vk_begin_submission(vk_device.compute_queue);
|
const std::vector<int> pc = { (int)ne11, (int)ne10, (int)ne10, (int)ne10 };
|
||||||
const std::vector<int> pc = { (int)ne01, (int)ne10, (int)ne10, (int)ne10 };
|
ggml_vk_sync_buffers(s.buffer, { { d_Qy, qy_offset, qy_sz } }, vk_device.compute_queue, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, false);
|
||||||
ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(d_Q) }, vk_device.compute_queue, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, false);
|
ggml_vk_sync_buffers(s.buffer, { { d_Y, y_offset, y_sz } }, vk_device.compute_queue, vk::AccessFlagBits::eShaderRead, vk::AccessFlagBits::eShaderWrite, false);
|
||||||
ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(d_X) }, vk_device.compute_queue, vk::AccessFlagBits::eShaderRead, vk::AccessFlagBits::eShaderWrite, false);
|
ggml_vk_dispatch_pipeline(s, *to_fp16_vk_1, { { d_Qy, qy_offset, qy_sz }, { d_Y, y_offset, y_sz } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)y_ne, 1, 1});
|
||||||
ggml_vk_dispatch_pipeline(s, *to_fp32_vk, { { d_Q, q_offset, q_sz }, { d_X, x_offset, x_sz } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)x_ne, 1, 1});
|
}
|
||||||
ggml_vk_end_submission(s, std::move(q_semaphores), { s_q });
|
ggml_vk_end_submission(s, std::move(q_semaphores), { s_q });
|
||||||
compute_seqs.push_back({ s });
|
compute_seqs.push_back({ s });
|
||||||
|
|
||||||
|
mm_semaphores.push_back(s_q);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
|
||||||
// compute
|
// compute
|
||||||
compute_seqs.push_back(ggml_vk_matmul(*pipeline, { d_X, x_offset, x_sz }, { d_Y, y_offset, y_sz }, { d_D, d_offset, d_sz }, ne01, ne11, ne10, ne10, ne10, ne01, split_k, vk_device.compute_queue, { s_q, s_y }, { s_mm }));
|
const int ncols = ne00;
|
||||||
|
vk_submission s = ggml_vk_begin_submission(vk_device.compute_queue);
|
||||||
|
ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(d_Qx), ggml_vk_subbuffer(d_Y) }, vk_device.compute_queue, vk::AccessFlagBits::eMemoryWrite, vk::AccessFlagBits::eShaderRead, false);
|
||||||
|
ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(d_D) }, vk_device.compute_queue, vk::AccessFlagBits::eShaderRead, vk::AccessFlagBits::eShaderWrite, false);
|
||||||
|
ggml_vk_dispatch_pipeline(s, *dmmv, { { d_Qx, qx_offset, qx_sz }, { d_Y, y_offset, y_sz }, { d_D, d_offset, d_sz } }, sizeof(int), &ncols, { (uint32_t)ne01, 1, 1});
|
||||||
|
ggml_vk_end_submission(s, std::move(mm_semaphores), { s_mm });
|
||||||
|
compute_seqs.push_back({ s });
|
||||||
|
} else {
|
||||||
|
// compute
|
||||||
|
compute_seqs.push_back(ggml_vk_matmul(*pipeline, { d_X, x_offset, x_sz }, { d_Y, y_offset, y_sz }, { d_D, d_offset, d_sz }, ne01, ne11, ne10, ne10, ne10, ne01, split_k, vk_device.compute_queue, std::move(mm_semaphores), { s_mm }));
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy dst to host
|
// copy dst to host
|
||||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||||
transfer_0_seqs.push_back(ggml_vk_buffer_read_async(&d_D, d_offset, d, sizeof(float) * d_ne, vk_device.transfer_queues[0], { s_mm }, {}));
|
transfer_0_seqs.push_back(ggml_vk_buffer_read_async(&d_D, d_offset, d, sizeof(float) * d_ne, vk_device.transfer_queues[0], { s_mm }, {}));
|
||||||
|
|
||||||
ggml_vk_submit(vk_device.transfer_queues[1], transfer_1_seqs, VK_NULL_HANDLE);
|
|
||||||
ggml_vk_submit(vk_device.compute_queue, compute_seqs, VK_NULL_HANDLE);
|
ggml_vk_submit(vk_device.compute_queue, compute_seqs, VK_NULL_HANDLE);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1876,17 +2022,27 @@ static void ggml_vk_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
|
||||||
ggml_vk_queue_cleanup(vk_device.compute_queue);
|
ggml_vk_queue_cleanup(vk_device.compute_queue);
|
||||||
|
|
||||||
ggml_vk_pipeline_cleanup(*pipeline);
|
ggml_vk_pipeline_cleanup(*pipeline);
|
||||||
ggml_vk_pipeline_cleanup(*to_fp32_vk);
|
if (qx_needs_dequant) {
|
||||||
|
ggml_vk_pipeline_cleanup(*to_fp16_vk_0);
|
||||||
|
}
|
||||||
|
if (qy_needs_dequant) {
|
||||||
|
ggml_vk_pipeline_cleanup(*to_fp16_vk_1);
|
||||||
|
}
|
||||||
ggml_vk_pipeline_cleanup(*dmmv);
|
ggml_vk_pipeline_cleanup(*dmmv);
|
||||||
ggml_vk_pipeline_cleanup(vk_pipeline_matmul_split_k_reduce);
|
ggml_vk_pipeline_cleanup(vk_pipeline_matmul_split_k_reduce);
|
||||||
|
|
||||||
if (!mul_mat_vec) {
|
if (qx_needs_dequant) {
|
||||||
ggml_vk_pool_free(d_X);
|
ggml_vk_pool_free(d_X);
|
||||||
}
|
}
|
||||||
ggml_vk_pool_free(d_Y);
|
if (qy_needs_dequant) {
|
||||||
|
ggml_vk_pool_free(d_Y);
|
||||||
|
}
|
||||||
ggml_vk_pool_free(d_D);
|
ggml_vk_pool_free(d_D);
|
||||||
if (src0->backend == GGML_BACKEND_CPU) {
|
if (load_x) {
|
||||||
ggml_vk_pool_free(d_Q);
|
ggml_vk_pool_free(d_Qx);
|
||||||
|
}
|
||||||
|
if (load_y) {
|
||||||
|
ggml_vk_pool_free(d_Qy);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1899,7 +2055,7 @@ static bool ggml_vk_can_mul_mat(const struct ggml_tensor * src0, const struct gg
|
||||||
|
|
||||||
// TODO: find the optimal values for these
|
// TODO: find the optimal values for these
|
||||||
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
|
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
|
||||||
src1->type == GGML_TYPE_F32 &&
|
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || ggml_is_quantized(src1->type)) &&
|
||||||
dst->type == GGML_TYPE_F32 &&
|
dst->type == GGML_TYPE_F32 &&
|
||||||
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_GPU)) {
|
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_GPU)) {
|
||||||
return true;
|
return true;
|
||||||
|
@ -1939,19 +2095,9 @@ static void ggml_vk_mul_mat(const struct ggml_tensor * src0, const struct ggml_t
|
||||||
|
|
||||||
if (src0->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32) {
|
||||||
ggml_vk_mul_mat_f32(src0, src1, dst);
|
ggml_vk_mul_mat_f32(src0, src1, dst);
|
||||||
}
|
} else if (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) {
|
||||||
else if (src0->type == GGML_TYPE_F16) {
|
ggml_vk_mul_mat_q_f16(src0, src1, dst);
|
||||||
if (ggml_vk_mul_mat_use_f16(src0, src1, dst)) {
|
} else {
|
||||||
ggml_vk_mul_mat_f16(src0, src1, dst);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
ggml_vk_mul_mat_q_f32(src0, src1, dst);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if (ggml_is_quantized(src0->type)) {
|
|
||||||
ggml_vk_mul_mat_q_f32(src0, src1, dst);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2055,7 +2201,7 @@ static void ggml_vk_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1,
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
static void ggml_vk_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); // NOLINT
|
||||||
ggml_vk_mul_f32(src0, src1, dst);
|
ggml_vk_mul_f32(src0, src1, dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2068,7 +2214,7 @@ void ggml_vk_transform_tensor(void * data, ggml_tensor * tensor) {
|
||||||
const int64_t ne2 = tensor->ne[2];
|
const int64_t ne2 = tensor->ne[2];
|
||||||
const int64_t ne3 = tensor->ne[3];
|
const int64_t ne3 = tensor->ne[3];
|
||||||
|
|
||||||
GGML_ASSERT(ne2 == 1 && ne3 == 1);
|
GGML_ASSERT(ne2 == 1 && ne3 == 1); // NOLINT
|
||||||
|
|
||||||
const ggml_type type = tensor->type;
|
const ggml_type type = tensor->type;
|
||||||
const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
|
const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
|
||||||
|
@ -2175,6 +2321,77 @@ void ggml_vk_test_transfer(size_t ne) {
|
||||||
free(x);
|
free(x);
|
||||||
free(y);
|
free(y);
|
||||||
}
|
}
|
||||||
|
void ggml_vk_test_f32_to_f16(size_t m, size_t k) {
|
||||||
|
#ifdef VK_DEBUG
|
||||||
|
std::cerr << "ggml_vk_test_transfer(" << ne << ")" << std::endl;
|
||||||
|
#endif
|
||||||
|
// Check transfers are correct
|
||||||
|
const uint32_t ne = m * k;
|
||||||
|
vk_buffer d_X = ggml_vk_create_buffer(sizeof(float) * ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
||||||
|
vk_buffer d_Y = ggml_vk_create_buffer(sizeof(ggml_fp16_t) * ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
||||||
|
|
||||||
|
float* x = (float *) malloc(sizeof(float) * ne);
|
||||||
|
ggml_fp16_t* y = (ggml_fp16_t *) malloc(sizeof(ggml_fp16_t) * ne);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < ne; i++) {
|
||||||
|
x[i] = rand() / (float)RAND_MAX;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_vk_pipeline_allocate_descriptor_sets(vk_pipeline_f32_to_f16, 1);
|
||||||
|
|
||||||
|
auto begin = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
|
ggml_vk_buffer_write(&d_X, 0, x, sizeof(float) * ne, vk_device.transfer_queues[0]);
|
||||||
|
|
||||||
|
vk_device.transfer_queues[0].queue.waitIdle();
|
||||||
|
|
||||||
|
auto end = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
|
double ms_to_gpu = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
|
||||||
|
|
||||||
|
begin = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
|
std::vector<vk_sequence> seqs;
|
||||||
|
vk_submission s = ggml_vk_begin_submission(vk_device.compute_queue);
|
||||||
|
const std::vector<int> pc = { (int)m, (int)k, (int)k, (int)k };
|
||||||
|
ggml_vk_sync_buffers(s.buffer, { { d_X, 0, (uint32_t)sizeof(float) * ne } }, vk_device.compute_queue, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, false);
|
||||||
|
ggml_vk_sync_buffers(s.buffer, { { d_Y, 0, (uint32_t)sizeof(ggml_fp16_t) * ne} }, vk_device.compute_queue, vk::AccessFlagBits::eShaderRead, vk::AccessFlagBits::eShaderWrite, false);
|
||||||
|
ggml_vk_dispatch_pipeline(s, vk_pipeline_f32_to_f16, { { d_X, 0, (uint32_t)sizeof(float) * ne }, { d_Y, 0, (uint32_t)sizeof(ggml_fp16_t) * ne } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1});
|
||||||
|
ggml_vk_end_submission(s, {}, {});
|
||||||
|
seqs.push_back({ s });
|
||||||
|
|
||||||
|
ggml_vk_submit(vk_device.compute_queue, seqs, VK_NULL_HANDLE);
|
||||||
|
|
||||||
|
vk_device.compute_queue.queue.waitIdle();
|
||||||
|
|
||||||
|
end = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
|
double ms_convert = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
|
||||||
|
|
||||||
|
begin = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
|
ggml_vk_buffer_read(&d_Y, 0, y, sizeof(ggml_fp16_t) * ne, vk_device.transfer_queues[1]);
|
||||||
|
|
||||||
|
end = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
|
double ms_from_gpu = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
|
||||||
|
|
||||||
|
double avg_err = 0.0;
|
||||||
|
for (size_t i = 0; i < ne; i++) {
|
||||||
|
avg_err += std::fabs(x[i] - ggml_fp16_to_fp32(y[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cerr << "TEST F32 TO F16 " << ms_to_gpu << "ms to_gpu " << ms_convert << "ms convert " << ms_from_gpu << "ms from gpu avg_err=" << avg_err / ne << std::endl;
|
||||||
|
|
||||||
|
ggml_vk_destroy_buffer(d_X);
|
||||||
|
ggml_vk_destroy_buffer(d_Y);
|
||||||
|
|
||||||
|
ggml_vk_pipeline_cleanup(vk_pipeline_f32_to_f16);
|
||||||
|
|
||||||
|
free(x);
|
||||||
|
free(y);
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k, size_t num_it, int split_k, int shader_size) {
|
void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k, size_t num_it, int split_k, int shader_size) {
|
||||||
#ifdef VK_DEBUG
|
#ifdef VK_DEBUG
|
||||||
std::cerr << "ggml_vk_test_matmul_f32(" << m << ", " << n << ", " << k << ", " << num_it << ", " << split_k << ", " << shader_size << ")" << std::endl;
|
std::cerr << "ggml_vk_test_matmul_f32(" << m << ", " << n << ", " << k << ", " << num_it << ", " << split_k << ", " << shader_size << ")" << std::endl;
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A { float16_t x[]; };
|
layout (binding = 0) readonly buffer A { float16_t x[]; };
|
||||||
layout (binding = 1) readonly buffer B { float y[]; };
|
layout (binding = 1) readonly buffer B { float16_t y[]; };
|
||||||
layout (binding = 2) writeonly buffer D { float dst[]; };
|
layout (binding = 2) writeonly buffer D { float dst[]; };
|
||||||
|
|
||||||
layout (push_constant) uniform parameter
|
layout (push_constant) uniform parameter
|
||||||
|
@ -19,7 +19,7 @@ layout (push_constant) uniform parameter
|
||||||
int ncols;
|
int ncols;
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
shared float tmp[BLOCK_SIZE];
|
shared float16_t tmp[BLOCK_SIZE];
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const int block_size = int(gl_WorkGroupSize.x);
|
const int block_size = int(gl_WorkGroupSize.x);
|
||||||
|
@ -28,7 +28,7 @@ void main() {
|
||||||
|
|
||||||
const int y_offset = QUANT_K/2;
|
const int y_offset = QUANT_K/2;
|
||||||
|
|
||||||
tmp[tid] = 0;
|
tmp[tid] = 0.0hf;
|
||||||
|
|
||||||
[[unroll]] for (int i = 0; i < p.ncols/block_size; i += 2) {
|
[[unroll]] for (int i = 0; i < p.ncols/block_size; i += 2) {
|
||||||
const int col = i*block_size + 2*tid;
|
const int col = i*block_size + 2*tid;
|
||||||
|
@ -37,8 +37,8 @@ void main() {
|
||||||
const int iybs = col - col%QUANT_K; // y block start index
|
const int iybs = col - col%QUANT_K; // y block start index
|
||||||
|
|
||||||
// dequantize
|
// dequantize
|
||||||
float v0 = float(x[ib + 0]);
|
float16_t v0 = x[ib + 0];
|
||||||
float v1 = float(x[ib + 1]);
|
float16_t v1 = x[ib + 1];
|
||||||
|
|
||||||
// matrix multiplication
|
// matrix multiplication
|
||||||
tmp[tid] += v0 * y[iybs + iqs + 0];
|
tmp[tid] += v0 * y[iybs + iqs + 0];
|
||||||
|
@ -54,6 +54,6 @@ void main() {
|
||||||
barrier();
|
barrier();
|
||||||
}
|
}
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
dst[row] = tmp[0];
|
dst[row] = float(tmp[0]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@ struct block_q4_0
|
||||||
};
|
};
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A { block_q4_0 x[]; };
|
layout (binding = 0) readonly buffer A { block_q4_0 x[]; };
|
||||||
layout (binding = 1) readonly buffer B { float y[]; };
|
layout (binding = 1) readonly buffer B { float16_t y[]; };
|
||||||
layout (binding = 2) writeonly buffer D { float dst[]; };
|
layout (binding = 2) writeonly buffer D { float dst[]; };
|
||||||
|
|
||||||
layout (push_constant) uniform parameter
|
layout (push_constant) uniform parameter
|
||||||
|
@ -26,7 +26,7 @@ layout (push_constant) uniform parameter
|
||||||
int ncols;
|
int ncols;
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
shared float tmp[BLOCK_SIZE];
|
shared float16_t tmp[BLOCK_SIZE];
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const int block_size = int(gl_WorkGroupSize.x);
|
const int block_size = int(gl_WorkGroupSize.x);
|
||||||
|
@ -35,7 +35,7 @@ void main() {
|
||||||
|
|
||||||
const int y_offset = QUANT_K/2;
|
const int y_offset = QUANT_K/2;
|
||||||
|
|
||||||
tmp[tid] = 0;
|
tmp[tid] = 0.0hf;
|
||||||
|
|
||||||
[[unroll]] for (int i = 0; i < p.ncols/block_size; i += 2) {
|
[[unroll]] for (int i = 0; i < p.ncols/block_size; i += 2) {
|
||||||
const int col = i*block_size + 2*tid;
|
const int col = i*block_size + 2*tid;
|
||||||
|
@ -44,15 +44,15 @@ void main() {
|
||||||
const int iybs = col - col%QUANT_K; // y block start index
|
const int iybs = col - col%QUANT_K; // y block start index
|
||||||
|
|
||||||
// dequantize
|
// dequantize
|
||||||
const float d = float(x[ib].d);
|
const float16_t d = x[ib].d;
|
||||||
|
|
||||||
const uint8_t vui = x[ib].qs[iqs];
|
const uint8_t vui = x[ib].qs[iqs];
|
||||||
|
|
||||||
const int8_t vi0 = int8_t(vui & 0xF);
|
const int8_t vi0 = int8_t(vui & 0xF);
|
||||||
const int8_t vi1 = int8_t(vui >> 4);
|
const int8_t vi1 = int8_t(vui >> 4);
|
||||||
|
|
||||||
float v0 = (vi0 - 8)*d;
|
float16_t v0 = float16_t(vi0 - 8)*d;
|
||||||
float v1 = (vi1 - 8)*d;
|
float16_t v1 = float16_t(vi1 - 8)*d;
|
||||||
|
|
||||||
// matrix multiplication
|
// matrix multiplication
|
||||||
tmp[tid] += v0 * y[iybs + iqs + 0];
|
tmp[tid] += v0 * y[iybs + iqs + 0];
|
||||||
|
@ -68,6 +68,6 @@ void main() {
|
||||||
barrier();
|
barrier();
|
||||||
}
|
}
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
dst[row] = tmp[0];
|
dst[row] = float(tmp[0]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,7 @@ struct block_q4_0
|
||||||
};
|
};
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A { block_q4_0 x[]; };
|
layout (binding = 0) readonly buffer A { block_q4_0 x[]; };
|
||||||
layout (binding = 1) writeonly buffer D { float y[]; };
|
layout (binding = 1) writeonly buffer D { float16_t y[]; };
|
||||||
|
|
||||||
layout (push_constant) uniform parameter
|
layout (push_constant) uniform parameter
|
||||||
{
|
{
|
||||||
|
@ -41,11 +41,11 @@ void main() {
|
||||||
const int stride_a = p.stride_a / QUANT_K;
|
const int stride_a = p.stride_a / QUANT_K;
|
||||||
|
|
||||||
const block_q4_0 blk = x[col * stride_a + row];
|
const block_q4_0 blk = x[col * stride_a + row];
|
||||||
const float d = float(blk.d);
|
const float16_t d = blk.d;
|
||||||
|
|
||||||
[[unroll]] for (int j = 0; j < QUANT_K/2; ++j) {
|
[[unroll]] for (int j = 0; j < QUANT_K/2; ++j) {
|
||||||
const int x0 = (blk.qs[j] & 0x0F) - 8;
|
const float16_t x0 = float16_t((blk.qs[j] & 0x0F) - 8);
|
||||||
const int x1 = (blk.qs[j] >> 4) - 8;
|
const float16_t x1 = float16_t((blk.qs[j] >> 4) - 8);
|
||||||
|
|
||||||
y[col * p.stride_b + row*QUANT_K + j + 0 ] = x0*d;
|
y[col * p.stride_b + row*QUANT_K + j + 0 ] = x0*d;
|
||||||
y[col * p.stride_b + row*QUANT_K + j + QUANT_K/2] = x1*d;
|
y[col * p.stride_b + row*QUANT_K + j + QUANT_K/2] = x1*d;
|
||||||
|
|
25
vk_shaders/f32_to_f16.glsl
Normal file
25
vk_shaders/f32_to_f16.glsl
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_shader_16bit_storage : require
|
||||||
|
|
||||||
|
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer A { float data_a[]; };
|
||||||
|
layout (binding = 1) writeonly buffer D { float16_t data_b[]; };
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter
|
||||||
|
{
|
||||||
|
int M;
|
||||||
|
int K;
|
||||||
|
int stride_a;
|
||||||
|
int stride_b;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const int row = int(gl_GlobalInvocationID.x % p.K);
|
||||||
|
const int col = int(gl_GlobalInvocationID.x / p.K);
|
||||||
|
|
||||||
|
if (row < p.K && col < p.M) {
|
||||||
|
data_b[col * p.stride_b + row] = float16_t(data_a[col * p.stride_a + row]);
|
||||||
|
}
|
||||||
|
}
|
145
vk_shaders/matmul_f16_f32.glsl
Normal file
145
vk_shaders/matmul_f16_f32.glsl
Normal file
|
@ -0,0 +1,145 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#define WARP 32
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||||
|
|
||||||
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer A { float16_t data_a[]; };
|
||||||
|
layout (binding = 1) readonly buffer B { float data_b[]; };
|
||||||
|
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter
|
||||||
|
{
|
||||||
|
int M;
|
||||||
|
int N;
|
||||||
|
int K;
|
||||||
|
int stride_a;
|
||||||
|
int stride_b;
|
||||||
|
int stride_d;
|
||||||
|
int k_split;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
layout (constant_id = 1) const int BM = 64;
|
||||||
|
layout (constant_id = 2) const int BN = 64;
|
||||||
|
layout (constant_id = 3) const int BK = 16;
|
||||||
|
layout (constant_id = 4) const int WM = 32;
|
||||||
|
layout (constant_id = 5) const int WN = 32;
|
||||||
|
layout (constant_id = 6) const int WMITER = 2;
|
||||||
|
layout (constant_id = 7) const int TM = 4;
|
||||||
|
layout (constant_id = 8) const int TN = 2;
|
||||||
|
|
||||||
|
shared float16_t buf_a[BM * (BK+1)];
|
||||||
|
shared float16_t buf_b[BN * (BK+1)];
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const int blocks_x = (p.M + BM - 1) / BM;
|
||||||
|
const int ir = int(gl_WorkGroupID.x) % blocks_x;
|
||||||
|
const int ik = int(gl_WorkGroupID.x) / blocks_x;
|
||||||
|
const int ic = int(gl_WorkGroupID.y);
|
||||||
|
|
||||||
|
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
||||||
|
const int warp_r = warp_i % (BM / WM);
|
||||||
|
const int warp_c = warp_i / (BM / WM);
|
||||||
|
|
||||||
|
const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||||
|
const int WSUBM = WM / WMITER;
|
||||||
|
const int WSUBN = WN / WNITER;
|
||||||
|
|
||||||
|
const int tiw = int(gl_LocalInvocationID.x % WARP);
|
||||||
|
const int tiwr = tiw % (WSUBM / TM);
|
||||||
|
const int tiwc = tiw / (WSUBM / TM);
|
||||||
|
|
||||||
|
const int loadr = int(gl_LocalInvocationID.x % BK);
|
||||||
|
const int loadc = int(gl_LocalInvocationID.x / BK);
|
||||||
|
|
||||||
|
const int loadstride = int(gl_WorkGroupSize.x);
|
||||||
|
|
||||||
|
const int start_k = ik * p.k_split;
|
||||||
|
const int end_k = (ik + 1) * p.k_split;
|
||||||
|
|
||||||
|
int pos_a = ir * BM * p.stride_a + start_k;
|
||||||
|
int pos_b = ic * BN * p.stride_b + start_k;
|
||||||
|
|
||||||
|
float sums[WMITER * TM * WNITER * TN];
|
||||||
|
float16_t cache_a[WMITER * TM];
|
||||||
|
float16_t cache_b[WNITER * TN];
|
||||||
|
|
||||||
|
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||||
|
sums[i] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
|
||||||
|
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
|
||||||
|
const int lr = l % BK;
|
||||||
|
const int lc = l / BK;
|
||||||
|
if (ir * BM + loadc + lc < p.M && block + loadr + lr < p.K) {
|
||||||
|
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr];
|
||||||
|
} else {
|
||||||
|
buf_a[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
[[unroll]] for (int l = 0; l < BN * BK; l += loadstride) {
|
||||||
|
const int lr = l % BK;
|
||||||
|
const int lc = l / BK;
|
||||||
|
if (ic * BN + loadc + lc < p.N && block + loadr + lr < p.K) {
|
||||||
|
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = float16_t(data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr]);
|
||||||
|
} else {
|
||||||
|
buf_b[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
pos_a += BK;
|
||||||
|
pos_b += BK;
|
||||||
|
|
||||||
|
for (int i = 0; i < min(BK, p.K - block); i++) {
|
||||||
|
// Load from shared into cache
|
||||||
|
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||||
|
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||||
|
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||||
|
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||||
|
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||||
|
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||||
|
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||||
|
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||||
|
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
const int dr = ir * BM + warp_r * WM;
|
||||||
|
const int dc = ic * BN + warp_c * WN;
|
||||||
|
|
||||||
|
const int k_split_offset = ik * p.M * p.N;
|
||||||
|
|
||||||
|
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||||
|
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||||
|
|
||||||
|
const int dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
||||||
|
const int dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||||
|
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||||
|
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||||
|
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||||
|
data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
149
vk_shaders/matmul_f16_f32_aligned.glsl
Normal file
149
vk_shaders/matmul_f16_f32_aligned.glsl
Normal file
|
@ -0,0 +1,149 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#define WARP 32
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||||
|
|
||||||
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer A { f16mat2x4 data_a[]; };
|
||||||
|
layout (binding = 1) readonly buffer B { mat2x4 data_b[]; };
|
||||||
|
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter
|
||||||
|
{
|
||||||
|
int M;
|
||||||
|
int N;
|
||||||
|
int K;
|
||||||
|
int stride_a;
|
||||||
|
int stride_b;
|
||||||
|
int stride_d;
|
||||||
|
int k_split;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
layout (constant_id = 1) const int BM = 64;
|
||||||
|
layout (constant_id = 2) const int BN = 64;
|
||||||
|
layout (constant_id = 3) const int BK = 16;
|
||||||
|
layout (constant_id = 4) const int WM = 32;
|
||||||
|
layout (constant_id = 5) const int WN = 32;
|
||||||
|
layout (constant_id = 6) const int WMITER = 2;
|
||||||
|
layout (constant_id = 7) const int TM = 4;
|
||||||
|
layout (constant_id = 8) const int TN = 2;
|
||||||
|
|
||||||
|
shared float16_t buf_a[BM * (BK+1)];
|
||||||
|
shared float16_t buf_b[BN * (BK+1)];
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const int blocks_x = (p.M + BM - 1) / BM;
|
||||||
|
const int ir = int(gl_WorkGroupID.x) % blocks_x;
|
||||||
|
const int ik = int(gl_WorkGroupID.x) / blocks_x;
|
||||||
|
const int ic = int(gl_WorkGroupID.y);
|
||||||
|
|
||||||
|
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
||||||
|
const int warp_r = warp_i % (BM / WM);
|
||||||
|
const int warp_c = warp_i / (BM / WM);
|
||||||
|
|
||||||
|
const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||||
|
const int WSUBM = WM / WMITER;
|
||||||
|
const int WSUBN = WN / WNITER;
|
||||||
|
|
||||||
|
const int tiw = int(gl_LocalInvocationID.x % WARP);
|
||||||
|
const int tiwr = tiw % (WSUBM / TM);
|
||||||
|
const int tiwc = tiw / (WSUBM / TM);
|
||||||
|
|
||||||
|
const int loadr = int(gl_LocalInvocationID.x % (BK / 8));
|
||||||
|
const int loadc = int(gl_LocalInvocationID.x / (BK / 8));
|
||||||
|
|
||||||
|
const int loadstride = int(gl_WorkGroupSize.x * 8) / BK;
|
||||||
|
|
||||||
|
const int start_k = ik * p.k_split;
|
||||||
|
const int end_k = (ik + 1) * p.k_split;
|
||||||
|
|
||||||
|
int pos_a = ir * BM * p.stride_a / 8 + start_k / 8;
|
||||||
|
int pos_b = ic * BN * p.stride_b / 8 + start_k / 8;
|
||||||
|
|
||||||
|
float sums[WMITER * TM * WNITER * TN];
|
||||||
|
float16_t cache_a[WMITER * TM];
|
||||||
|
float16_t cache_b[WNITER * TN];
|
||||||
|
|
||||||
|
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||||
|
sums[i] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
|
||||||
|
[[unroll]] for (int l = 0; l < BM; l += loadstride) {
|
||||||
|
f16mat2x4 tmp = data_a[pos_a + (loadc + l) * p.stride_a / 8 + loadr];
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 0] = tmp[0].x;
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 1] = tmp[0].y;
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 2] = tmp[0].z;
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 3] = tmp[0].w;
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 4] = tmp[1].x;
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 5] = tmp[1].y;
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 6] = tmp[1].z;
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 8 + 7] = tmp[1].w;
|
||||||
|
}
|
||||||
|
[[unroll]] for (int l = 0; l < BN; l += loadstride) {
|
||||||
|
mat2x4 tmp = data_b[pos_b + (loadc + l) * p.stride_b / 8 + loadr];
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 0] = float16_t(tmp[0].x);
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 1] = float16_t(tmp[0].y);
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 2] = float16_t(tmp[0].z);
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 3] = float16_t(tmp[0].w);
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 4] = float16_t(tmp[1].x);
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 5] = float16_t(tmp[1].y);
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 6] = float16_t(tmp[1].z);
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 8 + 7] = float16_t(tmp[1].w);
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
pos_a += BK / 8;
|
||||||
|
pos_b += BK / 8;
|
||||||
|
|
||||||
|
for (int i = 0; i < min(BK, p.K - block); i++) {
|
||||||
|
// Load from shared into cache
|
||||||
|
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||||
|
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||||
|
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||||
|
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||||
|
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||||
|
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||||
|
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||||
|
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||||
|
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
const int dr = ir * BM + warp_r * WM;
|
||||||
|
const int dc = ic * BN + warp_c * WN;
|
||||||
|
|
||||||
|
const int k_split_offset = ik * p.M * p.N;
|
||||||
|
|
||||||
|
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||||
|
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||||
|
|
||||||
|
const int dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
||||||
|
const int dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||||
|
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||||
|
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||||
|
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||||
|
data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
169
vk_shaders/matmul_f32_q4_0.glsl
Normal file
169
vk_shaders/matmul_f32_q4_0.glsl
Normal file
|
@ -0,0 +1,169 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#define WARP 32
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
#extension GL_EXT_shader_16bit_storage : require
|
||||||
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||||
|
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||||
|
|
||||||
|
#define QUANT_K 32
|
||||||
|
#define QUANT_R 2
|
||||||
|
|
||||||
|
struct block_q4_0
|
||||||
|
{
|
||||||
|
float16_t d;
|
||||||
|
uint8_t qs[16];
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer A { block_q4_0 data_a[]; };
|
||||||
|
layout (binding = 1) readonly buffer B { vec4 data_b[]; };
|
||||||
|
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter
|
||||||
|
{
|
||||||
|
int M;
|
||||||
|
int N;
|
||||||
|
int K;
|
||||||
|
int stride_a;
|
||||||
|
int stride_b;
|
||||||
|
int stride_d;
|
||||||
|
int k_split;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
layout (constant_id = 1) const int BM = 64;
|
||||||
|
layout (constant_id = 2) const int BN = 64;
|
||||||
|
layout (constant_id = 3) const int BK = 16;
|
||||||
|
layout (constant_id = 4) const int WM = 32;
|
||||||
|
layout (constant_id = 5) const int WN = 32;
|
||||||
|
layout (constant_id = 6) const int WMITER = 2;
|
||||||
|
layout (constant_id = 7) const int TM = 4;
|
||||||
|
layout (constant_id = 8) const int TN = 2;
|
||||||
|
|
||||||
|
shared float buf_a[BM * (BK+1)];
|
||||||
|
shared float buf_b[BN * (BK+1)];
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const int blocks_x = (p.M + BM - 1) / BM;
|
||||||
|
const int ir = int(gl_WorkGroupID.x) % blocks_x;
|
||||||
|
const int ik = int(gl_WorkGroupID.x) / blocks_x;
|
||||||
|
const int ic = int(gl_WorkGroupID.y);
|
||||||
|
|
||||||
|
const int stride_a = p.stride_a / QUANT_K;
|
||||||
|
|
||||||
|
const int warp_i = int(gl_LocalInvocationID.x / WARP);
|
||||||
|
const int warp_r = warp_i % (BM / WM);
|
||||||
|
const int warp_c = warp_i / (BM / WM);
|
||||||
|
|
||||||
|
const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||||
|
const int WSUBM = WM / WMITER;
|
||||||
|
const int WSUBN = WN / WNITER;
|
||||||
|
|
||||||
|
const int tiw = int(gl_LocalInvocationID.x % WARP);
|
||||||
|
const int tiwr = tiw % (WSUBM / TM);
|
||||||
|
const int tiwc = tiw / (WSUBM / TM);
|
||||||
|
|
||||||
|
const int loadr = int(gl_LocalInvocationID.x % (BK / 4));
|
||||||
|
const int loadc = int(gl_LocalInvocationID.x / (BK / 4));
|
||||||
|
|
||||||
|
const int loadstride = int(gl_WorkGroupSize.x * 4) / BK;
|
||||||
|
|
||||||
|
const int start_k = ik * p.k_split;
|
||||||
|
const int end_k = (ik + 1) * p.k_split;
|
||||||
|
|
||||||
|
int pos_b = ic * BN * p.stride_b / 4 + start_k / 4;
|
||||||
|
|
||||||
|
float sums[WMITER * TM * WNITER * TN];
|
||||||
|
float cache_a[WMITER * TM];
|
||||||
|
float cache_b[WNITER * TN];
|
||||||
|
|
||||||
|
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||||
|
sums[i] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
|
||||||
|
[[unroll]] for (int l = 0; l < BM; l += loadstride) {
|
||||||
|
const int row = (block + loadr * 4) / QUANT_K;
|
||||||
|
const int qi = (block + loadr * 4) % QUANT_K;
|
||||||
|
const block_q4_0 blk = data_a[(ir * BM + loadc + l) * stride_a + row];
|
||||||
|
const float d = float(blk.d);
|
||||||
|
|
||||||
|
int x0, x1, x2, x3;
|
||||||
|
if (qi < 16) {
|
||||||
|
x0 = (blk.qs[qi + 0] & 0x0F) - 8;
|
||||||
|
x1 = (blk.qs[qi + 1] & 0x0F) - 8;
|
||||||
|
x2 = (blk.qs[qi + 2] & 0x0F) - 8;
|
||||||
|
x3 = (blk.qs[qi + 3] & 0x0F) - 8;
|
||||||
|
} else {
|
||||||
|
x0 = (blk.qs[qi + 0] >> 4) - 8;
|
||||||
|
x1 = (blk.qs[qi + 1] >> 4) - 8;
|
||||||
|
x2 = (blk.qs[qi + 2] >> 4) - 8;
|
||||||
|
x3 = (blk.qs[qi + 3] >> 4) - 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 4 + 0] = x0*d;
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 4 + 1] = x1*d;
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 4 + 2] = x2*d;
|
||||||
|
buf_a[(loadc + l) * (BK+1) + loadr * 4 + 3] = x3*d;
|
||||||
|
}
|
||||||
|
[[unroll]] for (int l = 0; l < BN; l += loadstride) {
|
||||||
|
vec4 tmp = data_b[pos_b + (loadc + l) * p.stride_b / 4 + loadr];
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 4 + 0] = tmp.x;
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 4 + 1] = tmp.y;
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 4 + 2] = tmp.z;
|
||||||
|
buf_b[(loadc + l) * (BK+1) + loadr * 4 + 3] = tmp.w;
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
pos_b += BK / 4;
|
||||||
|
|
||||||
|
for (int i = 0; i < min(BK, p.K - block); i++) {
|
||||||
|
// Load from shared into cache
|
||||||
|
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||||
|
[[unroll]] for (int j = 0; j < TM; j++) {
|
||||||
|
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||||
|
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||||
|
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||||
|
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||||
|
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||||
|
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||||
|
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += cache_a[wsir * TM + cr] * cache_b[wsic * TN + cc];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
const int dr = ir * BM + warp_r * WM;
|
||||||
|
const int dc = ic * BN + warp_c * WN;
|
||||||
|
|
||||||
|
const int k_split_offset = ik * p.M * p.N;
|
||||||
|
|
||||||
|
[[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) {
|
||||||
|
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
|
||||||
|
|
||||||
|
const int dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
||||||
|
const int dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||||
|
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||||
|
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||||
|
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||||
|
data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue