Q3_K now working on CUDA and AVX2/scalar
CUDA is not ideal - ~50% slower than Q4_0 for single token prediction, about the same in batch mode (perplexity). CPU single token is ~55 ms (on Ryzen 7950X).
This commit is contained in:
parent
b4f71347ff
commit
c93cce3a45
8 changed files with 381 additions and 6 deletions
4
Makefile
4
Makefile
|
@ -229,7 +229,7 @@ clean:
|
|||
# Examples
|
||||
#
|
||||
|
||||
main: examples/main/main.cpp build-info.h ggml.o llama.o common.o $(OBJS)
|
||||
main: examples/main/main.cpp build-info.h ggml.o k_quants.o llama.o common.o $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||
@echo
|
||||
@echo '==== Run ./main -h for help. ===='
|
||||
|
@ -269,7 +269,7 @@ benchmark-matmult: examples/benchmark/benchmark-matmult.cpp build-info.h ggml.o
|
|||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||
./$@
|
||||
|
||||
vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS)
|
||||
vdot: pocs/vdot/vdot.cpp ggml.o k_quants.o $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
|
||||
|
||||
.PHONY: tests clean
|
||||
|
|
|
@ -282,8 +282,9 @@ int main(int argc, char ** argv) {
|
|||
break;
|
||||
}
|
||||
int j;
|
||||
for (j = 0; j < GGML_TYPE_COUNT && strcmp(argv[i], ggml_type_name((ggml_type) j)) != 0; j++) {
|
||||
// find match
|
||||
for (j = 0; j < GGML_TYPE_COUNT; ++j) {
|
||||
const auto * name = ggml_type_name((ggml_type) j);
|
||||
if (name && strcmp(argv[i], name) == 0) break;
|
||||
}
|
||||
if (j < GGML_TYPE_COUNT) {
|
||||
params.include_types.push_back((ggml_type) j);
|
||||
|
|
|
@ -12,6 +12,7 @@ static const std::map<std::string, llama_ftype> LLAMA_FTYPE_MAP = {
|
|||
{"q5_0", LLAMA_FTYPE_MOSTLY_Q5_0},
|
||||
{"q5_1", LLAMA_FTYPE_MOSTLY_Q5_1},
|
||||
{"q8_0", LLAMA_FTYPE_MOSTLY_Q8_0},
|
||||
{"q3_K", LLAMA_FTYPE_MOSTLY_Q3_K},
|
||||
};
|
||||
|
||||
bool try_parse_ftype(const std::string & ftype_str, llama_ftype & ftype, std::string & ftype_str_out) {
|
||||
|
|
142
ggml-cuda.cu
142
ggml-cuda.cu
|
@ -3,6 +3,7 @@
|
|||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <atomic>
|
||||
#include <assert.h>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cublas_v2.h>
|
||||
|
@ -35,6 +36,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
|||
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
|
||||
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
|
||||
typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
|
||||
typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
|
||||
|
||||
// QK = number of values after dequantization
|
||||
// QR = QK / number of values before dequantization
|
||||
|
@ -83,6 +85,18 @@ typedef struct {
|
|||
} block_q8_0;
|
||||
static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
|
||||
|
||||
//================================= k-quants
|
||||
|
||||
#define QK_K 256
|
||||
|
||||
typedef struct {
|
||||
uint8_t hmask[QK_K/8];
|
||||
uint8_t qs[QK_K/4]; // nibbles / quants
|
||||
uint8_t scales[3*QK_K/64];
|
||||
half d;
|
||||
} block_q3_K;
|
||||
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
|
||||
|
||||
#define WARP_SIZE 32
|
||||
|
||||
#define CUDA_MUL_BLOCK_SIZE 256
|
||||
|
@ -184,6 +198,80 @@ static __device__ void dequantize_q8_0(const void * vx, const int ib, const int
|
|||
v1 = vi1*d;
|
||||
}
|
||||
|
||||
//================================== k-quants
|
||||
|
||||
static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
|
||||
|
||||
const uint32_t kmask1 = 0x03030303;
|
||||
const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
|
||||
int i = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
int n = tid / 4;
|
||||
int j = tid - 4*n;
|
||||
|
||||
const block_q3_K * x = (const block_q3_K *) vx;
|
||||
|
||||
float * y = yy + i*QK_K + 128*n + 32*j;
|
||||
|
||||
float d_all = x[i].d;
|
||||
|
||||
const uint8_t * q = x[i].qs + 32*n;
|
||||
const uint8_t * hm = x[i].hmask;
|
||||
|
||||
uint32_t aux[4];
|
||||
const int8_t * scales = (const int8_t*)aux;
|
||||
|
||||
memcpy(aux, x[i].scales, 12);
|
||||
uint32_t tmp = aux[2];
|
||||
aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
|
||||
aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
|
||||
aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
|
||||
aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
|
||||
|
||||
uint8_t m = 1 << (4*n + j);
|
||||
int is = 8*n + 2*j;
|
||||
float dl;
|
||||
int shift = 2*j;
|
||||
|
||||
dl = d_all * (scales[is++] - 32);
|
||||
for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
|
||||
|
||||
dl = d_all * (scales[is++] - 32);
|
||||
for (int l = 16; l < 32; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
|
||||
|
||||
}
|
||||
|
||||
static __device__ void vec_dot_q3_K(const void * vx, const int ib, const int iqs, const float * y, float & result) {
|
||||
|
||||
const block_q3_K * x = (const block_q3_K *) vx;
|
||||
|
||||
int n = iqs / 128;
|
||||
int iqsn = iqs - 128*n;
|
||||
int j = iqsn / 32;
|
||||
int l = iqsn - 32*j;
|
||||
int shift = 2*j;
|
||||
int is = 8*n + 2*j + l/16;
|
||||
uint8_t m = 1 << (4*n + j);
|
||||
|
||||
const float d = x[ib].d;
|
||||
const uint8_t * q = x[ib].qs + 32*n;
|
||||
const uint8_t * hm = x[ib].hmask;
|
||||
|
||||
int8_t us = is < 4 ? (x[ib].scales[is-0] & 0xF) | (((x[ib].scales[is+8] >> 0) & 3) << 4) :
|
||||
is < 8 ? (x[ib].scales[is-0] & 0xF) | (((x[ib].scales[is+4] >> 2) & 3) << 4) :
|
||||
is < 12 ? (x[ib].scales[is-8] >> 4) | (((x[ib].scales[is+0] >> 4) & 3) << 4) :
|
||||
(x[ib].scales[is-8] >> 4) | (((x[ib].scales[is-4] >> 6) & 3) << 4);
|
||||
float scale = d * (us - 32);
|
||||
|
||||
float sum = 0;
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
int8_t ql = (q[l + k] >> shift) & 3;
|
||||
sum += y[iqs + k] * (ql - ((hm[l + k] & m) ? 0 : 4));
|
||||
}
|
||||
result = sum * scale;
|
||||
}
|
||||
|
||||
static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
||||
const half * x = (const half *) vx;
|
||||
|
||||
|
@ -258,6 +346,43 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
|
|||
}
|
||||
}
|
||||
|
||||
template <dot_kernel_k_t dot_kernel>
|
||||
static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const int iter_stride = QK_K;
|
||||
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
|
||||
|
||||
float tmp = 0; // partial sum for thread in warp
|
||||
|
||||
for (int i = 0; i < ncols; i += iter_stride) {
|
||||
const int col = i + vals_per_iter*tid;
|
||||
const int ib = (row*ncols + col)/QK_K; // x block index
|
||||
const int iqs = col%QK_K; // x quant index
|
||||
const int iybs = col - col%QK_K; // y block start index
|
||||
|
||||
// processing >2 values per i iter is faster for fast GPUs
|
||||
#pragma unroll
|
||||
for (int j = 0; j < vals_per_iter; j += 8) {
|
||||
float v;
|
||||
dot_kernel(vx, ib, iqs + j, y + iybs, v);
|
||||
tmp += v;
|
||||
}
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
dst[row] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
|
||||
const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
|
||||
mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
|
||||
|
@ -288,6 +413,11 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
|
|||
dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
}
|
||||
|
||||
static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_q3_K<<<nb, 8, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
|
||||
|
@ -328,6 +458,12 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f
|
|||
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
dequantize_mul_mat_vec_k<vec_dot_q3_K><<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
||||
dequantize_block<32, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
|
@ -353,6 +489,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|||
return dequantize_row_q5_1_cuda;
|
||||
case GGML_TYPE_Q8_0:
|
||||
return dequantize_row_q8_0_cuda;
|
||||
case GGML_TYPE_Q3_K:
|
||||
return dequantize_row_q3_K_cuda;
|
||||
case GGML_TYPE_F16:
|
||||
return convert_fp16_to_fp32_cuda;
|
||||
default:
|
||||
|
@ -372,6 +510,8 @@ static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_t
|
|||
return dequantize_mul_mat_vec_q5_1_cuda;
|
||||
case GGML_TYPE_Q8_0:
|
||||
return dequantize_mul_mat_vec_q8_0_cuda;
|
||||
case GGML_TYPE_Q3_K:
|
||||
return dequantize_mul_mat_vec_q3_K_cuda;
|
||||
case GGML_TYPE_F16:
|
||||
return convert_mul_mat_vec_f16_cuda;
|
||||
default:
|
||||
|
@ -790,12 +930,14 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
|
|||
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
|
||||
|
||||
// compute
|
||||
//printf("Calling dmmv\n");
|
||||
dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
} else { // general dequantization kernel + cuBLAS matrix matrix multiplication
|
||||
float * c_X = d_X + i * x_ne;
|
||||
|
||||
//typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
|
||||
// convert src0 to fp32 on device
|
||||
to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
|
16
ggml.c
16
ggml.c
|
@ -1570,8 +1570,8 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
|||
.dequantize_row_q = (dequantize_row_q_t) dequantize_row_q3_K,
|
||||
.quantize_row_q = quantize_row_q3_K,
|
||||
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q3_K_reference,
|
||||
.quantize_row_q_dot = NULL, //quantize_row_q8_K,
|
||||
.vec_dot_q = NULL, //ggml_vec_dot_q3_K_q8_K,
|
||||
.quantize_row_q_dot = quantize_row_q8_K,
|
||||
.vec_dot_q = ggml_vec_dot_q3_K_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
},
|
||||
};
|
||||
|
@ -7602,6 +7602,7 @@ static void ggml_compute_forward_add(
|
|||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q3_K:
|
||||
{
|
||||
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
||||
} break;
|
||||
|
@ -7905,6 +7906,7 @@ static void ggml_compute_forward_add1(
|
|||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q8_1:
|
||||
case GGML_TYPE_Q3_K:
|
||||
{
|
||||
ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
|
||||
} break;
|
||||
|
@ -8027,6 +8029,7 @@ static void ggml_compute_forward_acc(
|
|||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q8_1:
|
||||
case GGML_TYPE_Q3_K:
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
|
@ -10120,6 +10123,7 @@ static void ggml_compute_forward_mul_mat(
|
|||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q8_1:
|
||||
case GGML_TYPE_Q3_K:
|
||||
{
|
||||
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
|
||||
} break;
|
||||
|
@ -10303,6 +10307,7 @@ static void ggml_compute_forward_set(
|
|||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q8_1:
|
||||
case GGML_TYPE_Q3_K:
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
|
@ -10468,6 +10473,7 @@ static void ggml_compute_forward_get_rows(
|
|||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q8_1:
|
||||
case GGML_TYPE_Q3_K:
|
||||
{
|
||||
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
||||
} break;
|
||||
|
@ -16092,6 +16098,12 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
|
|||
block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
|
||||
result = ggml_quantize_q8_0(src + start, block, n, n, hist);
|
||||
} break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
{
|
||||
GGML_ASSERT(start % QK_K == 0);
|
||||
block_q3_K * block = (block_q3_K*)dst + start / QK_K;
|
||||
result = ggml_quantize_q3_K(src + start, block, n, n, hist);
|
||||
} break;
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
|
|
214
k_quants.c
214
k_quants.c
|
@ -5,6 +5,8 @@
|
|||
#include <string.h>
|
||||
#include <assert.h>
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
#undef MIN
|
||||
#undef MAX
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
|
@ -456,3 +458,215 @@ void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) {
|
|||
quantize_row_q8_K_reference(x, y, k);
|
||||
}
|
||||
|
||||
//===================================== Dot ptoducts =================================
|
||||
|
||||
//
|
||||
// Helper functions
|
||||
//
|
||||
#if __AVX__ || __AVX2__ || __AVX512F__
|
||||
|
||||
// horizontally add 8 floats
|
||||
static inline float hsum_float_8(const __m256 x) {
|
||||
__m128 res = _mm256_extractf128_ps(x, 1);
|
||||
res = _mm_add_ps(res, _mm256_castps256_ps128(x));
|
||||
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
|
||||
res = _mm_add_ss(res, _mm_movehdup_ps(res));
|
||||
return _mm_cvtss_f32(res);
|
||||
}
|
||||
|
||||
// shuffle to pick the required scales in dot products
|
||||
static inline __m256i get_scale_shuffle_q3k(int i) {
|
||||
static const uint8_t k_shuffle[128] = {
|
||||
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
|
||||
4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
|
||||
8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
|
||||
12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
|
||||
};
|
||||
return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
assert(n % QK_K == 0);
|
||||
|
||||
const uint32_t kmask1 = 0x03030303;
|
||||
const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
|
||||
const block_q3_K * restrict x = vx;
|
||||
const block_q8_K * restrict y = vy;
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
#ifdef z__ARM_NEON
|
||||
// TODO
|
||||
#elif defined __AVX2__
|
||||
|
||||
const __m256i m3 = _mm256_set1_epi8(3);
|
||||
const __m256i mone = _mm256_set1_epi8(1);
|
||||
const __m128i m32 = _mm_set1_epi8(32);
|
||||
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
uint32_t aux[3];
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
||||
|
||||
const uint8_t * restrict q3 = x[i].qs;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
|
||||
// Set up scales
|
||||
memcpy(aux, x[i].scales, 12);
|
||||
__m128i scales128 = _mm_set_epi32(
|
||||
((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
|
||||
((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
|
||||
(aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
|
||||
(aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
|
||||
scales128 = _mm_sub_epi8(scales128, m32);
|
||||
const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
|
||||
const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
|
||||
const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
|
||||
const __m256i scales[2] = {_mm256_set_m128i(l_scales, l_scales), _mm256_set_m128i(h_scales, h_scales)};
|
||||
|
||||
// high bit
|
||||
const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);
|
||||
|
||||
// integer accumulator
|
||||
__m256i sumi = _mm256_setzero_si256();
|
||||
|
||||
int bit = 0;
|
||||
int is = 0;
|
||||
|
||||
for (int j = 0; j < QK_K/128; ++j) {
|
||||
|
||||
// load low 2 bits
|
||||
const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32;
|
||||
|
||||
// prepare low and high bits
|
||||
const __m256i q3l_0 = _mm256_and_si256(q3bits, m3);
|
||||
const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
|
||||
++bit;
|
||||
|
||||
const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
|
||||
const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
|
||||
++bit;
|
||||
|
||||
const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
|
||||
const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
|
||||
++bit;
|
||||
|
||||
const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
|
||||
const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
|
||||
++bit;
|
||||
|
||||
// load Q8 quants
|
||||
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
|
||||
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
|
||||
// and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
|
||||
// and 2 if the high bit was set)
|
||||
__m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
|
||||
__m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
|
||||
__m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
|
||||
__m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
|
||||
|
||||
__m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
|
||||
__m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
|
||||
__m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
|
||||
__m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
|
||||
|
||||
p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
|
||||
p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
|
||||
p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
|
||||
p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
|
||||
|
||||
// multiply with scales
|
||||
p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
|
||||
p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
|
||||
p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
|
||||
p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
|
||||
|
||||
// accumulate
|
||||
p16_0 = _mm256_add_epi32(p16_0, p16_1);
|
||||
p16_2 = _mm256_add_epi32(p16_2, p16_3);
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
|
||||
|
||||
}
|
||||
|
||||
// multiply with block scale and accumulate
|
||||
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
|
||||
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc);
|
||||
|
||||
#else
|
||||
// scalar version
|
||||
// This function is written like this so the compiler can manage to vectorize most of it
|
||||
// Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
|
||||
// manually vectorized version above. Every other version I tried would run at least 4 times slower.
|
||||
// The ideal situation would be if we could just write the code once, and the compiler would
|
||||
// automatically produce the best possible set of machine instructions, instead of us having to manually
|
||||
// write vectorized versions for AVX, ARM_NEON, etc.
|
||||
|
||||
int8_t aux8[QK_K];
|
||||
int16_t aux16[8];
|
||||
float sums [8];
|
||||
int32_t aux32[8];
|
||||
memset(sums, 0, 8*sizeof(float));
|
||||
|
||||
uint32_t auxs[4];
|
||||
const int8_t * scales = (const int8_t*)auxs;
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const uint8_t * restrict q3 = x[i].qs;
|
||||
const uint8_t * restrict hm = x[i].hmask;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
memset(aux32, 0, 8*sizeof(int32_t));
|
||||
int8_t * restrict a = aux8;
|
||||
uint8_t m = 1;
|
||||
for (int j = 0; j < QK_K; j += 128) {
|
||||
for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
|
||||
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
|
||||
a += 32; m <<= 1;
|
||||
for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
|
||||
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
|
||||
a += 32; m <<= 1;
|
||||
for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
|
||||
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
|
||||
a += 32; m <<= 1;
|
||||
for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
|
||||
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
|
||||
a += 32; m <<= 1;
|
||||
q3 += 32;
|
||||
}
|
||||
a = aux8;
|
||||
|
||||
memcpy(auxs, x[i].scales, 12);
|
||||
uint32_t tmp = auxs[2];
|
||||
auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
|
||||
auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
|
||||
auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
|
||||
auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
|
||||
for (int j = 0; j < QK_K/16; ++j) {
|
||||
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
||||
for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
|
||||
q8 += 8; a += 8;
|
||||
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
||||
for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
|
||||
q8 += 8; a += 8;
|
||||
}
|
||||
const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
|
||||
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
||||
}
|
||||
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
||||
*s = sumf;
|
||||
|
||||
#endif
|
||||
|
||||
}
|
||||
|
|
|
@ -507,6 +507,7 @@ struct llama_file_loader {
|
|||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q3_K:
|
||||
break;
|
||||
default: {
|
||||
throw format("unrecognized tensor type %u\n", shard.type);
|
||||
|
@ -582,6 +583,7 @@ struct llama_file_saver {
|
|||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q3_K:
|
||||
break;
|
||||
default: LLAMA_ASSERT(false);
|
||||
}
|
||||
|
@ -898,6 +900,7 @@ static const char *llama_ftype_name(enum llama_ftype ftype) {
|
|||
case LLAMA_FTYPE_MOSTLY_Q5_0: return "mostly Q5_0";
|
||||
case LLAMA_FTYPE_MOSTLY_Q5_1: return "mostly Q5_1";
|
||||
case LLAMA_FTYPE_MOSTLY_Q8_0: return "mostly Q8_0";
|
||||
case LLAMA_FTYPE_MOSTLY_Q3_K: return "mostly Q3_K";
|
||||
default: return "unknown, may not work";
|
||||
}
|
||||
}
|
||||
|
@ -2063,6 +2066,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q5_1: quantized_type = GGML_TYPE_Q5_1; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q3_K: quantized_type = GGML_TYPE_Q3_K; break;
|
||||
default: throw format("invalid output file type %d\n", ftype);
|
||||
};
|
||||
|
||||
|
|
1
llama.h
1
llama.h
|
@ -94,6 +94,7 @@ extern "C" {
|
|||
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q3_K = 10,// except 1d tensors
|
||||
};
|
||||
|
||||
LLAMA_API struct llama_context_params llama_context_default_params();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue