From bff3a20944a851fe443ed8cc326e19861f6d0cf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 14 Jun 2024 16:16:44 +0200 Subject: [PATCH] fix data race --- ggml-cuda/mmq.cuh | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index c9019242b..774083249 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -850,16 +850,18 @@ template static __device__ __forceinlin const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx); - x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = 0; - #pragma unroll - for (int l = 0; l < QR3_K; ++l) { + for (int l = 0; l < QR2_K; ++l) { const int k = kbx*QI2_K + (kqsx/8)*8 + l*2 + (kqsx % 8)/4; int x_qs_k = ((x_ql_0 >> (2*l)) & 0x03030303) << (2*(kqsx % 4)); x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE); x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 2, WARP_SIZE); + if (kqsx % QR2_K != 0) { + continue; + } + x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k; } @@ -1011,6 +1013,10 @@ template static __device__ __forceinlin int x_qs_k = (x_ql_k | x_qh_k) << (4*(k%2)); x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE); + if (kqsx % 2 != 0) { + continue; + } + x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k; } }