CUDA: faster q2_K, q3_K MMQ + int8 tensor cores (#7921)

* CUDA: faster q2_K, q3_K MMQ + int8 tensor cores

* try CI fix

* try CI fix

* try CI fix

* fix data race

* rever q2_K precision related changes
This commit is contained in:
Johannes Gäßler 2024-06-14 18:41:49 +02:00 committed by GitHub
parent 66ef1ceedf
commit 76d66ee0be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 468 additions and 330 deletions

View file

@ -331,6 +331,10 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
#define FP16_AVAILABLE
#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
#define FAST_FP16_AVAILABLE
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
#define FP16_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
@ -661,6 +665,7 @@ struct ggml_cuda_device_info {
int cc; // compute capability
int nsm; // number of streaming multiprocessors
size_t smpb; // max. shared memory per block
size_t smpbo; // max. shared memory per block (with opt-in)
bool vmm; // virtual memory support
size_t vmm_granularity; // granularity of virtual memory
size_t total_vram;