Output FP32 in fp16 matmul shader
This commit is contained in:
parent
40c8f843f2
commit
df3cdbdac7
2 changed files with 8 additions and 8 deletions
|
@ -318,7 +318,9 @@ void ggml_vk_init(void) {
|
|||
|
||||
// Shaders
|
||||
vk_pipeline_matmul_f32 = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 6, {128, 128, 1});
|
||||
if (vk_fp16_support) {
|
||||
vk_pipeline_matmul_f16 = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 6, {128, 128, 1});
|
||||
}
|
||||
|
||||
vk_pipeline_f16_to_f32 = ggml_vk_create_pipeline("vk_shaders/f16_to_f32.spv", "main", 2, 1, {32, 1, 1});
|
||||
|
||||
|
@ -816,7 +818,7 @@ static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
|||
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &d_X, VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT);
|
||||
}
|
||||
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * y_ne, &d_Y, VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT);
|
||||
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * d_ne, &d_D, VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT);
|
||||
ggml_vk_pool_malloc(sizeof(float) * d_ne, &d_D, VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT);
|
||||
|
||||
bool src1_cont_rows = nb10 == sizeof(float);
|
||||
bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
|
||||
|
@ -873,10 +875,8 @@ static void ggml_vk_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
|||
vk_device.destroyFence(fence);
|
||||
|
||||
// copy dst to host
|
||||
ggml_vk_buffer_read(&d_D, 0, tmp, sizeof(ggml_fp16_t) * d_ne);
|
||||
|
||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||
ggml_fp16_to_fp32_row(tmp, d, d_ne);
|
||||
ggml_vk_buffer_read(&d_D, 0, d, sizeof(float) * d_ne);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1)
|
|||
|
||||
layout (binding = 0) readonly buffer A { float16_t data_a[]; };
|
||||
layout (binding = 1) readonly buffer B { float16_t data_b[]; };
|
||||
layout (binding = 2) writeonly buffer D { float16_t data_d[]; };
|
||||
layout (binding = 2) writeonly buffer D { float data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
|
@ -45,7 +45,7 @@ void main() {
|
|||
int pos_a = ir * BM * p.stride_a;
|
||||
int pos_b = ic * BN * p.stride_b;
|
||||
|
||||
float16_t sums[TM * TN];
|
||||
float sums[TM * TN];
|
||||
float16_t cache_a[TM];
|
||||
float16_t cache_b[TN];
|
||||
|
||||
|
@ -81,7 +81,7 @@ void main() {
|
|||
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
sums[cc * TM + cr] += cache_a[cr] * cache_b[cc];
|
||||
sums[cc * TM + cr] += float(cache_a[cr]) * float(cache_b[cc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue