CUDA: mul_mat_q RDNA2 tunings
This commit is contained in:
parent
00d62adb79
commit
d0ef910f51
3 changed files with 426 additions and 45 deletions
|
@ -388,7 +388,6 @@ if (LLAMA_HIPBLAS)
|
||||||
target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
|
target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
|
||||||
target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
|
target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
|
||||||
target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
|
target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
|
||||||
target_compile_definitions(ggml-rocm PRIVATE CC_TURING=1000000000)
|
|
||||||
set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
|
set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
|
||||||
target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas)
|
target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas)
|
||||||
|
|
||||||
|
|
1
Makefile
1
Makefile
|
@ -358,7 +358,6 @@ ifdef LLAMA_HIPBLAS
|
||||||
HIPFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
|
HIPFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
|
||||||
HIPFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y)
|
HIPFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y)
|
||||||
HIPFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
|
HIPFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
|
||||||
HIPFLAGS += -DCC_TURING=1000000000
|
|
||||||
ifdef LLAMA_CUDA_FORCE_DMMV
|
ifdef LLAMA_CUDA_FORCE_DMMV
|
||||||
HIPFLAGS += -DGGML_CUDA_FORCE_DMMV
|
HIPFLAGS += -DGGML_CUDA_FORCE_DMMV
|
||||||
endif # LLAMA_CUDA_FORCE_DMMV
|
endif # LLAMA_CUDA_FORCE_DMMV
|
||||||
|
|
469
ggml-cuda.cu
469
ggml-cuda.cu
|
@ -13,7 +13,7 @@
|
||||||
#ifdef __HIP_PLATFORM_AMD__
|
#ifdef __HIP_PLATFORM_AMD__
|
||||||
// for rocblas_initialize()
|
// for rocblas_initialize()
|
||||||
#include "rocblas/rocblas.h"
|
#include "rocblas/rocblas.h"
|
||||||
#endif
|
#endif // __HIP_PLATFORM_AMD__
|
||||||
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
||||||
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
||||||
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
||||||
|
@ -68,19 +68,29 @@
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <cublas_v2.h>
|
#include <cublas_v2.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#endif
|
#endif // defined(GGML_USE_HIPBLAS)
|
||||||
|
|
||||||
#include "ggml-cuda.h"
|
#include "ggml-cuda.h"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
||||||
#ifndef CC_TURING
|
#define CC_TURING 700
|
||||||
#define CC_TURING 700
|
#define CC_OFFSET_AMD 1000000
|
||||||
#endif
|
#define CC_RDNA2 CC_OFFSET_AMD + 1030
|
||||||
|
|
||||||
#if defined(GGML_USE_HIPBLAS)
|
#if defined(GGML_USE_HIPBLAS)
|
||||||
#define __CUDA_ARCH__ 1300
|
#define __CUDA_ARCH__ 1300
|
||||||
|
|
||||||
|
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
||||||
|
defined(__gfx1150__) || defined(__gfx1151__)
|
||||||
|
#define RDNA3
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) \
|
||||||
|
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
|
||||||
|
#define RDNA2
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifndef __has_builtin
|
#ifndef __has_builtin
|
||||||
#define __has_builtin(x) 0
|
#define __has_builtin(x) 0
|
||||||
#endif
|
#endif
|
||||||
|
@ -132,7 +142,7 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
||||||
#endif
|
#endif
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
#endif
|
#endif // defined(GGML_USE_HIPBLAS)
|
||||||
|
|
||||||
#if defined(_MSC_VER)
|
#if defined(_MSC_VER)
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
|
@ -3444,6 +3454,12 @@ static __device__ __forceinline__ void mul_mat_q(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define MMQ_X_Q4_0_RDNA2 64
|
||||||
|
#define MMQ_Y_Q4_0_RDNA2 128
|
||||||
|
#define NWARPS_Q4_0_RDNA2 8
|
||||||
|
#define MMQ_X_Q4_0_RDNA1 64
|
||||||
|
#define MMQ_Y_Q4_0_RDNA1 64
|
||||||
|
#define NWARPS_Q4_0_RDNA1 8
|
||||||
#define MMQ_X_Q4_0_AMPERE 64
|
#define MMQ_X_Q4_0_AMPERE 64
|
||||||
#define MMQ_Y_Q4_0_AMPERE 128
|
#define MMQ_Y_Q4_0_AMPERE 128
|
||||||
#define NWARPS_Q4_0_AMPERE 4
|
#define NWARPS_Q4_0_AMPERE 4
|
||||||
|
@ -3451,11 +3467,32 @@ static __device__ __forceinline__ void mul_mat_q(
|
||||||
#define MMQ_Y_Q4_0_PASCAL 64
|
#define MMQ_Y_Q4_0_PASCAL 64
|
||||||
#define NWARPS_Q4_0_PASCAL 8
|
#define NWARPS_Q4_0_PASCAL 8
|
||||||
|
|
||||||
template <bool need_check> static __global__ void mul_mat_q4_0(
|
template <bool need_check> static __global__ void
|
||||||
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
__launch_bounds__(WARP_SIZE*NWARPS_Q4_0_RDNA2, 2)
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
mul_mat_q4_0(
|
||||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= CC_TURING
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
const int mmq_x = MMQ_X_Q4_0_RDNA2;
|
||||||
|
const int mmq_y = MMQ_Y_Q4_0_RDNA2;
|
||||||
|
const int nwarps = NWARPS_Q4_0_RDNA2;
|
||||||
|
#else
|
||||||
|
const int mmq_x = MMQ_X_Q4_0_RDNA1;
|
||||||
|
const int mmq_y = MMQ_Y_Q4_0_RDNA1;
|
||||||
|
const int nwarps = NWARPS_Q4_0_RDNA1;
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
|
||||||
|
mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
|
||||||
|
load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||||
|
|
||||||
|
#elif __CUDA_ARCH__ >= CC_TURING
|
||||||
const int mmq_x = MMQ_X_Q4_0_AMPERE;
|
const int mmq_x = MMQ_X_Q4_0_AMPERE;
|
||||||
const int mmq_y = MMQ_Y_Q4_0_AMPERE;
|
const int mmq_y = MMQ_Y_Q4_0_AMPERE;
|
||||||
const int nwarps = NWARPS_Q4_0_AMPERE;
|
const int nwarps = NWARPS_Q4_0_AMPERE;
|
||||||
|
@ -3478,6 +3515,12 @@ template <bool need_check> static __global__ void mul_mat_q4_0(
|
||||||
#endif // __CUDA_ARCH__ >= CC_TURING
|
#endif // __CUDA_ARCH__ >= CC_TURING
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define MMQ_X_Q4_1_RDNA2 64
|
||||||
|
#define MMQ_Y_Q4_1_RDNA2 128
|
||||||
|
#define NWARPS_Q4_1_RDNA2 8
|
||||||
|
#define MMQ_X_Q4_1_RDNA1 64
|
||||||
|
#define MMQ_Y_Q4_1_RDNA1 64
|
||||||
|
#define NWARPS_Q4_1_RDNA1 8
|
||||||
#define MMQ_X_Q4_1_AMPERE 64
|
#define MMQ_X_Q4_1_AMPERE 64
|
||||||
#define MMQ_Y_Q4_1_AMPERE 128
|
#define MMQ_Y_Q4_1_AMPERE 128
|
||||||
#define NWARPS_Q4_1_AMPERE 4
|
#define NWARPS_Q4_1_AMPERE 4
|
||||||
|
@ -3486,14 +3529,33 @@ template <bool need_check> static __global__ void mul_mat_q4_0(
|
||||||
#define NWARPS_Q4_1_PASCAL 8
|
#define NWARPS_Q4_1_PASCAL 8
|
||||||
|
|
||||||
template <bool need_check> static __global__ void
|
template <bool need_check> static __global__ void
|
||||||
#if __CUDA_ARCH__ < CC_TURING
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
__launch_bounds__(WARP_SIZE*NWARPS_Q4_1_RDNA2, 2)
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
#elif __CUDA_ARCH__ < CC_TURING
|
||||||
__launch_bounds__(WARP_SIZE*NWARPS_Q4_1_PASCAL, 2)
|
__launch_bounds__(WARP_SIZE*NWARPS_Q4_1_PASCAL, 2)
|
||||||
#endif // __CUDA_ARCH__ < CC_TURING
|
#endif // __CUDA_ARCH__ < CC_TURING
|
||||||
mul_mat_q4_1(
|
mul_mat_q4_1(
|
||||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= CC_TURING
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
const int mmq_x = MMQ_X_Q4_1_RDNA2;
|
||||||
|
const int mmq_y = MMQ_Y_Q4_1_RDNA2;
|
||||||
|
const int nwarps = NWARPS_Q4_1_RDNA2;
|
||||||
|
#else
|
||||||
|
const int mmq_x = MMQ_X_Q4_1_RDNA1;
|
||||||
|
const int mmq_y = MMQ_Y_Q4_1_RDNA1;
|
||||||
|
const int nwarps = NWARPS_Q4_1_RDNA1;
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
|
||||||
|
mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
|
||||||
|
load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||||
|
|
||||||
|
#elif __CUDA_ARCH__ >= CC_TURING
|
||||||
const int mmq_x = MMQ_X_Q4_1_AMPERE;
|
const int mmq_x = MMQ_X_Q4_1_AMPERE;
|
||||||
const int mmq_y = MMQ_Y_Q4_1_AMPERE;
|
const int mmq_y = MMQ_Y_Q4_1_AMPERE;
|
||||||
const int nwarps = NWARPS_Q4_1_AMPERE;
|
const int nwarps = NWARPS_Q4_1_AMPERE;
|
||||||
|
@ -3516,6 +3578,12 @@ template <bool need_check> static __global__ void
|
||||||
#endif // __CUDA_ARCH__ >= CC_TURING
|
#endif // __CUDA_ARCH__ >= CC_TURING
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define MMQ_X_Q5_0_RDNA2 64
|
||||||
|
#define MMQ_Y_Q5_0_RDNA2 128
|
||||||
|
#define NWARPS_Q5_0_RDNA2 8
|
||||||
|
#define MMQ_X_Q5_0_RDNA1 64
|
||||||
|
#define MMQ_Y_Q5_0_RDNA1 64
|
||||||
|
#define NWARPS_Q5_0_RDNA1 8
|
||||||
#define MMQ_X_Q5_0_AMPERE 128
|
#define MMQ_X_Q5_0_AMPERE 128
|
||||||
#define MMQ_Y_Q5_0_AMPERE 64
|
#define MMQ_Y_Q5_0_AMPERE 64
|
||||||
#define NWARPS_Q5_0_AMPERE 4
|
#define NWARPS_Q5_0_AMPERE 4
|
||||||
|
@ -3523,11 +3591,32 @@ template <bool need_check> static __global__ void
|
||||||
#define MMQ_Y_Q5_0_PASCAL 64
|
#define MMQ_Y_Q5_0_PASCAL 64
|
||||||
#define NWARPS_Q5_0_PASCAL 8
|
#define NWARPS_Q5_0_PASCAL 8
|
||||||
|
|
||||||
template <bool need_check> static __global__ void mul_mat_q5_0(
|
template <bool need_check> static __global__ void
|
||||||
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
__launch_bounds__(WARP_SIZE*NWARPS_Q5_0_RDNA2, 2)
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
mul_mat_q5_0(
|
||||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= CC_TURING
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
const int mmq_x = MMQ_X_Q5_0_RDNA2;
|
||||||
|
const int mmq_y = MMQ_Y_Q5_0_RDNA2;
|
||||||
|
const int nwarps = NWARPS_Q5_0_RDNA2;
|
||||||
|
#else
|
||||||
|
const int mmq_x = MMQ_X_Q5_0_RDNA1;
|
||||||
|
const int mmq_y = MMQ_Y_Q5_0_RDNA1;
|
||||||
|
const int nwarps = NWARPS_Q5_0_RDNA1;
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
|
||||||
|
mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
|
||||||
|
load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||||
|
|
||||||
|
#elif __CUDA_ARCH__ >= CC_TURING
|
||||||
const int mmq_x = MMQ_X_Q5_0_AMPERE;
|
const int mmq_x = MMQ_X_Q5_0_AMPERE;
|
||||||
const int mmq_y = MMQ_Y_Q5_0_AMPERE;
|
const int mmq_y = MMQ_Y_Q5_0_AMPERE;
|
||||||
const int nwarps = NWARPS_Q5_0_AMPERE;
|
const int nwarps = NWARPS_Q5_0_AMPERE;
|
||||||
|
@ -3550,6 +3639,12 @@ template <bool need_check> static __global__ void mul_mat_q5_0(
|
||||||
#endif // __CUDA_ARCH__ >= CC_TURING
|
#endif // __CUDA_ARCH__ >= CC_TURING
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define MMQ_X_Q5_1_RDNA2 64
|
||||||
|
#define MMQ_Y_Q5_1_RDNA2 128
|
||||||
|
#define NWARPS_Q5_1_RDNA2 8
|
||||||
|
#define MMQ_X_Q5_1_RDNA1 64
|
||||||
|
#define MMQ_Y_Q5_1_RDNA1 64
|
||||||
|
#define NWARPS_Q5_1_RDNA1 8
|
||||||
#define MMQ_X_Q5_1_AMPERE 128
|
#define MMQ_X_Q5_1_AMPERE 128
|
||||||
#define MMQ_Y_Q5_1_AMPERE 64
|
#define MMQ_Y_Q5_1_AMPERE 64
|
||||||
#define NWARPS_Q5_1_AMPERE 4
|
#define NWARPS_Q5_1_AMPERE 4
|
||||||
|
@ -3557,11 +3652,32 @@ template <bool need_check> static __global__ void mul_mat_q5_0(
|
||||||
#define MMQ_Y_Q5_1_PASCAL 64
|
#define MMQ_Y_Q5_1_PASCAL 64
|
||||||
#define NWARPS_Q5_1_PASCAL 8
|
#define NWARPS_Q5_1_PASCAL 8
|
||||||
|
|
||||||
template <bool need_check> static __global__ void mul_mat_q5_1(
|
template <bool need_check> static __global__ void
|
||||||
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
__launch_bounds__(WARP_SIZE*NWARPS_Q5_1_RDNA2, 2)
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
mul_mat_q5_1(
|
||||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= CC_TURING
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
const int mmq_x = MMQ_X_Q5_1_RDNA2;
|
||||||
|
const int mmq_y = MMQ_Y_Q5_1_RDNA2;
|
||||||
|
const int nwarps = NWARPS_Q5_1_RDNA2;
|
||||||
|
#else
|
||||||
|
const int mmq_x = MMQ_X_Q5_1_RDNA1;
|
||||||
|
const int mmq_y = MMQ_Y_Q5_1_RDNA1;
|
||||||
|
const int nwarps = NWARPS_Q5_1_RDNA1;
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
|
||||||
|
mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
|
||||||
|
load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||||
|
|
||||||
|
#elif __CUDA_ARCH__ >= CC_TURING
|
||||||
const int mmq_x = MMQ_X_Q5_1_AMPERE;
|
const int mmq_x = MMQ_X_Q5_1_AMPERE;
|
||||||
const int mmq_y = MMQ_Y_Q5_1_AMPERE;
|
const int mmq_y = MMQ_Y_Q5_1_AMPERE;
|
||||||
const int nwarps = NWARPS_Q5_1_AMPERE;
|
const int nwarps = NWARPS_Q5_1_AMPERE;
|
||||||
|
@ -3584,6 +3700,12 @@ template <bool need_check> static __global__ void mul_mat_q5_1(
|
||||||
#endif // __CUDA_ARCH__ >= CC_TURING
|
#endif // __CUDA_ARCH__ >= CC_TURING
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define MMQ_X_Q8_0_RDNA2 64
|
||||||
|
#define MMQ_Y_Q8_0_RDNA2 128
|
||||||
|
#define NWARPS_Q8_0_RDNA2 8
|
||||||
|
#define MMQ_X_Q8_0_RDNA1 64
|
||||||
|
#define MMQ_Y_Q8_0_RDNA1 64
|
||||||
|
#define NWARPS_Q8_0_RDNA1 8
|
||||||
#define MMQ_X_Q8_0_AMPERE 128
|
#define MMQ_X_Q8_0_AMPERE 128
|
||||||
#define MMQ_Y_Q8_0_AMPERE 64
|
#define MMQ_Y_Q8_0_AMPERE 64
|
||||||
#define NWARPS_Q8_0_AMPERE 4
|
#define NWARPS_Q8_0_AMPERE 4
|
||||||
|
@ -3591,11 +3713,32 @@ template <bool need_check> static __global__ void mul_mat_q5_1(
|
||||||
#define MMQ_Y_Q8_0_PASCAL 64
|
#define MMQ_Y_Q8_0_PASCAL 64
|
||||||
#define NWARPS_Q8_0_PASCAL 8
|
#define NWARPS_Q8_0_PASCAL 8
|
||||||
|
|
||||||
template <bool need_check> static __global__ void mul_mat_q8_0(
|
template <bool need_check> static __global__ void
|
||||||
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
__launch_bounds__(WARP_SIZE*NWARPS_Q8_0_RDNA2, 2)
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
mul_mat_q8_0(
|
||||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= CC_TURING
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
const int mmq_x = MMQ_X_Q8_0_RDNA2;
|
||||||
|
const int mmq_y = MMQ_Y_Q8_0_RDNA2;
|
||||||
|
const int nwarps = NWARPS_Q8_0_RDNA2;
|
||||||
|
#else
|
||||||
|
const int mmq_x = MMQ_X_Q8_0_RDNA1;
|
||||||
|
const int mmq_y = MMQ_Y_Q8_0_RDNA1;
|
||||||
|
const int nwarps = NWARPS_Q8_0_RDNA1;
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
|
||||||
|
mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
|
||||||
|
load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||||
|
|
||||||
|
#elif __CUDA_ARCH__ >= CC_TURING
|
||||||
const int mmq_x = MMQ_X_Q8_0_AMPERE;
|
const int mmq_x = MMQ_X_Q8_0_AMPERE;
|
||||||
const int mmq_y = MMQ_Y_Q8_0_AMPERE;
|
const int mmq_y = MMQ_Y_Q8_0_AMPERE;
|
||||||
const int nwarps = NWARPS_Q8_0_AMPERE;
|
const int nwarps = NWARPS_Q8_0_AMPERE;
|
||||||
|
@ -3618,6 +3761,12 @@ template <bool need_check> static __global__ void mul_mat_q8_0(
|
||||||
#endif // __CUDA_ARCH__ >= CC_TURING
|
#endif // __CUDA_ARCH__ >= CC_TURING
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define MMQ_X_Q2_K_RDNA2 64
|
||||||
|
#define MMQ_Y_Q2_K_RDNA2 128
|
||||||
|
#define NWARPS_Q2_K_RDNA2 8
|
||||||
|
#define MMQ_X_Q2_K_RDNA1 128
|
||||||
|
#define MMQ_Y_Q2_K_RDNA1 32
|
||||||
|
#define NWARPS_Q2_K_RDNA1 8
|
||||||
#define MMQ_X_Q2_K_AMPERE 64
|
#define MMQ_X_Q2_K_AMPERE 64
|
||||||
#define MMQ_Y_Q2_K_AMPERE 128
|
#define MMQ_Y_Q2_K_AMPERE 128
|
||||||
#define NWARPS_Q2_K_AMPERE 4
|
#define NWARPS_Q2_K_AMPERE 4
|
||||||
|
@ -3625,11 +3774,32 @@ template <bool need_check> static __global__ void mul_mat_q8_0(
|
||||||
#define MMQ_Y_Q2_K_PASCAL 64
|
#define MMQ_Y_Q2_K_PASCAL 64
|
||||||
#define NWARPS_Q2_K_PASCAL 8
|
#define NWARPS_Q2_K_PASCAL 8
|
||||||
|
|
||||||
template <bool need_check> static __global__ void mul_mat_q2_K(
|
template <bool need_check> static __global__ void
|
||||||
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
__launch_bounds__(WARP_SIZE*NWARPS_Q2_K_RDNA2, 2)
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
mul_mat_q2_K(
|
||||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= CC_TURING
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
const int mmq_x = MMQ_X_Q2_K_RDNA2;
|
||||||
|
const int mmq_y = MMQ_Y_Q2_K_RDNA2;
|
||||||
|
const int nwarps = NWARPS_Q2_K_RDNA2;
|
||||||
|
#else
|
||||||
|
const int mmq_x = MMQ_X_Q2_K_RDNA1;
|
||||||
|
const int mmq_y = MMQ_Y_Q2_K_RDNA1;
|
||||||
|
const int nwarps = NWARPS_Q2_K_RDNA1;
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
|
||||||
|
mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
|
||||||
|
load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||||
|
|
||||||
|
#elif __CUDA_ARCH__ >= CC_TURING
|
||||||
const int mmq_x = MMQ_X_Q2_K_AMPERE;
|
const int mmq_x = MMQ_X_Q2_K_AMPERE;
|
||||||
const int mmq_y = MMQ_Y_Q2_K_AMPERE;
|
const int mmq_y = MMQ_Y_Q2_K_AMPERE;
|
||||||
const int nwarps = NWARPS_Q2_K_AMPERE;
|
const int nwarps = NWARPS_Q2_K_AMPERE;
|
||||||
|
@ -3652,6 +3822,12 @@ template <bool need_check> static __global__ void mul_mat_q2_K(
|
||||||
#endif // __CUDA_ARCH__ >= CC_TURING
|
#endif // __CUDA_ARCH__ >= CC_TURING
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define MMQ_X_Q3_K_RDNA2 128
|
||||||
|
#define MMQ_Y_Q3_K_RDNA2 64
|
||||||
|
#define NWARPS_Q3_K_RDNA2 8
|
||||||
|
#define MMQ_X_Q3_K_RDNA1 32
|
||||||
|
#define MMQ_Y_Q3_K_RDNA1 128
|
||||||
|
#define NWARPS_Q3_K_RDNA1 8
|
||||||
#define MMQ_X_Q3_K_AMPERE 128
|
#define MMQ_X_Q3_K_AMPERE 128
|
||||||
#define MMQ_Y_Q3_K_AMPERE 128
|
#define MMQ_Y_Q3_K_AMPERE 128
|
||||||
#define NWARPS_Q3_K_AMPERE 4
|
#define NWARPS_Q3_K_AMPERE 4
|
||||||
|
@ -3660,14 +3836,33 @@ template <bool need_check> static __global__ void mul_mat_q2_K(
|
||||||
#define NWARPS_Q3_K_PASCAL 8
|
#define NWARPS_Q3_K_PASCAL 8
|
||||||
|
|
||||||
template <bool need_check> static __global__ void
|
template <bool need_check> static __global__ void
|
||||||
#if __CUDA_ARCH__ < CC_TURING
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
__launch_bounds__(WARP_SIZE*NWARPS_Q3_K_RDNA2, 2)
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
#elif __CUDA_ARCH__ < CC_TURING
|
||||||
__launch_bounds__(WARP_SIZE*NWARPS_Q3_K_PASCAL, 2)
|
__launch_bounds__(WARP_SIZE*NWARPS_Q3_K_PASCAL, 2)
|
||||||
#endif // __CUDA_ARCH__ < CC_TURING
|
#endif // __CUDA_ARCH__ < CC_TURING
|
||||||
mul_mat_q3_K(
|
mul_mat_q3_K(
|
||||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= CC_TURING
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
const int mmq_x = MMQ_X_Q3_K_RDNA2;
|
||||||
|
const int mmq_y = MMQ_Y_Q3_K_RDNA2;
|
||||||
|
const int nwarps = NWARPS_Q3_K_RDNA2;
|
||||||
|
#else
|
||||||
|
const int mmq_x = MMQ_X_Q3_K_RDNA1;
|
||||||
|
const int mmq_y = MMQ_Y_Q3_K_RDNA1;
|
||||||
|
const int nwarps = NWARPS_Q3_K_RDNA1;
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
|
||||||
|
mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
|
||||||
|
load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||||
|
|
||||||
|
#elif __CUDA_ARCH__ >= CC_TURING
|
||||||
const int mmq_x = MMQ_X_Q3_K_AMPERE;
|
const int mmq_x = MMQ_X_Q3_K_AMPERE;
|
||||||
const int mmq_y = MMQ_Y_Q3_K_AMPERE;
|
const int mmq_y = MMQ_Y_Q3_K_AMPERE;
|
||||||
const int nwarps = NWARPS_Q3_K_AMPERE;
|
const int nwarps = NWARPS_Q3_K_AMPERE;
|
||||||
|
@ -3690,6 +3885,12 @@ template <bool need_check> static __global__ void
|
||||||
#endif // __CUDA_ARCH__ >= CC_TURING
|
#endif // __CUDA_ARCH__ >= CC_TURING
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define MMQ_X_Q4_K_RDNA2 64
|
||||||
|
#define MMQ_Y_Q4_K_RDNA2 128
|
||||||
|
#define NWARPS_Q4_K_RDNA2 8
|
||||||
|
#define MMQ_X_Q4_K_RDNA1 32
|
||||||
|
#define MMQ_Y_Q4_K_RDNA1 64
|
||||||
|
#define NWARPS_Q4_K_RDNA1 8
|
||||||
#define MMQ_X_Q4_K_AMPERE 64
|
#define MMQ_X_Q4_K_AMPERE 64
|
||||||
#define MMQ_Y_Q4_K_AMPERE 128
|
#define MMQ_Y_Q4_K_AMPERE 128
|
||||||
#define NWARPS_Q4_K_AMPERE 4
|
#define NWARPS_Q4_K_AMPERE 4
|
||||||
|
@ -3698,14 +3899,33 @@ template <bool need_check> static __global__ void
|
||||||
#define NWARPS_Q4_K_PASCAL 8
|
#define NWARPS_Q4_K_PASCAL 8
|
||||||
|
|
||||||
template <bool need_check> static __global__ void
|
template <bool need_check> static __global__ void
|
||||||
#if __CUDA_ARCH__ < CC_TURING
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
__launch_bounds__(WARP_SIZE*NWARPS_Q4_K_RDNA2, 2)
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
#elif __CUDA_ARCH__ < CC_TURING
|
||||||
__launch_bounds__(WARP_SIZE*NWARPS_Q4_K_PASCAL, 2)
|
__launch_bounds__(WARP_SIZE*NWARPS_Q4_K_PASCAL, 2)
|
||||||
#endif // __CUDA_ARCH__ < CC_TURING
|
#endif // __CUDA_ARCH__ < CC_TURING
|
||||||
mul_mat_q4_K(
|
mul_mat_q4_K(
|
||||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= CC_TURING
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
const int mmq_x = MMQ_X_Q4_K_RDNA2;
|
||||||
|
const int mmq_y = MMQ_Y_Q4_K_RDNA2;
|
||||||
|
const int nwarps = NWARPS_Q4_K_RDNA2;
|
||||||
|
#else
|
||||||
|
const int mmq_x = MMQ_X_Q4_K_RDNA1;
|
||||||
|
const int mmq_y = MMQ_Y_Q4_K_RDNA1;
|
||||||
|
const int nwarps = NWARPS_Q4_K_RDNA1;
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
|
||||||
|
mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
|
||||||
|
load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||||
|
|
||||||
|
#elif __CUDA_ARCH__ >= CC_TURING
|
||||||
const int mmq_x = MMQ_X_Q4_K_AMPERE;
|
const int mmq_x = MMQ_X_Q4_K_AMPERE;
|
||||||
const int mmq_y = MMQ_Y_Q4_K_AMPERE;
|
const int mmq_y = MMQ_Y_Q4_K_AMPERE;
|
||||||
const int nwarps = NWARPS_Q4_K_AMPERE;
|
const int nwarps = NWARPS_Q4_K_AMPERE;
|
||||||
|
@ -3728,6 +3948,12 @@ template <bool need_check> static __global__ void
|
||||||
#endif // __CUDA_ARCH__ >= CC_TURING
|
#endif // __CUDA_ARCH__ >= CC_TURING
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define MMQ_X_Q5_K_RDNA2 64
|
||||||
|
#define MMQ_Y_Q5_K_RDNA2 128
|
||||||
|
#define NWARPS_Q5_K_RDNA2 8
|
||||||
|
#define MMQ_X_Q5_K_RDNA1 32
|
||||||
|
#define MMQ_Y_Q5_K_RDNA1 64
|
||||||
|
#define NWARPS_Q5_K_RDNA1 8
|
||||||
#define MMQ_X_Q5_K_AMPERE 64
|
#define MMQ_X_Q5_K_AMPERE 64
|
||||||
#define MMQ_Y_Q5_K_AMPERE 128
|
#define MMQ_Y_Q5_K_AMPERE 128
|
||||||
#define NWARPS_Q5_K_AMPERE 4
|
#define NWARPS_Q5_K_AMPERE 4
|
||||||
|
@ -3735,11 +3961,32 @@ template <bool need_check> static __global__ void
|
||||||
#define MMQ_Y_Q5_K_PASCAL 64
|
#define MMQ_Y_Q5_K_PASCAL 64
|
||||||
#define NWARPS_Q5_K_PASCAL 8
|
#define NWARPS_Q5_K_PASCAL 8
|
||||||
|
|
||||||
template <bool need_check> static __global__ void mul_mat_q5_K(
|
template <bool need_check> static __global__ void
|
||||||
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
__launch_bounds__(WARP_SIZE*NWARPS_Q5_K_RDNA2, 2)
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
mul_mat_q5_K(
|
||||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= CC_TURING
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
const int mmq_x = MMQ_X_Q5_K_RDNA2;
|
||||||
|
const int mmq_y = MMQ_Y_Q5_K_RDNA2;
|
||||||
|
const int nwarps = NWARPS_Q5_K_RDNA2;
|
||||||
|
#else
|
||||||
|
const int mmq_x = MMQ_X_Q5_K_RDNA1;
|
||||||
|
const int mmq_y = MMQ_Y_Q5_K_RDNA1;
|
||||||
|
const int nwarps = NWARPS_Q5_K_RDNA1;
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
|
||||||
|
mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
|
||||||
|
load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||||
|
|
||||||
|
#elif __CUDA_ARCH__ >= CC_TURING
|
||||||
const int mmq_x = MMQ_X_Q5_K_AMPERE;
|
const int mmq_x = MMQ_X_Q5_K_AMPERE;
|
||||||
const int mmq_y = MMQ_Y_Q5_K_AMPERE;
|
const int mmq_y = MMQ_Y_Q5_K_AMPERE;
|
||||||
const int nwarps = NWARPS_Q5_K_AMPERE;
|
const int nwarps = NWARPS_Q5_K_AMPERE;
|
||||||
|
@ -3762,6 +4009,12 @@ template <bool need_check> static __global__ void mul_mat_q5_K(
|
||||||
#endif // __CUDA_ARCH__ >= CC_TURING
|
#endif // __CUDA_ARCH__ >= CC_TURING
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define MMQ_X_Q6_K_RDNA2 64
|
||||||
|
#define MMQ_Y_Q6_K_RDNA2 128
|
||||||
|
#define NWARPS_Q6_K_RDNA2 8
|
||||||
|
#define MMQ_X_Q6_K_RDNA1 32
|
||||||
|
#define MMQ_Y_Q6_K_RDNA1 64
|
||||||
|
#define NWARPS_Q6_K_RDNA1 8
|
||||||
#define MMQ_X_Q6_K_AMPERE 64
|
#define MMQ_X_Q6_K_AMPERE 64
|
||||||
#define MMQ_Y_Q6_K_AMPERE 64
|
#define MMQ_Y_Q6_K_AMPERE 64
|
||||||
#define NWARPS_Q6_K_AMPERE 4
|
#define NWARPS_Q6_K_AMPERE 4
|
||||||
|
@ -3770,14 +4023,33 @@ template <bool need_check> static __global__ void mul_mat_q5_K(
|
||||||
#define NWARPS_Q6_K_PASCAL 8
|
#define NWARPS_Q6_K_PASCAL 8
|
||||||
|
|
||||||
template <bool need_check> static __global__ void
|
template <bool need_check> static __global__ void
|
||||||
#if __CUDA_ARCH__ < CC_TURING
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
__launch_bounds__(WARP_SIZE*NWARPS_Q6_K_RDNA2, 2)
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
#elif __CUDA_ARCH__ < CC_TURING
|
||||||
__launch_bounds__(WARP_SIZE*NWARPS_Q6_K_PASCAL, 2)
|
__launch_bounds__(WARP_SIZE*NWARPS_Q6_K_PASCAL, 2)
|
||||||
#endif // __CUDA_ARCH__ < CC_TURING
|
#endif // __CUDA_ARCH__ < CC_TURING
|
||||||
mul_mat_q6_K(
|
mul_mat_q6_K(
|
||||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= CC_TURING
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
const int mmq_x = MMQ_X_Q6_K_RDNA2;
|
||||||
|
const int mmq_y = MMQ_Y_Q6_K_RDNA2;
|
||||||
|
const int nwarps = NWARPS_Q6_K_RDNA2;
|
||||||
|
#else
|
||||||
|
const int mmq_x = MMQ_X_Q6_K_RDNA1;
|
||||||
|
const int mmq_y = MMQ_Y_Q6_K_RDNA1;
|
||||||
|
const int nwarps = NWARPS_Q6_K_RDNA1;
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
|
||||||
|
mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
|
||||||
|
load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
|
||||||
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||||
|
|
||||||
|
#elif __CUDA_ARCH__ >= CC_TURING
|
||||||
const int mmq_x = MMQ_X_Q6_K_AMPERE;
|
const int mmq_x = MMQ_X_Q6_K_AMPERE;
|
||||||
const int mmq_y = MMQ_Y_Q6_K_AMPERE;
|
const int mmq_y = MMQ_Y_Q6_K_AMPERE;
|
||||||
const int nwarps = NWARPS_Q6_K_AMPERE;
|
const int nwarps = NWARPS_Q6_K_AMPERE;
|
||||||
|
@ -4558,7 +4830,15 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
|
||||||
const int compute_capability = g_compute_capabilities[id];
|
const int compute_capability = g_compute_capabilities[id];
|
||||||
|
|
||||||
int mmq_x, mmq_y, nwarps;
|
int mmq_x, mmq_y, nwarps;
|
||||||
if (compute_capability >= CC_TURING) {
|
if (compute_capability >= CC_RDNA2) {
|
||||||
|
mmq_x = MMQ_X_Q4_0_RDNA2;
|
||||||
|
mmq_y = MMQ_Y_Q4_0_RDNA2;
|
||||||
|
nwarps = NWARPS_Q4_0_RDNA2;
|
||||||
|
} else if (compute_capability >= CC_OFFSET_AMD) {
|
||||||
|
mmq_x = MMQ_X_Q4_0_RDNA1;
|
||||||
|
mmq_y = MMQ_Y_Q4_0_RDNA1;
|
||||||
|
nwarps = NWARPS_Q4_0_RDNA1;
|
||||||
|
} else if (compute_capability >= CC_TURING) {
|
||||||
mmq_x = MMQ_X_Q4_0_AMPERE;
|
mmq_x = MMQ_X_Q4_0_AMPERE;
|
||||||
mmq_y = MMQ_Y_Q4_0_AMPERE;
|
mmq_y = MMQ_Y_Q4_0_AMPERE;
|
||||||
nwarps = NWARPS_Q4_0_AMPERE;
|
nwarps = NWARPS_Q4_0_AMPERE;
|
||||||
|
@ -4595,7 +4875,15 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
|
||||||
const int compute_capability = g_compute_capabilities[id];
|
const int compute_capability = g_compute_capabilities[id];
|
||||||
|
|
||||||
int mmq_x, mmq_y, nwarps;
|
int mmq_x, mmq_y, nwarps;
|
||||||
if (compute_capability >= CC_TURING) {
|
if (compute_capability >= CC_RDNA2) {
|
||||||
|
mmq_x = MMQ_X_Q4_1_RDNA2;
|
||||||
|
mmq_y = MMQ_Y_Q4_1_RDNA2;
|
||||||
|
nwarps = NWARPS_Q4_1_RDNA2;
|
||||||
|
} else if (compute_capability >= CC_OFFSET_AMD) {
|
||||||
|
mmq_x = MMQ_X_Q4_1_RDNA1;
|
||||||
|
mmq_y = MMQ_Y_Q4_1_RDNA1;
|
||||||
|
nwarps = NWARPS_Q4_1_RDNA1;
|
||||||
|
} else if (compute_capability >= CC_TURING) {
|
||||||
mmq_x = MMQ_X_Q4_1_AMPERE;
|
mmq_x = MMQ_X_Q4_1_AMPERE;
|
||||||
mmq_y = MMQ_Y_Q4_1_AMPERE;
|
mmq_y = MMQ_Y_Q4_1_AMPERE;
|
||||||
nwarps = NWARPS_Q4_1_AMPERE;
|
nwarps = NWARPS_Q4_1_AMPERE;
|
||||||
|
@ -4632,7 +4920,15 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
|
||||||
const int compute_capability = g_compute_capabilities[id];
|
const int compute_capability = g_compute_capabilities[id];
|
||||||
|
|
||||||
int mmq_x, mmq_y, nwarps;
|
int mmq_x, mmq_y, nwarps;
|
||||||
if (compute_capability >= CC_TURING) {
|
if (compute_capability >= CC_RDNA2) {
|
||||||
|
mmq_x = MMQ_X_Q5_0_RDNA2;
|
||||||
|
mmq_y = MMQ_Y_Q5_0_RDNA2;
|
||||||
|
nwarps = NWARPS_Q5_0_RDNA2;
|
||||||
|
} else if (compute_capability >= CC_OFFSET_AMD) {
|
||||||
|
mmq_x = MMQ_X_Q5_0_RDNA1;
|
||||||
|
mmq_y = MMQ_Y_Q5_0_RDNA1;
|
||||||
|
nwarps = NWARPS_Q5_0_RDNA1;
|
||||||
|
} else if (compute_capability >= CC_TURING) {
|
||||||
mmq_x = MMQ_X_Q5_0_AMPERE;
|
mmq_x = MMQ_X_Q5_0_AMPERE;
|
||||||
mmq_y = MMQ_Y_Q5_0_AMPERE;
|
mmq_y = MMQ_Y_Q5_0_AMPERE;
|
||||||
nwarps = NWARPS_Q5_0_AMPERE;
|
nwarps = NWARPS_Q5_0_AMPERE;
|
||||||
|
@ -4669,7 +4965,15 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
|
||||||
const int compute_capability = g_compute_capabilities[id];
|
const int compute_capability = g_compute_capabilities[id];
|
||||||
|
|
||||||
int mmq_x, mmq_y, nwarps;
|
int mmq_x, mmq_y, nwarps;
|
||||||
if (compute_capability >= CC_TURING) {
|
if (compute_capability >= CC_RDNA2) {
|
||||||
|
mmq_x = MMQ_X_Q5_1_RDNA2;
|
||||||
|
mmq_y = MMQ_Y_Q5_1_RDNA2;
|
||||||
|
nwarps = NWARPS_Q5_1_RDNA2;
|
||||||
|
} else if (compute_capability >= CC_OFFSET_AMD) {
|
||||||
|
mmq_x = MMQ_X_Q5_1_RDNA1;
|
||||||
|
mmq_y = MMQ_Y_Q5_1_RDNA1;
|
||||||
|
nwarps = NWARPS_Q5_1_RDNA1;
|
||||||
|
} else if (compute_capability >= CC_TURING) {
|
||||||
mmq_x = MMQ_X_Q5_1_AMPERE;
|
mmq_x = MMQ_X_Q5_1_AMPERE;
|
||||||
mmq_y = MMQ_Y_Q5_1_AMPERE;
|
mmq_y = MMQ_Y_Q5_1_AMPERE;
|
||||||
nwarps = NWARPS_Q5_1_AMPERE;
|
nwarps = NWARPS_Q5_1_AMPERE;
|
||||||
|
@ -4706,7 +5010,15 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
|
||||||
const int compute_capability = g_compute_capabilities[id];
|
const int compute_capability = g_compute_capabilities[id];
|
||||||
|
|
||||||
int mmq_x, mmq_y, nwarps;
|
int mmq_x, mmq_y, nwarps;
|
||||||
if (compute_capability >= CC_TURING) {
|
if (compute_capability >= CC_RDNA2) {
|
||||||
|
mmq_x = MMQ_X_Q8_0_RDNA2;
|
||||||
|
mmq_y = MMQ_Y_Q8_0_RDNA2;
|
||||||
|
nwarps = NWARPS_Q8_0_RDNA2;
|
||||||
|
} else if (compute_capability >= CC_OFFSET_AMD) {
|
||||||
|
mmq_x = MMQ_X_Q8_0_RDNA1;
|
||||||
|
mmq_y = MMQ_Y_Q8_0_RDNA1;
|
||||||
|
nwarps = NWARPS_Q8_0_RDNA1;
|
||||||
|
} else if (compute_capability >= CC_TURING) {
|
||||||
mmq_x = MMQ_X_Q8_0_AMPERE;
|
mmq_x = MMQ_X_Q8_0_AMPERE;
|
||||||
mmq_y = MMQ_Y_Q8_0_AMPERE;
|
mmq_y = MMQ_Y_Q8_0_AMPERE;
|
||||||
nwarps = NWARPS_Q8_0_AMPERE;
|
nwarps = NWARPS_Q8_0_AMPERE;
|
||||||
|
@ -4743,7 +5055,15 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
|
||||||
const int compute_capability = g_compute_capabilities[id];
|
const int compute_capability = g_compute_capabilities[id];
|
||||||
|
|
||||||
int mmq_x, mmq_y, nwarps;
|
int mmq_x, mmq_y, nwarps;
|
||||||
if (compute_capability >= CC_TURING) {
|
if (compute_capability >= CC_RDNA2) {
|
||||||
|
mmq_x = MMQ_X_Q2_K_RDNA2;
|
||||||
|
mmq_y = MMQ_Y_Q2_K_RDNA2;
|
||||||
|
nwarps = NWARPS_Q2_K_RDNA2;
|
||||||
|
} else if (compute_capability >= CC_OFFSET_AMD) {
|
||||||
|
mmq_x = MMQ_X_Q2_K_RDNA1;
|
||||||
|
mmq_y = MMQ_Y_Q2_K_RDNA1;
|
||||||
|
nwarps = NWARPS_Q2_K_RDNA1;
|
||||||
|
} else if (compute_capability >= CC_TURING) {
|
||||||
mmq_x = MMQ_X_Q2_K_AMPERE;
|
mmq_x = MMQ_X_Q2_K_AMPERE;
|
||||||
mmq_y = MMQ_Y_Q2_K_AMPERE;
|
mmq_y = MMQ_Y_Q2_K_AMPERE;
|
||||||
nwarps = NWARPS_Q2_K_AMPERE;
|
nwarps = NWARPS_Q2_K_AMPERE;
|
||||||
|
@ -4782,7 +5102,15 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
|
||||||
const int compute_capability = g_compute_capabilities[id];
|
const int compute_capability = g_compute_capabilities[id];
|
||||||
|
|
||||||
int mmq_x, mmq_y, nwarps;
|
int mmq_x, mmq_y, nwarps;
|
||||||
if (compute_capability >= CC_TURING) {
|
if (compute_capability >= CC_RDNA2) {
|
||||||
|
mmq_x = MMQ_X_Q3_K_RDNA2;
|
||||||
|
mmq_y = MMQ_Y_Q3_K_RDNA2;
|
||||||
|
nwarps = NWARPS_Q3_K_RDNA2;
|
||||||
|
} else if (compute_capability >= CC_OFFSET_AMD) {
|
||||||
|
mmq_x = MMQ_X_Q3_K_RDNA1;
|
||||||
|
mmq_y = MMQ_Y_Q3_K_RDNA1;
|
||||||
|
nwarps = NWARPS_Q3_K_RDNA1;
|
||||||
|
} else if (compute_capability >= CC_TURING) {
|
||||||
mmq_x = MMQ_X_Q3_K_AMPERE;
|
mmq_x = MMQ_X_Q3_K_AMPERE;
|
||||||
mmq_y = MMQ_Y_Q3_K_AMPERE;
|
mmq_y = MMQ_Y_Q3_K_AMPERE;
|
||||||
nwarps = NWARPS_Q3_K_AMPERE;
|
nwarps = NWARPS_Q3_K_AMPERE;
|
||||||
|
@ -4820,7 +5148,15 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
|
||||||
const int compute_capability = g_compute_capabilities[id];
|
const int compute_capability = g_compute_capabilities[id];
|
||||||
|
|
||||||
int mmq_x, mmq_y, nwarps;
|
int mmq_x, mmq_y, nwarps;
|
||||||
if (compute_capability >= CC_TURING) {
|
if (compute_capability >= CC_RDNA2) {
|
||||||
|
mmq_x = MMQ_X_Q4_K_RDNA2;
|
||||||
|
mmq_y = MMQ_Y_Q4_K_RDNA2;
|
||||||
|
nwarps = NWARPS_Q4_K_RDNA2;
|
||||||
|
} else if (compute_capability >= CC_OFFSET_AMD) {
|
||||||
|
mmq_x = MMQ_X_Q4_K_RDNA1;
|
||||||
|
mmq_y = MMQ_Y_Q4_K_RDNA1;
|
||||||
|
nwarps = NWARPS_Q4_K_RDNA1;
|
||||||
|
} else if (compute_capability >= CC_TURING) {
|
||||||
mmq_x = MMQ_X_Q4_K_AMPERE;
|
mmq_x = MMQ_X_Q4_K_AMPERE;
|
||||||
mmq_y = MMQ_Y_Q4_K_AMPERE;
|
mmq_y = MMQ_Y_Q4_K_AMPERE;
|
||||||
nwarps = NWARPS_Q4_K_AMPERE;
|
nwarps = NWARPS_Q4_K_AMPERE;
|
||||||
|
@ -4857,7 +5193,15 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
|
||||||
const int compute_capability = g_compute_capabilities[id];
|
const int compute_capability = g_compute_capabilities[id];
|
||||||
|
|
||||||
int mmq_x, mmq_y, nwarps;
|
int mmq_x, mmq_y, nwarps;
|
||||||
if (compute_capability >= CC_TURING) {
|
if (compute_capability >= CC_RDNA2) {
|
||||||
|
mmq_x = MMQ_X_Q5_K_RDNA2;
|
||||||
|
mmq_y = MMQ_Y_Q5_K_RDNA2;
|
||||||
|
nwarps = NWARPS_Q5_K_RDNA2;
|
||||||
|
} else if (compute_capability >= CC_OFFSET_AMD) {
|
||||||
|
mmq_x = MMQ_X_Q5_K_RDNA1;
|
||||||
|
mmq_y = MMQ_Y_Q5_K_RDNA1;
|
||||||
|
nwarps = NWARPS_Q5_K_RDNA1;
|
||||||
|
} else if (compute_capability >= CC_TURING) {
|
||||||
mmq_x = MMQ_X_Q5_K_AMPERE;
|
mmq_x = MMQ_X_Q5_K_AMPERE;
|
||||||
mmq_y = MMQ_Y_Q5_K_AMPERE;
|
mmq_y = MMQ_Y_Q5_K_AMPERE;
|
||||||
nwarps = NWARPS_Q5_K_AMPERE;
|
nwarps = NWARPS_Q5_K_AMPERE;
|
||||||
|
@ -4894,7 +5238,15 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
|
||||||
const int compute_capability = g_compute_capabilities[id];
|
const int compute_capability = g_compute_capabilities[id];
|
||||||
|
|
||||||
int mmq_x, mmq_y, nwarps;
|
int mmq_x, mmq_y, nwarps;
|
||||||
if (compute_capability >= CC_TURING) {
|
if (compute_capability >= CC_RDNA2) {
|
||||||
|
mmq_x = MMQ_X_Q6_K_RDNA2;
|
||||||
|
mmq_y = MMQ_Y_Q6_K_RDNA2;
|
||||||
|
nwarps = NWARPS_Q6_K_RDNA2;
|
||||||
|
} else if (compute_capability >= CC_OFFSET_AMD) {
|
||||||
|
mmq_x = MMQ_X_Q6_K_RDNA1;
|
||||||
|
mmq_y = MMQ_Y_Q6_K_RDNA1;
|
||||||
|
nwarps = NWARPS_Q6_K_RDNA1;
|
||||||
|
} else if (compute_capability >= CC_TURING) {
|
||||||
mmq_x = MMQ_X_Q6_K_AMPERE;
|
mmq_x = MMQ_X_Q6_K_AMPERE;
|
||||||
mmq_y = MMQ_Y_Q6_K_AMPERE;
|
mmq_y = MMQ_Y_Q6_K_AMPERE;
|
||||||
nwarps = NWARPS_Q6_K_AMPERE;
|
nwarps = NWARPS_Q6_K_AMPERE;
|
||||||
|
@ -5134,8 +5486,11 @@ void ggml_init_cublas() {
|
||||||
|
|
||||||
g_tensor_split[id] = total_vram;
|
g_tensor_split[id] = total_vram;
|
||||||
total_vram += prop.totalGlobalMem;
|
total_vram += prop.totalGlobalMem;
|
||||||
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
g_compute_capabilities[id] = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
|
||||||
|
#else
|
||||||
g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
|
g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
|
||||||
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
}
|
}
|
||||||
for (int id = 0; id < g_device_count; ++id) {
|
for (int id = 0; id < g_device_count; ++id) {
|
||||||
g_tensor_split[id] /= total_vram;
|
g_tensor_split[id] /= total_vram;
|
||||||
|
@ -5475,14 +5830,41 @@ inline void ggml_cuda_op_mul_mat_q(
|
||||||
}
|
}
|
||||||
|
|
||||||
static int64_t get_row_rounding(ggml_type type) {
|
static int64_t get_row_rounding(ggml_type type) {
|
||||||
int max_compute_capability = INT_MIN;
|
int64_t min_compute_capability = INT_MAX;
|
||||||
for (int id = 0; id < g_device_count; ++id) {
|
int64_t max_compute_capability = INT_MIN;
|
||||||
if (max_compute_capability < g_compute_capabilities[id]
|
for (int64_t id = 0; id < g_device_count; ++id) {
|
||||||
&& g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
|
if (g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
|
||||||
max_compute_capability = g_compute_capabilities[id];
|
if (min_compute_capability > g_compute_capabilities[id]) {
|
||||||
|
min_compute_capability = g_compute_capabilities[id];
|
||||||
|
}
|
||||||
|
if (max_compute_capability < g_compute_capabilities[id]) {
|
||||||
|
max_compute_capability = g_compute_capabilities[id];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
switch(type) {
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
case GGML_TYPE_Q4_1:
|
||||||
|
case GGML_TYPE_Q5_0:
|
||||||
|
case GGML_TYPE_Q5_1:
|
||||||
|
case GGML_TYPE_Q8_0:
|
||||||
|
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
return 1;
|
||||||
|
case GGML_TYPE_Q2_K:
|
||||||
|
return max_compute_capability >= CC_RDNA2 ? 128 : 32;
|
||||||
|
case GGML_TYPE_Q3_K:
|
||||||
|
return min_compute_capability < CC_RDNA2 ? 128 : 64;
|
||||||
|
case GGML_TYPE_Q4_K:
|
||||||
|
case GGML_TYPE_Q5_K:
|
||||||
|
case GGML_TYPE_Q6_K:
|
||||||
|
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
#else
|
||||||
switch(type) {
|
switch(type) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
|
@ -5503,6 +5885,7 @@ static int64_t get_row_rounding(ggml_type type) {
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void ggml_cuda_op_mul_mat_vec(
|
inline void ggml_cuda_op_mul_mat_vec(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue