diff --git a/Makefile b/Makefile index 9dd8a51c4..dbf9d52b7 100644 --- a/Makefile +++ b/Makefile @@ -171,11 +171,11 @@ ifdef LLAMA_CUDA_DMMV_Y else NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1 endif # LLAMA_CUDA_DMMV_Y -#ifdef LLAMA_CUDA_KQUANTS_ITER +ifdef LLAMA_CUDA_KQUANTS_ITER NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER) -#else +else NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2 -#endif +endif ggml-cuda.o: ggml-cuda.cu ggml-cuda.h $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ endif # LLAMA_CUBLAS diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 46f50e492..91c3a3b70 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -169,6 +169,8 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #ifndef K_QUANTS_PER_ITERATION #define K_QUANTS_PER_ITERATION 2 +#else +static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); #endif static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {