diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7d08239b1..0a7f1d956 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -50,7 +50,25 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); } while (0) #endif // CUDART_VERSION >= 11 -typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1); +typedef float dfloat; // dequantize float +typedef float2 dfloat2; + +static __device__ __forceinline__ void dadd_inplace(dfloat2 & a, const dfloat b) { + a.x += b; + a.y += b; +} + +static __device__ __forceinline__ void dsub_inplace(dfloat2 & a, const dfloat b) { + a.x -= b; + a.y -= b; +} + +static __device__ __forceinline__ void dmul_inplace(dfloat2 & a, const dfloat b) { + a.x *= b; + a.y *= b; +} + +typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v); typedef void (*cpy_kernel_t)(const char * cx, char * cdst); @@ -234,39 +252,39 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol } } -static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ +static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q4_0 * x = (const block_q4_0 *) vx; - const float d = x[ib].d; + const dfloat d = x[ib].d; const uint8_t vui = x[ib].qs[iqs]; - const int8_t vi0 = vui & 0xF; - const int8_t vi1 = vui >> 4; + v.x = vui & 0xF; + v.y = vui >> 4; - v0 = (vi0 - 8)*d; - v1 = (vi1 - 8)*d; + dsub_inplace(v, 8.0f); + dmul_inplace(v, d); } -static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){ +static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q4_1 * x = (const block_q4_1 *) vx; - const float d = x[ib].d; - const float m = x[ib].m; + const dfloat d = x[ib].d; + const dfloat m = x[ib].m; const uint8_t vui = x[ib].qs[iqs]; - const int8_t vi0 = vui & 0xF; - const int8_t vi1 = vui >> 4; + v.x = vui & 0xF; + v.y = vui >> 4; - v0 = vi0*d + m; - v1 = vi1*d + m; + dmul_inplace(v, d); + dadd_inplace(v, m); } -static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ +static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q5_0 * x = (const block_q5_0 *) vx; - const float d = x[ib].d; + const dfloat d = x[ib].d; uint32_t qh; memcpy(&qh, x[ib].qh, sizeof(qh)); @@ -274,18 +292,18 @@ static __device__ void dequantize_q5_0(const void * vx, const int ib, const int const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10; - const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16; - const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16; + v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); + v.y = ((x[ib].qs[iqs] >> 4) | xh_1); - v0 = x0*d; - v1 = x1*d; + dsub_inplace(v, 16.0f); + dmul_inplace(v, d); } -static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){ +static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q5_1 * x = (const block_q5_1 *) vx; - const float d = x[ib].d; - const float m = x[ib].m; + const dfloat d = x[ib].d; + const dfloat m = x[ib].m; uint32_t qh; memcpy(&qh, x[ib].qh, sizeof(qh)); @@ -293,23 +311,22 @@ static __device__ void dequantize_q5_1(const void * vx, const int ib, const int const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10; - const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0); - const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1); + v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); + v.y = ((x[ib].qs[iqs] >> 4) | xh_1); - v0 = x0*d + m; - v1 = x1*d + m; + dmul_inplace(v, d); + dadd_inplace(v, m); } -static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ +static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q8_0 * x = (const block_q8_0 *) vx; - const float d = x[ib].d; + const dfloat d = x[ib].d; - const int8_t vi0 = x[ib].qs[iqs + 0]; - const int8_t vi1 = x[ib].qs[iqs + 1]; + v.x = x[ib].qs[iqs + 0]; + v.y = x[ib].qs[iqs + 1]; - v0 = vi0*d; - v1 = vi1*d; + dmul_inplace(v, d); } //================================== k-quants @@ -843,11 +860,11 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float } } -static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){ +static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){ const half * x = (const half *) vx; - v0 = __half2float(x[ib + iqs + 0]); - v1 = __half2float(x[ib + iqs + 1]); + v.x = __half2float(x[ib + iqs + 0]); + v.y = __half2float(x[ib + iqs + 1]); } template @@ -864,9 +881,11 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k) const int y_offset = qr == 1 ? 1 : qk/2; // dequantize - float & v0 = y[iybs + iqs + 0]; - float & v1 = y[iybs + iqs + y_offset]; - dequantize_kernel(vx, ib, iqs, v0, v1); + dfloat2 v; + dequantize_kernel(vx, ib, iqs, v); + + y[iybs + iqs + 0] = v.x; + y[iybs + iqs + y_offset] = v.y; } template @@ -899,13 +918,13 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const half * y, f // process 2 vals per j iter // dequantize - float v0, v1; - dequantize_kernel(vx, ib, iqs + j/qr, v0, v1); + dfloat2 v; + dequantize_kernel(vx, ib, iqs + j/qr, v); // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val // matrix multiplication - tmp += v0 * __half2float(y[iybs + iqs + j/qr + 0]); - tmp += v1 * __half2float(y[iybs + iqs + j/qr + y_offset]); + tmp += v.x * __half2float(y[iybs + iqs + j/qr + 0]); + tmp += v.y * __half2float(y[iybs + iqs + j/qr + y_offset]); // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 } }