cuda : optimize argmax (#10441)

* cuda : optimize argmax

* remove unused parameter

ggml-ci

* fixup : use full warps

ggml-ci

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* fix ub

* ggml : check ne00 <= INT32_MAX in argmax and argsort

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
Diego Devesa 2024-11-21 18:18:50 +01:00 committed by GitHub
parent 1bb30bf28c
commit a5e47592b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 110 additions and 67 deletions

View file

@ -69,8 +69,8 @@ static __global__ void quantize_mmq_q8_1(
// Exchange max. abs. value between vals_per_scale/4 threads.
#pragma unroll
for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE));
}
float sum;
@ -79,8 +79,8 @@ static __global__ void quantize_mmq_q8_1(
// Exchange calculate sum across vals_per_sum/4 threads.
#pragma unroll
for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
}
}