Output FP32 in fp16 matmul shader

This commit is contained in:
0cc4m 2023-06-29 20:15:39 +02:00
parent 40c8f843f2
commit df3cdbdac7
2 changed files with 8 additions and 8 deletions

View file

@ -318,7 +318,9 @@ void ggml_vk_init(void) {
// Shaders // Shaders
vk_pipeline_matmul_f32 = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 6, {128, 128, 1}); vk_pipeline_matmul_f32 = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 6, {128, 128, 1});
if (vk_fp16_support) {
vk_pipeline_matmul_f16 = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 6, {128, 128, 1}); vk_pipeline_matmul_f16 = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 6, {128, 128, 1});
}
vk_pipeline_f16_to_f32 = ggml_vk_create_pipeline("vk_shaders/f16_to_f32.spv", "main", 2, 1, {32, 1, 1}); vk_pipeline_f16_to_f32 = ggml_vk_create_pipeline("vk_shaders/f16_to_f32.spv", "main", 2, 1, {32, 1, 1});
@ -816,7 +818,7 @@ static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &d_X, VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT); ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &d_X, VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT);
} }
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * y_ne, &d_Y, VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT); ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * y_ne, &d_Y, VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT);
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * d_ne, &d_D, VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT); ggml_vk_pool_malloc(sizeof(float) * d_ne, &d_D, VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT);
bool src1_cont_rows = nb10 == sizeof(float); bool src1_cont_rows = nb10 == sizeof(float);
bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float); bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
@ -873,10 +875,8 @@ static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
vk_device.destroyFence(fence); vk_device.destroyFence(fence);
// copy dst to host // copy dst to host
ggml_vk_buffer_read(&d_D, 0, tmp, sizeof(ggml_fp16_t) * d_ne);
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
ggml_fp16_to_fp32_row(tmp, d, d_ne); ggml_vk_buffer_read(&d_D, 0, d, sizeof(float) * d_ne);
} }
} }

View file

@ -13,7 +13,7 @@ layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1)
layout (binding = 0) readonly buffer A { float16_t data_a[]; }; layout (binding = 0) readonly buffer A { float16_t data_a[]; };
layout (binding = 1) readonly buffer B { float16_t data_b[]; }; layout (binding = 1) readonly buffer B { float16_t data_b[]; };
layout (binding = 2) writeonly buffer D { float16_t data_d[]; }; layout (binding = 2) writeonly buffer D { float data_d[]; };
layout (push_constant) uniform parameter layout (push_constant) uniform parameter
{ {
@ -45,7 +45,7 @@ void main() {
int pos_a = ir * BM * p.stride_a; int pos_a = ir * BM * p.stride_a;
int pos_b = ic * BN * p.stride_b; int pos_b = ic * BN * p.stride_b;
float16_t sums[TM * TN]; float sums[TM * TN];
float16_t cache_a[TM]; float16_t cache_a[TM];
float16_t cache_b[TN]; float16_t cache_b[TN];
@ -81,7 +81,7 @@ void main() {
[[unroll]] for (int cc = 0; cc < TN; cc++) { [[unroll]] for (int cc = 0; cc < TN; cc++) {
[[unroll]] for (int cr = 0; cr < TM; cr++) { [[unroll]] for (int cr = 0; cr < TM; cr++) {
sums[cc * TM + cr] += cache_a[cr] * cache_b[cc]; sums[cc * TM + cr] += float(cache_a[cr]) * float(cache_b[cc]);
} }
} }
} }