ggml : add q4_1 normalized quants
This commit is contained in:
parent
675425563c
commit
a4d1eb72c6
3 changed files with 43 additions and 29 deletions
20
ggml-cuda.cu
20
ggml-cuda.cu
|
@ -86,15 +86,19 @@ typedef struct {
|
|||
} block_q4_0;
|
||||
static_assert(sizeof(block_q4_0) == sizeof(int8_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
|
||||
|
||||
#define Q4_1DM (2.0f/15.0f)
|
||||
#define Q4_1MM (2.0f )
|
||||
#define Q4_1D(x) ( (((x) & 0xFF)*Q4_1DM) / 255.0f)
|
||||
#define Q4_1M(x) (-1.0f + (((x) >> 8)*Q4_1MM) / 255.0f)
|
||||
|
||||
#define QK4_1 32
|
||||
#define QR4_1 2
|
||||
#define QI4_1 (QK4_1 / (4 * QR4_1))
|
||||
typedef struct {
|
||||
half d; // delta
|
||||
half m; // min
|
||||
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
||||
uint16_t dm; // 8-bit delta + 8-bit min (can be adjusted easily)
|
||||
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
||||
} block_q4_1;
|
||||
static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
|
||||
static_assert(sizeof(block_q4_1) == sizeof(uint16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
|
||||
|
||||
#define QK5_0 32
|
||||
#define QR5_0 2
|
||||
|
@ -386,8 +390,8 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in
|
|||
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
const block_q4_1 * x = (const block_q4_1 *) vx;
|
||||
|
||||
const dfloat d = x[ib].d;
|
||||
const dfloat m = x[ib].m;
|
||||
const dfloat d = Q4_1D(x[ib].dm);
|
||||
const dfloat m = Q4_1M(x[ib].dm);
|
||||
|
||||
const int vui = x[ib].qs[iqs];
|
||||
|
||||
|
@ -1368,8 +1372,8 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
|
|||
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
|
||||
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_1)]);
|
||||
|
||||
const float d = __half2float(bq4_1->d) * __half2float(bq8_1->d);
|
||||
const float m = bq4_1->m;
|
||||
const float d = Q4_1D(bq4_1->dm) * __half2float(bq8_1->d);
|
||||
const float m = Q4_1M(bq4_1->dm);
|
||||
const float s = bq8_1->s;
|
||||
|
||||
const int vi0 = (vi >> 0) & 0x0F0F0F0F;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue