ggml/ex: calculate accuracy in graph, adapt MNIST (ggml/980)

This commit is contained in:
Johannes Gäßler 2024-10-03 17:29:59 +02:00 committed by Georgi Gerganov
parent eee39bdc96
commit fabdc3bda3
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
11 changed files with 389 additions and 8 deletions

View file

@ -175,6 +175,18 @@ static __device__ void no_device_code(
#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
#endif // __CUDA_ARCH__
static __device__ __forceinline__ int warp_reduce_sum(int x) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
return __reduce_add_sync(0xffffffff, x);
#else
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
}
return x;
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
}
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {