ggml : poc for normalizing weights for better quantization (metal)
This commit is contained in:
parent
b532a69b2f
commit
253eab8ae1
5 changed files with 295 additions and 139 deletions
53
ggml-cuda.cu
53
ggml-cuda.cu
|
@ -204,23 +204,31 @@ typedef void (*ggml_cuda_op_t)(
|
|||
// QR = QK / number of values before dequantization
|
||||
// QI = number of 32 bit integers before dequantization
|
||||
|
||||
#define Q4_0DM (1.0f/8.0f)
|
||||
#define Q4_0D(x) (((x)*Q4_0DM) / 127.0f)
|
||||
|
||||
#define QK4_0 32
|
||||
#define QR4_0 2
|
||||
#define QI4_0 (QK4_0 / (4 * QR4_0))
|
||||
typedef struct {
|
||||
half d; // delta
|
||||
int8_t d; // delta
|
||||
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
||||
} block_q4_0;
|
||||
static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
|
||||
static_assert(sizeof(block_q4_0) == sizeof(int8_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
|
||||
|
||||
#define Q4_1DM (2.0f/15.0f)
|
||||
#define Q4_1MM (2.0f )
|
||||
#define Q4_1D(x) ( (((x) & 0xFF)*Q4_1DM) / 255.0f)
|
||||
#define Q4_1M(x) (-1.0f + (((x) >> 8)*Q4_1MM) / 255.0f)
|
||||
|
||||
#define QK4_1 32
|
||||
#define QR4_1 2
|
||||
#define QI4_1 (QK4_1 / (4 * QR4_1))
|
||||
typedef struct {
|
||||
half2 dm; // dm.x = delta, dm.y = min
|
||||
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
||||
uint16_t dm; // 8-bit delta + 8-bit min (can be adjusted easily)
|
||||
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
||||
} block_q4_1;
|
||||
static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
|
||||
static_assert(sizeof(block_q4_1) == sizeof(uint16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
|
||||
|
||||
#define QK5_0 32
|
||||
#define QR5_0 2
|
||||
|
@ -232,15 +240,20 @@ typedef struct {
|
|||
} block_q5_0;
|
||||
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
|
||||
|
||||
#define Q5_1DM (2.0f/31.0f)
|
||||
#define Q5_1MM (2.0f )
|
||||
#define Q5_1D(x) ( (((x) & 0x0F)*Q5_1DM) / 15.0f)
|
||||
#define Q5_1M(x) (-1.0f + (((x) >> 4)*Q5_1MM) / 15.0f)
|
||||
|
||||
#define QK5_1 32
|
||||
#define QR5_1 2
|
||||
#define QI5_1 (QK5_1 / (4 * QR5_1))
|
||||
typedef struct {
|
||||
half2 dm; // dm.x = delta, dm.y = min
|
||||
uint8_t dm; // 4-bit delta + 4-bit min
|
||||
uint8_t qh[4]; // 5-th bit of quants
|
||||
uint8_t qs[QK5_1 / 2]; // nibbles / quants
|
||||
} block_q5_1;
|
||||
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
|
||||
static_assert(sizeof(block_q5_1) == sizeof(uint8_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
|
||||
|
||||
#define QK8_0 32
|
||||
#define QR8_0 1
|
||||
|
@ -506,7 +519,7 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
|
|||
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
|
||||
const dfloat d = x[ib].d;
|
||||
const dfloat d = Q4_0D(x[ib].d);
|
||||
|
||||
const int vui = x[ib].qs[iqs];
|
||||
|
||||
|
@ -525,8 +538,8 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in
|
|||
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
const block_q4_1 * x = (const block_q4_1 *) vx;
|
||||
|
||||
const dfloat d = __low2half(x[ib].dm);
|
||||
const dfloat m = __high2half(x[ib].dm);
|
||||
const dfloat d = Q4_1D(x[ib].dm);
|
||||
const dfloat m = Q4_1M(x[ib].dm);
|
||||
|
||||
const int vui = x[ib].qs[iqs];
|
||||
|
||||
|
@ -568,8 +581,8 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
|
|||
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
const block_q5_1 * x = (const block_q5_1 *) vx;
|
||||
|
||||
const dfloat d = __low2half(x[ib].dm);
|
||||
const dfloat m = __high2half(x[ib].dm);
|
||||
const dfloat d = Q5_1D(x[ib].dm);
|
||||
const dfloat m = Q5_1M(x[ib].dm);
|
||||
|
||||
uint32_t qh;
|
||||
memcpy(&qh, x[ib].qh, sizeof(qh));
|
||||
|
@ -2041,7 +2054,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
|
|||
u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);
|
||||
}
|
||||
|
||||
return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);
|
||||
return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, Q4_0D(bq4_0->d), bq8_1->ds);
|
||||
}
|
||||
|
||||
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||
|
@ -2135,7 +2148,12 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
|
|||
u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1);
|
||||
}
|
||||
|
||||
return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);
|
||||
const float d = Q4_1D(bq4_1->dm);
|
||||
const float m = Q4_1M(bq4_1->dm);
|
||||
|
||||
const float2 dm = {d, m};
|
||||
|
||||
return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, dm, bq8_1->ds);
|
||||
}
|
||||
|
||||
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||
|
@ -2341,7 +2359,12 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
|
|||
u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1);
|
||||
}
|
||||
|
||||
return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, bq5_1->dm, bq8_1->ds);
|
||||
const float d = Q5_1D(bq4_1->dm);
|
||||
const float m = Q5_1M(bq4_1->dm);
|
||||
|
||||
const float2 dm = {d, m};
|
||||
|
||||
return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, dm, bq8_1->ds);
|
||||
}
|
||||
|
||||
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue