WIP better data types

This commit is contained in:
JohannesGaessler 2023-06-18 10:37:13 +02:00
parent 28e2762324
commit e6c9b5f567

View file

@ -78,31 +78,6 @@ static __device__ __forceinline__ void dmul_vector_inplace(dfloat2 & a, const df
typedef float dfloat; // dequantize float typedef float dfloat; // dequantize float
typedef float2 dfloat2; typedef float2 dfloat2;
static __device__ __forceinline__ void dadd_scalar_inplace(dfloat2 & a, const dfloat b) {
a.x += b;
a.y += b;
}
static __device__ __forceinline__ void dadd_vector_inplace(dfloat2 & a, const dfloat2 b) {
a.x += b.x;
a.y += b.y;
}
static __device__ __forceinline__ void dsub_scalar_inplace(dfloat2 & a, const dfloat b) {
a.x -= b;
a.y -= b;
}
static __device__ __forceinline__ void dmul_scalar_inplace(dfloat2 & a, const dfloat b) {
a.x *= b;
a.y *= b;
}
static __device__ __forceinline__ void dmul_vector_inplace(dfloat2 & a, const dfloat2 b) {
a.x *= b.x;
a.y *= b.y;
}
#endif //GGML_CUDA_DMMV_F16 #endif //GGML_CUDA_DMMV_F16
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
@ -294,13 +269,16 @@ static __device__ void dequantize_q4_0(const void * vx, const int ib, const int
const dfloat d = x[ib].d; const dfloat d = x[ib].d;
const uint8_t vui = x[ib].qs[iqs]; const int vui = x[ib].qs[iqs];
v.x = vui & 0xF; v.x = vui & 0xF;
v.y = vui >> 4; v.y = vui >> 4;
dsub_scalar_inplace(v, 8.0f); v.x -= 8.0f;
dmul_scalar_inplace(v, d); v.y -= 8.0f;
v.x *= d;
v.y *= d;
} }
static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
@ -309,13 +287,13 @@ static __device__ void dequantize_q4_1(const void * vx, const int ib, const int
const dfloat d = x[ib].d; const dfloat d = x[ib].d;
const dfloat m = x[ib].m; const dfloat m = x[ib].m;
const uint8_t vui = x[ib].qs[iqs]; const int vui = x[ib].qs[iqs];
v.x = vui & 0xF; v.x = vui & 0xF;
v.y = vui >> 4; v.y = vui >> 4;
dmul_scalar_inplace(v, d); v.x = (v.x * d) + m;
dadd_scalar_inplace(v, m); v.y = (v.y * d) + m;
} }
static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
@ -326,14 +304,17 @@ static __device__ void dequantize_q5_0(const void * vx, const int ib, const int
uint32_t qh; uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh)); memcpy(&qh, x[ib].qh, sizeof(qh));
const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10; const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
v.y = ((x[ib].qs[iqs] >> 4) | xh_1); v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
dsub_scalar_inplace(v, 16.0f); v.x -= 16.0f;
dmul_scalar_inplace(v, d); v.y -= 16.0f;
v.x *= d;
v.y *= d;
} }
static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
@ -345,14 +326,17 @@ static __device__ void dequantize_q5_1(const void * vx, const int ib, const int
uint32_t qh; uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh)); memcpy(&qh, x[ib].qh, sizeof(qh));
const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10; const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
v.y = ((x[ib].qs[iqs] >> 4) | xh_1); v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
dmul_scalar_inplace(v, d); v.x *= d;
dadd_scalar_inplace(v, m); v.y *= d;
v.x += m;
v.y += m;
} }
static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
@ -363,7 +347,8 @@ static __device__ void dequantize_q8_0(const void * vx, const int ib, const int
v.x = x[ib].qs[iqs + 0]; v.x = x[ib].qs[iqs + 0];
v.y = x[ib].qs[iqs + 1]; v.y = x[ib].qs[iqs + 1];
dmul_scalar_inplace(v, d); v.x *= d;
v.y *= d;
} }
//================================== k-quants //================================== k-quants
@ -942,7 +927,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y,
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
const int y_offset = qr == 1 ? 1 : qk/2; const int y_offset = qr == 1 ? 1 : qk/2;
dfloat2 tmp = {0.0f, 0.0f}; // partial sum for thread in warp dfloat tmp = 0.0f; // partial sum for thread in warp
for (int i = 0; i < ncols; i += iter_stride) { for (int i = 0; i < ncols; i += iter_stride) {
const int col = i + vals_per_iter*tid; const int col = i + vals_per_iter*tid;
@ -961,11 +946,9 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y,
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
// matrix multiplication // matrix multiplication
const dfloat2 yi = {y[iybs + iqs + j/qr + 0], y[iybs + iqs + j/qr + y_offset]}; tmp += v.x * y[iybs + iqs + j/qr + 0];
tmp += v.y * y[iybs + iqs + j/qr + y_offset];
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
dmul_vector_inplace(v, yi);
dadd_vector_inplace(tmp, v);
} }
} }
@ -973,16 +956,11 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y,
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) { for (int mask = 16; mask > 0; mask >>= 1) {
#ifdef GGML_CUDA_DMMV_F16
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
#else
tmp.x += __shfl_xor_sync(0xffffffff, tmp.x, mask, 32);
tmp.y += __shfl_xor_sync(0xffffffff, tmp.y, mask, 32);
#endif // GGML_CUDA_DMMV_F16
} }
if (tid == 0) { if (tid == 0) {
dst[row] = tmp.x + tmp.y; dst[row] = tmp;
} }
} }