dfloat2
This commit is contained in:
parent
98b8596b2a
commit
8ac993bd0a
1 changed files with 62 additions and 43 deletions
105
ggml-cuda.cu
105
ggml-cuda.cu
|
@ -50,7 +50,25 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||||
} while (0)
|
} while (0)
|
||||||
#endif // CUDART_VERSION >= 11
|
#endif // CUDART_VERSION >= 11
|
||||||
|
|
||||||
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
|
typedef float dfloat; // dequantize float
|
||||||
|
typedef float2 dfloat2;
|
||||||
|
|
||||||
|
static __device__ __forceinline__ void dadd_inplace(dfloat2 & a, const dfloat b) {
|
||||||
|
a.x += b;
|
||||||
|
a.y += b;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ void dsub_inplace(dfloat2 & a, const dfloat b) {
|
||||||
|
a.x -= b;
|
||||||
|
a.y -= b;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ void dmul_inplace(dfloat2 & a, const dfloat b) {
|
||||||
|
a.x *= b;
|
||||||
|
a.y *= b;
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
|
||||||
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
|
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
|
||||||
typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
|
typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
|
||||||
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
||||||
|
@ -234,39 +252,39 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||||
|
|
||||||
const float d = x[ib].d;
|
const dfloat d = x[ib].d;
|
||||||
|
|
||||||
const uint8_t vui = x[ib].qs[iqs];
|
const uint8_t vui = x[ib].qs[iqs];
|
||||||
|
|
||||||
const int8_t vi0 = vui & 0xF;
|
v.x = vui & 0xF;
|
||||||
const int8_t vi1 = vui >> 4;
|
v.y = vui >> 4;
|
||||||
|
|
||||||
v0 = (vi0 - 8)*d;
|
dsub_inplace(v, 8.0f);
|
||||||
v1 = (vi1 - 8)*d;
|
dmul_inplace(v, d);
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||||
const block_q4_1 * x = (const block_q4_1 *) vx;
|
const block_q4_1 * x = (const block_q4_1 *) vx;
|
||||||
|
|
||||||
const float d = x[ib].d;
|
const dfloat d = x[ib].d;
|
||||||
const float m = x[ib].m;
|
const dfloat m = x[ib].m;
|
||||||
|
|
||||||
const uint8_t vui = x[ib].qs[iqs];
|
const uint8_t vui = x[ib].qs[iqs];
|
||||||
|
|
||||||
const int8_t vi0 = vui & 0xF;
|
v.x = vui & 0xF;
|
||||||
const int8_t vi1 = vui >> 4;
|
v.y = vui >> 4;
|
||||||
|
|
||||||
v0 = vi0*d + m;
|
dmul_inplace(v, d);
|
||||||
v1 = vi1*d + m;
|
dadd_inplace(v, m);
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||||
const block_q5_0 * x = (const block_q5_0 *) vx;
|
const block_q5_0 * x = (const block_q5_0 *) vx;
|
||||||
|
|
||||||
const float d = x[ib].d;
|
const dfloat d = x[ib].d;
|
||||||
|
|
||||||
uint32_t qh;
|
uint32_t qh;
|
||||||
memcpy(&qh, x[ib].qh, sizeof(qh));
|
memcpy(&qh, x[ib].qh, sizeof(qh));
|
||||||
|
@ -274,18 +292,18 @@ static __device__ void dequantize_q5_0(const void * vx, const int ib, const int
|
||||||
const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
|
const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
|
||||||
const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
|
const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
|
||||||
|
|
||||||
const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16;
|
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
|
||||||
const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16;
|
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
|
||||||
|
|
||||||
v0 = x0*d;
|
dsub_inplace(v, 16.0f);
|
||||||
v1 = x1*d;
|
dmul_inplace(v, d);
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||||
const block_q5_1 * x = (const block_q5_1 *) vx;
|
const block_q5_1 * x = (const block_q5_1 *) vx;
|
||||||
|
|
||||||
const float d = x[ib].d;
|
const dfloat d = x[ib].d;
|
||||||
const float m = x[ib].m;
|
const dfloat m = x[ib].m;
|
||||||
|
|
||||||
uint32_t qh;
|
uint32_t qh;
|
||||||
memcpy(&qh, x[ib].qh, sizeof(qh));
|
memcpy(&qh, x[ib].qh, sizeof(qh));
|
||||||
|
@ -293,23 +311,22 @@ static __device__ void dequantize_q5_1(const void * vx, const int ib, const int
|
||||||
const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
|
const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
|
||||||
const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
|
const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
|
||||||
|
|
||||||
const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0);
|
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
|
||||||
const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1);
|
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
|
||||||
|
|
||||||
v0 = x0*d + m;
|
dmul_inplace(v, d);
|
||||||
v1 = x1*d + m;
|
dadd_inplace(v, m);
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||||
const block_q8_0 * x = (const block_q8_0 *) vx;
|
const block_q8_0 * x = (const block_q8_0 *) vx;
|
||||||
|
|
||||||
const float d = x[ib].d;
|
const dfloat d = x[ib].d;
|
||||||
|
|
||||||
const int8_t vi0 = x[ib].qs[iqs + 0];
|
v.x = x[ib].qs[iqs + 0];
|
||||||
const int8_t vi1 = x[ib].qs[iqs + 1];
|
v.y = x[ib].qs[iqs + 1];
|
||||||
|
|
||||||
v0 = vi0*d;
|
dmul_inplace(v, d);
|
||||||
v1 = vi1*d;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//================================== k-quants
|
//================================== k-quants
|
||||||
|
@ -843,11 +860,11 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||||
const half * x = (const half *) vx;
|
const half * x = (const half *) vx;
|
||||||
|
|
||||||
v0 = __half2float(x[ib + iqs + 0]);
|
v.x = __half2float(x[ib + iqs + 0]);
|
||||||
v1 = __half2float(x[ib + iqs + 1]);
|
v.y = __half2float(x[ib + iqs + 1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||||
|
@ -864,9 +881,11 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
|
||||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||||
|
|
||||||
// dequantize
|
// dequantize
|
||||||
float & v0 = y[iybs + iqs + 0];
|
dfloat2 v;
|
||||||
float & v1 = y[iybs + iqs + y_offset];
|
dequantize_kernel(vx, ib, iqs, v);
|
||||||
dequantize_kernel(vx, ib, iqs, v0, v1);
|
|
||||||
|
y[iybs + iqs + 0] = v.x;
|
||||||
|
y[iybs + iqs + y_offset] = v.y;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||||
|
@ -899,13 +918,13 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const half * y, f
|
||||||
// process 2 vals per j iter
|
// process 2 vals per j iter
|
||||||
|
|
||||||
// dequantize
|
// dequantize
|
||||||
float v0, v1;
|
dfloat2 v;
|
||||||
dequantize_kernel(vx, ib, iqs + j/qr, v0, v1);
|
dequantize_kernel(vx, ib, iqs + j/qr, v);
|
||||||
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
|
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
|
||||||
|
|
||||||
// matrix multiplication
|
// matrix multiplication
|
||||||
tmp += v0 * __half2float(y[iybs + iqs + j/qr + 0]);
|
tmp += v.x * __half2float(y[iybs + iqs + j/qr + 0]);
|
||||||
tmp += v1 * __half2float(y[iybs + iqs + j/qr + y_offset]);
|
tmp += v.y * __half2float(y[iybs + iqs + j/qr + y_offset]);
|
||||||
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
|
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue