fix data race

This commit is contained in:
Johannes Gäßler 2024-06-14 16:16:44 +02:00
parent 80ba2aef4a
commit bff3a20944

View file

@ -850,16 +850,18 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = 0;
#pragma unroll
for (int l = 0; l < QR3_K; ++l) {
for (int l = 0; l < QR2_K; ++l) {
const int k = kbx*QI2_K + (kqsx/8)*8 + l*2 + (kqsx % 8)/4;
int x_qs_k = ((x_ql_0 >> (2*l)) & 0x03030303) << (2*(kqsx % 4));
x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 2, WARP_SIZE);
if (kqsx % QR2_K != 0) {
continue;
}
x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k;
}
@ -1011,6 +1013,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
int x_qs_k = (x_ql_k | x_qh_k) << (4*(k%2));
x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
if (kqsx % 2 != 0) {
continue;
}
x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k;
}
}