compile option

This commit is contained in:
JohannesGaessler 2023-06-18 09:23:37 +02:00
parent 8ac993bd0a
commit cde81f91ca
2 changed files with 41 additions and 9 deletions

View file

@ -169,6 +169,9 @@ ifdef LLAMA_CUDA_DMMV_Y
else
NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1
endif # LLAMA_CUDA_DMMV_Y
ifdef LLAMA_CUDA_DMMV_F16
NVCCFLAGS += -DGGML_CUDA_DMMV_F16
endif # LLAMA_CUDA_DMMV_F16
ifdef LLAMA_CUDA_KQUANTS_ITER
NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
else

View file

@ -50,6 +50,28 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
} while (0)
#endif // CUDART_VERSION >= 11
#ifdef GGML_CUDA_DMMV_F16
typedef half dfloat; // dequantize float
typedef half2 dfloat2;
static __device__ __forceinline__ void dadd_inplace(dfloat2 & a, const dfloat b) {
a = __hadd2(a, {b, b});
}
static __device__ __forceinline__ void dsub_inplace(dfloat2 & a, const dfloat b) {
a = __hsub2(a, {b, b});
}
static __device__ __forceinline__ void dmul_scalar_inplace(dfloat2 & a, const dfloat b) {
a = __hmul2(a, {b, b});
}
static __device__ __forceinline__ void dmul_vector_inplace(dfloat2 & a, const dfloat2 b) {
a = __hmul2(a, b);
}
#else
typedef float dfloat; // dequantize float
typedef float2 dfloat2;
@ -63,11 +85,17 @@ static __device__ __forceinline__ void dsub_inplace(dfloat2 & a, const dfloat b)
a.y -= b;
}
static __device__ __forceinline__ void dmul_inplace(dfloat2 & a, const dfloat b) {
static __device__ __forceinline__ void dmul_scalar_inplace(dfloat2 & a, const dfloat b) {
a.x *= b;
a.y *= b;
}
static __device__ __forceinline__ void dmul_vector_inplace(dfloat2 & a, const dfloat2 b) {
a.x *= b.x;
a.y *= b.y;
}
#endif //GGML_CUDA_DMMV_F16
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
@ -263,7 +291,7 @@ static __device__ void dequantize_q4_0(const void * vx, const int ib, const int
v.y = vui >> 4;
dsub_inplace(v, 8.0f);
dmul_inplace(v, d);
dmul_scalar_inplace(v, d);
}
static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
@ -277,7 +305,7 @@ static __device__ void dequantize_q4_1(const void * vx, const int ib, const int
v.x = vui & 0xF;
v.y = vui >> 4;
dmul_inplace(v, d);
dmul_scalar_inplace(v, d);
dadd_inplace(v, m);
}
@ -296,7 +324,7 @@ static __device__ void dequantize_q5_0(const void * vx, const int ib, const int
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
dsub_inplace(v, 16.0f);
dmul_inplace(v, d);
dmul_scalar_inplace(v, d);
}
static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
@ -314,7 +342,7 @@ static __device__ void dequantize_q5_1(const void * vx, const int ib, const int
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
dmul_inplace(v, d);
dmul_scalar_inplace(v, d);
dadd_inplace(v, m);
}
@ -326,7 +354,7 @@ static __device__ void dequantize_q8_0(const void * vx, const int ib, const int
v.x = x[ib].qs[iqs + 0];
v.y = x[ib].qs[iqs + 1];
dmul_inplace(v, d);
dmul_scalar_inplace(v, d);
}
//================================== k-quants
@ -904,7 +932,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const half * y, f
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
const int y_offset = qr == 1 ? 1 : qk/2;
float tmp = 0.0f; // partial sum for thread in warp
dfloat tmp = 0.0f; // partial sum for thread in warp
for (int i = 0; i < ncols; i += iter_stride) {
const int col = i + vals_per_iter*tid;
@ -923,8 +951,9 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const half * y, f
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
// matrix multiplication
tmp += v.x * __half2float(y[iybs + iqs + j/qr + 0]);
tmp += v.y * __half2float(y[iybs + iqs + j/qr + y_offset]);
dfloat2 yi = {y[iybs + iqs + j/qr + 0], y[iybs + iqs + j/qr + y_offset]};
tmp += v.x * yi.x;
tmp += v.y * yi.y;
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
}
}