align GEMM dispatch

This commit is contained in:
Meng, Hengyu 2024-05-28 05:11:55 +08:00
parent 9335b969e8
commit 583c81c91c
2 changed files with 80 additions and 70 deletions

View file

@ -96,8 +96,8 @@ option(LLAMA_LLAMAFILE "llama: use llamafile SGEMM"
set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
option(LLAMA_CUDA "llama: use CUDA" OFF) option(LLAMA_CUDA "llama: use CUDA" OFF)
option(LLAMA_CUBLAS "llama: use CUDA (deprecated, use LLAMA_CUDA)" OFF) option(LLAMA_CUBLAS "llama: use CUDA (deprecated, use LLAMA_CUDA)" OFF)
option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF) option(LLAMA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF)
option(LLAMA_CUDA_FORCE_MMQ "llama: use mmq kernels instead of cuBLAS" OFF) option(LLAMA_FORCE_MMQ "llama: use mmq kernels instead of cuBLAS" OFF)
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels") set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF) option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF)
@ -405,10 +405,10 @@ if (LLAMA_CUDA)
add_compile_definitions(GGML_USE_CUDA) add_compile_definitions(GGML_USE_CUDA)
add_compile_definitions(GGML_CUDA_USE_GRAPHS) add_compile_definitions(GGML_CUDA_USE_GRAPHS)
if (LLAMA_CUDA_FORCE_DMMV) if (LLAMA_FORCE_DMMV)
add_compile_definitions(GGML_CUDA_FORCE_DMMV) add_compile_definitions(GGML_CUDA_FORCE_DMMV)
endif() endif()
if (LLAMA_CUDA_FORCE_MMQ) if (LLAMA_FORCE_MMQ)
add_compile_definitions(GGML_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ)
endif() endif()
if (LLAMA_CUDA_NO_VMM) if (LLAMA_CUDA_NO_VMM)
@ -578,11 +578,11 @@ if (LLAMA_HIPBLAS)
add_compile_definitions(GGML_HIP_UMA) add_compile_definitions(GGML_HIP_UMA)
endif() endif()
if (LLAMA_CUDA_FORCE_DMMV) if (LLAMA_FORCE_DMMV)
add_compile_definitions(GGML_CUDA_FORCE_DMMV) add_compile_definitions(GGML_CUDA_FORCE_DMMV)
endif() endif()
if (LLAMA_CUDA_FORCE_MMQ) if (LLAMA_FORCE_MMQ)
add_compile_definitions(GGML_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ)
endif() endif()
@ -628,6 +628,13 @@ if (LLAMA_SYCL)
add_compile_definitions(GGML_SYCL_F16) add_compile_definitions(GGML_SYCL_F16)
endif() endif()
if (LLAMA_SYCL_FORCE_DMMV)
add_compile_definitions(GGML_SYCL_FORCE_DMMV)
endif()
if (LLAMA_SYCL_FORCE_MMQ)
add_compile_definitions(GGML_SYCL_FORCE_MMQ)
endif()
add_compile_options(-I./) #include DPCT add_compile_options(-I./) #include DPCT
add_compile_options(-I/${SYCL_INCLUDE_DIR}) add_compile_options(-I/${SYCL_INCLUDE_DIR})

View file

@ -2971,7 +2971,7 @@ static int g_work_group_size = 0;
// typedef sycl::half ggml_fp16_t; // typedef sycl::half ggml_fp16_t;
#define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP #define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP
#define VER_4VEC 610 //todo for hardward optimize. #define VER_4VEC 130 //todo for hardward optimize.
#define VER_GEN9 700 //todo for hardward optimize. #define VER_GEN9 700 //todo for hardward optimize.
#define VER_GEN12 1000000 //todo for hardward optimize. #define VER_GEN12 1000000 //todo for hardward optimize.
#define VER_GEN13 (VER_GEN12 + 1030) //todo for hardward optimize. #define VER_GEN13 (VER_GEN12 + 1030) //todo for hardward optimize.
@ -2984,7 +2984,7 @@ static int g_work_group_size = 0;
#define SYCL_USE_XMX #define SYCL_USE_XMX
// max batch size to use MMQ kernels when tensor cores are available // max batch size to use MMQ kernels when tensor cores are available
#define XMX_MAX_BATCH_SIZE 32 #define MMQ_MAX_BATCH_SIZE 32
#if defined(_MSC_VER) #if defined(_MSC_VER)
@ -15193,6 +15193,25 @@ catch (sycl::exception const &exc) {
std::exit(1); std::exit(1);
} }
bool ggml_sycl_supports_mmq(enum ggml_type type) {
// TODO: accuracy issues in MMQ
return false;
// switch (type) {
// case GGML_TYPE_Q4_0:
// case GGML_TYPE_Q4_1:
// case GGML_TYPE_Q5_0:
// case GGML_TYPE_Q5_1:
// case GGML_TYPE_Q8_0:
// case GGML_TYPE_Q2_K:
// case GGML_TYPE_Q3_K:
// case GGML_TYPE_Q4_K:
// case GGML_TYPE_Q5_K:
// case GGML_TYPE_Q6_K:
// return true;
// default:
// return false;
// }
}
static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const bool all_on_device = const bool all_on_device =
@ -15209,76 +15228,60 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
} }
} }
#ifdef SYCL_USE_XMX #if !defined(GGML_SYCL_FORCE_MMQ)
const bool use_xmx = true; #define SYCL_USE_XMX
#else
const bool use_xmx = false;
#endif #endif
// debug helpers #ifdef SYCL_USE_XMX
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); bool use_xmx = true;
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
//printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
//printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
// KQ single-batch
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_p021\n");
ggml_sycl_mul_mat_vec_p021(src0, src1, dst);
} else if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
// KQV single-batch
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_nc\n");
ggml_sycl_mul_mat_vec_nc(src0, src1, dst);
} else if (!split && all_on_device && use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
// KQ + KQV multi-batch
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n");
ggml_sycl_mul_mat_batched_sycl(src0, src1, dst);
} else if (src0->type == GGML_TYPE_F32) {
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n");
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
// GGML_SYCL_DEBUG("ggml_is_quantized or GGML_TYPE_F16\n");
if (src1->ne[1] == 1 && src0->ne[0] % GGML_SYCL_DMMV_X == 0) {
#ifdef GGML_SYCL_FORCE_DMMV
const bool use_mul_mat_vec_q = false;
#else #else
bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type); bool use_xmx = false;
use_mul_mat_vec_q = use_mul_mat_vec_q || #endif
(src0->type == GGML_TYPE_IQ2_XXS) || (src0->type == GGML_TYPE_IQ2_XS) || (src0->type == GGML_TYPE_IQ2_S) ||
(src0->type == GGML_TYPE_IQ3_XXS) || (src0->type == GGML_TYPE_IQ3_S) ||
(src0->type == GGML_TYPE_IQ4_NL) || (src0->type == GGML_TYPE_IQ4_XS) ||
(src0->type == GGML_TYPE_IQ1_S) || (src0->type == GGML_TYPE_IQ1_M);
// check data types and tensor shapes for custom matrix multiplication kernels:
bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
// fp16 performance always better on gen12+
const bool fp16_performance_good = true;
// mmvq and mmq need the __dp4a instruction which is available for gen12+
use_mul_mat_vec_q = use_mul_mat_vec_q; // Check dp4a
// Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
#ifdef SYCL_USE_XMX
use_mul_mat_q = use_mul_mat_q && (!fp16_performance_good || src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
#endif // SYCL_USE_XMX
#infdef GGML_SYCL_FORCE_DMMV
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
#endif // GGML_SYCL_FORCE_DMMV #endif // GGML_SYCL_FORCE_DMMV
if (use_mul_mat_vec_q) { if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n"); // KQ single-batch
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true); ggml_sycl_mul_mat_vec_p021(src0, src1, dst);
} else { } else if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n"); // KQV single-batch
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false); ggml_sycl_mul_mat_vec_nc(src0, src1, dst);
} } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || fp16_performance_good) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
} else { // KQ + KQV multi-batch
bool use_mul_mat_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type); ggml_sycl_mul_mat_batched_sycl(src0, src1, dst);
use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS); } else if (use_dequantize_mul_mat_vec) {
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
if (use_xmx && min_compute_capability >= VER_GEN9 && src1->ne[1] > XMX_MAX_BATCH_SIZE) { } else if (use_mul_mat_vec_q) {
use_mul_mat_q = false; ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
} } else if (use_mul_mat_q) {
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
if (use_mul_mat_q) {
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_q path\n");
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
} else {
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_sycl path\n");
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
}
}
} else { } else {
GGML_ASSERT(false); ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
} }
} }