iq3_xxs: CUDA dot product
We have PP-512: 5891 t/s TG-128: 143.9 t/s
This commit is contained in:
parent
f120672964
commit
f1875b0a93
2 changed files with 46 additions and 1 deletions
42
ggml-cuda.cu
42
ggml-cuda.cu
|
@ -4418,6 +4418,36 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
|
||||||
|
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||||
|
#if QK_K == 256
|
||||||
|
const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq;
|
||||||
|
|
||||||
|
const int ib32 = iqs;
|
||||||
|
const uint8_t * q3 = bq2->qs + 8*ib32;
|
||||||
|
const uint16_t * gas = (const uint16_t *)(bq2->qs + QK_K/4) + 2*ib32;
|
||||||
|
const int8_t * q8 = bq8_1[ib32].qs;
|
||||||
|
uint32_t aux32 = gas[0] | (gas[1] << 16);
|
||||||
|
int sumi = 0;
|
||||||
|
for (int l = 0; l < 4; ++l) {
|
||||||
|
const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
|
||||||
|
const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
|
||||||
|
const uint8_t signs = ksigns_iq2xs[aux32 & 127];
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
sumi += q8[j+0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
|
||||||
|
sumi += q8[j+4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
|
||||||
|
}
|
||||||
|
q8 += 8;
|
||||||
|
aux32 >>= 7;
|
||||||
|
}
|
||||||
|
const float d = (float)bq2->d * (0.5f + aux32) * __low2float(bq8_1[ib32].ds) * 0.5f;
|
||||||
|
return d * sumi;
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
return 0.f;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
||||||
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
||||||
static __device__ __forceinline__ void mul_mat_q(
|
static __device__ __forceinline__ void mul_mat_q(
|
||||||
|
@ -6744,6 +6774,15 @@ static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, float
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
|
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||||
|
const dim3 block_nums(block_num_y, 1, 1);
|
||||||
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||||
|
mul_mat_vec_q<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_mul_mat_q4_0_q8_1_cuda(
|
static void ggml_mul_mat_q4_0_q8_1_cuda(
|
||||||
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||||
|
@ -8389,6 +8428,9 @@ static void ggml_cuda_op_mul_mat_vec_q(
|
||||||
case GGML_TYPE_IQ2_XS:
|
case GGML_TYPE_IQ2_XS:
|
||||||
mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
||||||
break;
|
break;
|
||||||
|
case GGML_TYPE_IQ3_XXS:
|
||||||
|
mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -9029,9 +9029,12 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
|
||||||
auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str());
|
auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str());
|
||||||
int i_layer = info.first, n_layer = info.second;
|
int i_layer = info.first, n_layer = info.second;
|
||||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
|
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS) {// || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
|
else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS) {
|
||||||
if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K;
|
if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K;
|
||||||
}
|
}
|
||||||
|
//else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
|
||||||
|
// if (i_layer < n_layer/8) new_type = GGML_TYPE_Q5_K;
|
||||||
|
//}
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
|
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
|
||||||
new_type = i_layer < n_layer/16 ? GGML_TYPE_Q5_K
|
new_type = i_layer < n_layer/16 ? GGML_TYPE_Q5_K
|
||||||
: arch != LLM_ARCH_FALCON || use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q4_K
|
: arch != LLM_ARCH_FALCON || use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q4_K
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue