Merge 8dfe3d8e97
into 53debe6f3c
This commit is contained in:
commit
58d5154e7d
6 changed files with 45 additions and 13 deletions
|
@ -150,6 +150,7 @@ set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
|
||||||
"ggml: max. batch size for using peer access")
|
"ggml: max. batch size for using peer access")
|
||||||
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
|
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
|
||||||
option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF)
|
option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF)
|
||||||
|
option(GGML_CUDA_FA "ggml: compile with FlashAttention" ON)
|
||||||
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
|
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
|
||||||
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
|
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
|
||||||
|
|
||||||
|
|
|
@ -28,24 +28,35 @@ if (CUDAToolkit_FOUND)
|
||||||
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
|
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
|
||||||
|
|
||||||
file(GLOB GGML_SOURCES_CUDA "*.cu")
|
file(GLOB GGML_SOURCES_CUDA "*.cu")
|
||||||
file(GLOB SRCS "template-instances/fattn-wmma*.cu")
|
if (GGML_CUDA_FA)
|
||||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
file(GLOB SRCS "template-instances/fattn-wmma*.cu")
|
||||||
file(GLOB SRCS "template-instances/mmq*.cu")
|
|
||||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
|
||||||
|
|
||||||
if (GGML_CUDA_FA_ALL_QUANTS)
|
|
||||||
file(GLOB SRCS "template-instances/fattn-vec*.cu")
|
|
||||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||||
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
|
|
||||||
else()
|
else()
|
||||||
file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu")
|
list(FILTER GGML_SOURCES_CUDA EXCLUDE REGEX ".*fattn.*")
|
||||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
list(FILTER GGML_HEADERS_CUDA EXCLUDE REGEX ".*fattn.*")
|
||||||
file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu")
|
# message(FATAL_ERROR ${GGML_SOURCES_CUDA})
|
||||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
endif()
|
||||||
file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu")
|
if (NOT GGML_CUDA_FORCE_CUBLAS)
|
||||||
|
file(GLOB SRCS "template-instances/mmq*.cu")
|
||||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if (GGML_CUDA_FA)
|
||||||
|
add_compile_definitions(GGML_CUDA_FA)
|
||||||
|
if (GGML_CUDA_FA_ALL_QUANTS)
|
||||||
|
file(GLOB SRCS "template-instances/fattn-vec*.cu")
|
||||||
|
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||||
|
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
|
||||||
|
else()
|
||||||
|
file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu")
|
||||||
|
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||||
|
file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu")
|
||||||
|
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||||
|
file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu")
|
||||||
|
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
ggml_add_backend_library(ggml-cuda
|
ggml_add_backend_library(ggml-cuda
|
||||||
${GGML_HEADERS_CUDA}
|
${GGML_HEADERS_CUDA}
|
||||||
${GGML_SOURCES_CUDA}
|
${GGML_SOURCES_CUDA}
|
||||||
|
|
|
@ -155,6 +155,10 @@ typedef float2 dfloat2;
|
||||||
#define FLASH_ATTN_AVAILABLE
|
#define FLASH_ATTN_AVAILABLE
|
||||||
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||||
|
|
||||||
|
#if !defined(GGML_CUDA_FA)
|
||||||
|
#undef FLASH_ATTN_AVAILABLE
|
||||||
|
#endif
|
||||||
|
|
||||||
static constexpr bool fast_fp16_available(const int cc) {
|
static constexpr bool fast_fp16_available(const int cc) {
|
||||||
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
|
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,9 @@
|
||||||
#include "ggml-cuda/cpy.cuh"
|
#include "ggml-cuda/cpy.cuh"
|
||||||
#include "ggml-cuda/cross-entropy-loss.cuh"
|
#include "ggml-cuda/cross-entropy-loss.cuh"
|
||||||
#include "ggml-cuda/diagmask.cuh"
|
#include "ggml-cuda/diagmask.cuh"
|
||||||
|
#ifdef FLASH_ATTN_AVAILABLE
|
||||||
#include "ggml-cuda/fattn.cuh"
|
#include "ggml-cuda/fattn.cuh"
|
||||||
|
#endif
|
||||||
#include "ggml-cuda/getrows.cuh"
|
#include "ggml-cuda/getrows.cuh"
|
||||||
#include "ggml-cuda/im2col.cuh"
|
#include "ggml-cuda/im2col.cuh"
|
||||||
#include "ggml-cuda/mmq.cuh"
|
#include "ggml-cuda/mmq.cuh"
|
||||||
|
@ -2286,8 +2288,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
ggml_cuda_op_argsort(ctx, dst);
|
ggml_cuda_op_argsort(ctx, dst);
|
||||||
break;
|
break;
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
|
#ifdef FLASH_ATTN_AVAILABLE
|
||||||
ggml_cuda_flash_attn_ext(ctx, dst);
|
ggml_cuda_flash_attn_ext(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
#else
|
||||||
|
return false;
|
||||||
|
#endif
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||||
ggml_cuda_cross_entropy_loss(ctx, dst);
|
ggml_cuda_cross_entropy_loss(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -1,5 +1,12 @@
|
||||||
#include "mmq.cuh"
|
#include "mmq.cuh"
|
||||||
|
|
||||||
|
#ifdef GGML_CUDA_FORCE_CUBLAS
|
||||||
|
void ggml_cuda_op_mul_mat_q(
|
||||||
|
ggml_backend_cuda_context &,
|
||||||
|
const ggml_tensor *, const ggml_tensor *, ggml_tensor *, const char *, const float *,
|
||||||
|
const char *, float *, const int64_t, const int64_t, const int64_t,
|
||||||
|
const int64_t, cudaStream_t) {}
|
||||||
|
#else
|
||||||
void ggml_cuda_op_mul_mat_q(
|
void ggml_cuda_op_mul_mat_q(
|
||||||
ggml_backend_cuda_context & ctx,
|
ggml_backend_cuda_context & ctx,
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||||
|
@ -94,6 +101,7 @@ void ggml_cuda_op_mul_mat_q(
|
||||||
GGML_UNUSED(dst);
|
GGML_UNUSED(dst);
|
||||||
GGML_UNUSED(src1_ddf_i);
|
GGML_UNUSED(src1_ddf_i);
|
||||||
}
|
}
|
||||||
|
#endif // GGML_CUDA_FORCE_CUBLAS
|
||||||
|
|
||||||
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||||
#ifdef GGML_CUDA_FORCE_CUBLAS
|
#ifdef GGML_CUDA_FORCE_CUBLAS
|
||||||
|
|
|
@ -2906,6 +2906,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
|
||||||
#define DECL_MMQ_CASE(type) \
|
#define DECL_MMQ_CASE(type) \
|
||||||
template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
|
template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
|
||||||
|
|
||||||
|
#if !defined(GGML_CUDA_FORCE_CUBLAS)
|
||||||
extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
|
extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
|
||||||
extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
|
extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
|
||||||
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
|
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
|
||||||
|
@ -2924,6 +2925,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
|
||||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
|
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
|
||||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
|
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
|
||||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
|
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
|
||||||
|
#endif
|
||||||
|
|
||||||
// -------------------------------------------------------------------------------------------------------------------------
|
// -------------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue