Merge: Testing speed of tensor cores vs MMQ
This commit is contained in:
commit
2ea3b567cf
5 changed files with 154 additions and 28 deletions
|
@ -1502,7 +1502,7 @@ struct llama_server_context
|
|||
{
|
||||
for (auto & slot : slots)
|
||||
{
|
||||
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty());
|
||||
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty()) || !slot.images.empty();
|
||||
|
||||
// empty prompt passed -> release the slot and send empty response
|
||||
if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt)
|
||||
|
|
|
@ -95,13 +95,8 @@ int main(int argc, char ** argv) {
|
|||
llama_batch batch = llama_batch_init(512, 0, 1);
|
||||
|
||||
// evaluate the initial prompt
|
||||
batch.n_tokens = tokens_list.size();
|
||||
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
batch.token[i] = tokens_list[i];
|
||||
batch.pos[i] = i;
|
||||
batch.seq_id[i] = 0;
|
||||
batch.logits[i] = false;
|
||||
for (size_t i = 0; i < tokens_list.size(); i++) {
|
||||
llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
|
||||
}
|
||||
|
||||
// llama_decode will output logits only for the last token of the prompt
|
||||
|
@ -148,15 +143,10 @@ int main(int argc, char ** argv) {
|
|||
fflush(stdout);
|
||||
|
||||
// prepare the next batch
|
||||
batch.n_tokens = 0;
|
||||
llama_batch_clear(batch);
|
||||
|
||||
// push this new token for next evaluation
|
||||
batch.token [batch.n_tokens] = new_token_id;
|
||||
batch.pos [batch.n_tokens] = n_cur;
|
||||
batch.seq_id[batch.n_tokens] = 0;
|
||||
batch.logits[batch.n_tokens] = true;
|
||||
|
||||
batch.n_tokens += 1;
|
||||
llama_batch_add(batch, new_token_id, n_cur, { 0 }, true);
|
||||
|
||||
n_decode += 1;
|
||||
}
|
||||
|
|
|
@ -8,6 +8,9 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
|
||||
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
||||
|
||||
struct seq_draft {
|
||||
bool active = false;
|
||||
bool drafting = false;
|
||||
|
@ -64,6 +67,33 @@ int main(int argc, char ** argv) {
|
|||
params.n_gpu_layers = params.n_gpu_layers_draft;
|
||||
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
|
||||
|
||||
{
|
||||
const int n_vocab_tgt = llama_n_vocab(model_tgt);
|
||||
const int n_vocab_dft = llama_n_vocab(model_dft);
|
||||
const int vocab_diff = n_vocab_tgt > n_vocab_dft
|
||||
? n_vocab_tgt - n_vocab_dft
|
||||
: n_vocab_dft - n_vocab_tgt;
|
||||
|
||||
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
||||
fprintf(stderr, "%s: error: draft model vocab must closely match target model to use speculation but ", __func__);
|
||||
fprintf(stderr, "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
|
||||
n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
|
||||
const char * token_text_tgt = llama_token_get_text(model_tgt, i);
|
||||
const char * token_text_dft = llama_token_get_text(model_dft, i);
|
||||
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
||||
fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__);
|
||||
fprintf(stderr, "token %d content differs - target '%s', draft '%s'\n", i,
|
||||
llama_token_to_piece(ctx_tgt, i).c_str(),
|
||||
llama_token_to_piece(ctx_dft, i).c_str());
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tokenize the prompt
|
||||
std::vector<llama_token> inp;
|
||||
inp = ::llama_tokenize(ctx_tgt, params.prompt, true);
|
||||
|
@ -227,6 +257,7 @@ int main(int argc, char ** argv) {
|
|||
llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
|
||||
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||
// LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
||||
llama_decode (ctx_dft, batch_dft);
|
||||
|
||||
++n_past_dft;
|
||||
|
@ -370,7 +401,7 @@ int main(int argc, char ** argv) {
|
|||
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
|
||||
}
|
||||
|
||||
//LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt));
|
||||
// LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
|
||||
llama_decode(ctx_tgt, batch_tgt);
|
||||
++n_past_tgt;
|
||||
}
|
||||
|
|
125
ggml-cuda.cu
125
ggml-cuda.cu
|
@ -87,6 +87,24 @@
|
|||
#define CC_OFFSET_AMD 1000000
|
||||
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
|
||||
|
||||
// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
|
||||
// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
|
||||
// for large computational tasks. the drawback is that this requires some extra amount of VRAM:
|
||||
// - 7B quantum model: +100-200 MB
|
||||
// - 13B quantum model: +200-400 MB
|
||||
//
|
||||
//#define GGML_CUDA_FORCE_MMQ
|
||||
|
||||
// TODO: improve this to be correct for more hardware
|
||||
// for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores
|
||||
// probably other such cases, and not sure what happens on AMD hardware
|
||||
#if !defined(GGML_CUDA_FORCE_MMQ)
|
||||
#define CUDA_USE_TENSOR_CORES
|
||||
#endif
|
||||
|
||||
// max batch size to use MMQ kernels when tensor cores are available
|
||||
#define MMQ_MAX_BATCH_SIZE 32
|
||||
|
||||
#if defined(GGML_USE_HIPBLAS)
|
||||
#define __CUDA_ARCH__ 1300
|
||||
|
||||
|
@ -3554,9 +3572,15 @@ static __device__ __forceinline__ void mul_mat_q(
|
|||
#define MMQ_X_Q4_0_RDNA1 64
|
||||
#define MMQ_Y_Q4_0_RDNA1 64
|
||||
#define NWARPS_Q4_0_RDNA1 8
|
||||
#if defined(CUDA_USE_TENSOR_CORES)
|
||||
#define MMQ_X_Q4_0_AMPERE 4
|
||||
#define MMQ_Y_Q4_0_AMPERE 32
|
||||
#define NWARPS_Q4_0_AMPERE 4
|
||||
#else
|
||||
#define MMQ_X_Q4_0_AMPERE 64
|
||||
#define MMQ_Y_Q4_0_AMPERE 128
|
||||
#define NWARPS_Q4_0_AMPERE 4
|
||||
#endif
|
||||
#define MMQ_X_Q4_0_PASCAL 64
|
||||
#define MMQ_Y_Q4_0_PASCAL 64
|
||||
#define NWARPS_Q4_0_PASCAL 8
|
||||
|
@ -3615,9 +3639,15 @@ template <bool need_check> static __global__ void
|
|||
#define MMQ_X_Q4_1_RDNA1 64
|
||||
#define MMQ_Y_Q4_1_RDNA1 64
|
||||
#define NWARPS_Q4_1_RDNA1 8
|
||||
#if defined(CUDA_USE_TENSOR_CORES)
|
||||
#define MMQ_X_Q4_1_AMPERE 4
|
||||
#define MMQ_Y_Q4_1_AMPERE 32
|
||||
#define NWARPS_Q4_1_AMPERE 4
|
||||
#else
|
||||
#define MMQ_X_Q4_1_AMPERE 64
|
||||
#define MMQ_Y_Q4_1_AMPERE 128
|
||||
#define NWARPS_Q4_1_AMPERE 4
|
||||
#endif
|
||||
#define MMQ_X_Q4_1_PASCAL 64
|
||||
#define MMQ_Y_Q4_1_PASCAL 64
|
||||
#define NWARPS_Q4_1_PASCAL 8
|
||||
|
@ -3678,9 +3708,15 @@ template <bool need_check> static __global__ void
|
|||
#define MMQ_X_Q5_0_RDNA1 64
|
||||
#define MMQ_Y_Q5_0_RDNA1 64
|
||||
#define NWARPS_Q5_0_RDNA1 8
|
||||
#if defined(CUDA_USE_TENSOR_CORES)
|
||||
#define MMQ_X_Q5_0_AMPERE 4
|
||||
#define MMQ_Y_Q5_0_AMPERE 32
|
||||
#define NWARPS_Q5_0_AMPERE 4
|
||||
#else
|
||||
#define MMQ_X_Q5_0_AMPERE 128
|
||||
#define MMQ_Y_Q5_0_AMPERE 64
|
||||
#define NWARPS_Q5_0_AMPERE 4
|
||||
#endif
|
||||
#define MMQ_X_Q5_0_PASCAL 64
|
||||
#define MMQ_Y_Q5_0_PASCAL 64
|
||||
#define NWARPS_Q5_0_PASCAL 8
|
||||
|
@ -3739,9 +3775,15 @@ template <bool need_check> static __global__ void
|
|||
#define MMQ_X_Q5_1_RDNA1 64
|
||||
#define MMQ_Y_Q5_1_RDNA1 64
|
||||
#define NWARPS_Q5_1_RDNA1 8
|
||||
#if defined(CUDA_USE_TENSOR_CORES)
|
||||
#define MMQ_X_Q5_1_AMPERE 4
|
||||
#define MMQ_Y_Q5_1_AMPERE 32
|
||||
#define NWARPS_Q5_1_AMPERE 4
|
||||
#else
|
||||
#define MMQ_X_Q5_1_AMPERE 128
|
||||
#define MMQ_Y_Q5_1_AMPERE 64
|
||||
#define NWARPS_Q5_1_AMPERE 4
|
||||
#endif
|
||||
#define MMQ_X_Q5_1_PASCAL 64
|
||||
#define MMQ_Y_Q5_1_PASCAL 64
|
||||
#define NWARPS_Q5_1_PASCAL 8
|
||||
|
@ -3800,9 +3842,15 @@ mul_mat_q5_1(
|
|||
#define MMQ_X_Q8_0_RDNA1 64
|
||||
#define MMQ_Y_Q8_0_RDNA1 64
|
||||
#define NWARPS_Q8_0_RDNA1 8
|
||||
#if defined(CUDA_USE_TENSOR_CORES)
|
||||
#define MMQ_X_Q8_0_AMPERE 4
|
||||
#define MMQ_Y_Q8_0_AMPERE 32
|
||||
#define NWARPS_Q8_0_AMPERE 4
|
||||
#else
|
||||
#define MMQ_X_Q8_0_AMPERE 128
|
||||
#define MMQ_Y_Q8_0_AMPERE 64
|
||||
#define NWARPS_Q8_0_AMPERE 4
|
||||
#endif
|
||||
#define MMQ_X_Q8_0_PASCAL 64
|
||||
#define MMQ_Y_Q8_0_PASCAL 64
|
||||
#define NWARPS_Q8_0_PASCAL 8
|
||||
|
@ -3861,9 +3909,15 @@ template <bool need_check> static __global__ void
|
|||
#define MMQ_X_Q2_K_RDNA1 128
|
||||
#define MMQ_Y_Q2_K_RDNA1 32
|
||||
#define NWARPS_Q2_K_RDNA1 8
|
||||
#if defined(CUDA_USE_TENSOR_CORES)
|
||||
#define MMQ_X_Q2_K_AMPERE 4
|
||||
#define MMQ_Y_Q2_K_AMPERE 32
|
||||
#define NWARPS_Q2_K_AMPERE 4
|
||||
#else
|
||||
#define MMQ_X_Q2_K_AMPERE 64
|
||||
#define MMQ_Y_Q2_K_AMPERE 128
|
||||
#define NWARPS_Q2_K_AMPERE 4
|
||||
#endif
|
||||
#define MMQ_X_Q2_K_PASCAL 64
|
||||
#define MMQ_Y_Q2_K_PASCAL 64
|
||||
#define NWARPS_Q2_K_PASCAL 8
|
||||
|
@ -3922,9 +3976,15 @@ mul_mat_q2_K(
|
|||
#define MMQ_X_Q3_K_RDNA1 32
|
||||
#define MMQ_Y_Q3_K_RDNA1 128
|
||||
#define NWARPS_Q3_K_RDNA1 8
|
||||
#if defined(CUDA_USE_TENSOR_CORES)
|
||||
#define MMQ_X_Q3_K_AMPERE 4
|
||||
#define MMQ_Y_Q3_K_AMPERE 32
|
||||
#define NWARPS_Q3_K_AMPERE 4
|
||||
#else
|
||||
#define MMQ_X_Q3_K_AMPERE 128
|
||||
#define MMQ_Y_Q3_K_AMPERE 128
|
||||
#define NWARPS_Q3_K_AMPERE 4
|
||||
#endif
|
||||
#define MMQ_X_Q3_K_PASCAL 64
|
||||
#define MMQ_Y_Q3_K_PASCAL 64
|
||||
#define NWARPS_Q3_K_PASCAL 8
|
||||
|
@ -3985,9 +4045,15 @@ template <bool need_check> static __global__ void
|
|||
#define MMQ_X_Q4_K_RDNA1 32
|
||||
#define MMQ_Y_Q4_K_RDNA1 64
|
||||
#define NWARPS_Q4_K_RDNA1 8
|
||||
#if defined(CUDA_USE_TENSOR_CORES)
|
||||
#define MMQ_X_Q4_K_AMPERE 4
|
||||
#define MMQ_Y_Q4_K_AMPERE 32
|
||||
#define NWARPS_Q4_K_AMPERE 4
|
||||
#else
|
||||
#define MMQ_X_Q4_K_AMPERE 64
|
||||
#define MMQ_Y_Q4_K_AMPERE 128
|
||||
#define NWARPS_Q4_K_AMPERE 4
|
||||
#endif
|
||||
#define MMQ_X_Q4_K_PASCAL 64
|
||||
#define MMQ_Y_Q4_K_PASCAL 64
|
||||
#define NWARPS_Q4_K_PASCAL 8
|
||||
|
@ -4048,9 +4114,15 @@ template <bool need_check> static __global__ void
|
|||
#define MMQ_X_Q5_K_RDNA1 32
|
||||
#define MMQ_Y_Q5_K_RDNA1 64
|
||||
#define NWARPS_Q5_K_RDNA1 8
|
||||
#if defined(CUDA_USE_TENSOR_CORES)
|
||||
#define MMQ_X_Q5_K_AMPERE 4
|
||||
#define MMQ_Y_Q5_K_AMPERE 32
|
||||
#define NWARPS_Q5_K_AMPERE 4
|
||||
#else
|
||||
#define MMQ_X_Q5_K_AMPERE 64
|
||||
#define MMQ_Y_Q5_K_AMPERE 128
|
||||
#define NWARPS_Q5_K_AMPERE 4
|
||||
#endif
|
||||
#define MMQ_X_Q5_K_PASCAL 64
|
||||
#define MMQ_Y_Q5_K_PASCAL 64
|
||||
#define NWARPS_Q5_K_PASCAL 8
|
||||
|
@ -4109,9 +4181,15 @@ mul_mat_q5_K(
|
|||
#define MMQ_X_Q6_K_RDNA1 32
|
||||
#define MMQ_Y_Q6_K_RDNA1 64
|
||||
#define NWARPS_Q6_K_RDNA1 8
|
||||
#if defined(CUDA_USE_TENSOR_CORES)
|
||||
#define MMQ_X_Q6_K_AMPERE 4
|
||||
#define MMQ_Y_Q6_K_AMPERE 32
|
||||
#define NWARPS_Q6_K_AMPERE 4
|
||||
#else
|
||||
#define MMQ_X_Q6_K_AMPERE 64
|
||||
#define MMQ_Y_Q6_K_AMPERE 64
|
||||
#define NWARPS_Q6_K_AMPERE 4
|
||||
#endif
|
||||
#define MMQ_X_Q6_K_PASCAL 64
|
||||
#define MMQ_Y_Q6_K_PASCAL 64
|
||||
#define NWARPS_Q6_K_PASCAL 8
|
||||
|
@ -5659,6 +5737,16 @@ void ggml_init_cublas() {
|
|||
CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
|
||||
GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
|
||||
int64_t total_vram = 0;
|
||||
#if defined(GGML_CUDA_FORCE_MMQ)
|
||||
fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
|
||||
#else
|
||||
fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
|
||||
#endif
|
||||
#if defined(CUDA_USE_TENSOR_CORES)
|
||||
fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: yes\n", __func__);
|
||||
#else
|
||||
fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: no\n", __func__);
|
||||
#endif
|
||||
fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count);
|
||||
for (int id = 0; id < g_device_count; ++id) {
|
||||
cudaDeviceProp prop;
|
||||
|
@ -6343,7 +6431,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
|||
cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
row_diff, src1_ncols, ne10,
|
||||
&alpha, src0_ddf_i, ne00,
|
||||
src1_ddf_i, ne10,
|
||||
src1_ddf_i, ne10,
|
||||
&beta, dst_dd_i, ldc));
|
||||
|
||||
if (src0_as != 0) {
|
||||
|
@ -7044,9 +7132,10 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
|
|||
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
|
||||
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(!ggml_is_transposed(src0));
|
||||
GGML_ASSERT(!ggml_is_transposed(src1));
|
||||
|
||||
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
@ -7198,17 +7287,24 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
|||
}
|
||||
|
||||
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
|
||||
src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
|
||||
const bool all_on_device =
|
||||
(src0->backend == GGML_BACKEND_GPU) &&
|
||||
(src1->backend == GGML_BACKEND_GPU) &&
|
||||
( dst->backend == GGML_BACKEND_GPU);
|
||||
|
||||
int64_t min_compute_capability = INT_MAX;
|
||||
for (int64_t id = 0; id < g_device_count; ++id) {
|
||||
if (min_compute_capability > g_compute_capabilities[id]
|
||||
&& g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
|
||||
if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
|
||||
min_compute_capability = g_compute_capabilities[id];
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef CUDA_USE_TENSOR_CORES
|
||||
const bool use_tensor_cores = true;
|
||||
#else
|
||||
const bool use_tensor_cores = false;
|
||||
#endif
|
||||
|
||||
// debug helpers
|
||||
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
|
||||
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
|
||||
|
@ -7217,20 +7313,19 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|||
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
||||
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
||||
|
||||
if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
// KQ single-batch
|
||||
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
|
||||
} else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
||||
} else if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
||||
// KQV single-batch
|
||||
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
|
||||
} else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
||||
} else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
|
||||
// KQ + KQV multi-batch
|
||||
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
|
||||
} else if (src0->type == GGML_TYPE_F32) {
|
||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
|
||||
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
|
||||
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
|
||||
|
||||
#ifdef GGML_CUDA_FORCE_DMMV
|
||||
const bool use_mul_mat_vec_q = false;
|
||||
#else
|
||||
|
@ -7243,7 +7338,15 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
|
||||
}
|
||||
} else {
|
||||
if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) {
|
||||
bool use_mul_mat_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
|
||||
|
||||
// when tensor cores are available, use them for large batch size
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/3776
|
||||
if (use_tensor_cores && min_compute_capability >= CC_VOLTA && src1->ne[1] > MMQ_MAX_BATCH_SIZE) {
|
||||
use_mul_mat_q = false;
|
||||
}
|
||||
|
||||
if (use_mul_mat_q) {
|
||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
|
||||
} else {
|
||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
|
||||
|
|
|
@ -1583,12 +1583,14 @@ static void llama_kv_cache_seq_shift(
|
|||
enum llama_fver {
|
||||
GGUF_FILE_VERSION_V1 = 1,
|
||||
GGUF_FILE_VERSION_V2 = 2,
|
||||
GGUF_FILE_VERSION_V3 = 3,
|
||||
};
|
||||
|
||||
static const char * llama_file_version_name(llama_fver version) {
|
||||
switch (version) {
|
||||
case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)";
|
||||
case GGUF_FILE_VERSION_V2: return "GGUF V2 (latest)";
|
||||
case GGUF_FILE_VERSION_V2: return "GGUF V2";
|
||||
case GGUF_FILE_VERSION_V3: return "GGUF V3 (latest)";
|
||||
}
|
||||
|
||||
return "unknown";
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue