Merge branch 'master' into concedo_experimental
should fix multigpu
This commit is contained in:
commit
a62468ec4c
4 changed files with 158 additions and 86 deletions
100
cmake/FindSIMD.cmake
Normal file
100
cmake/FindSIMD.cmake
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
include(CheckCSourceRuns)
|
||||||
|
|
||||||
|
set(AVX_CODE "
|
||||||
|
#include <immintrin.h>
|
||||||
|
int main()
|
||||||
|
{
|
||||||
|
__m256 a;
|
||||||
|
a = _mm256_set1_ps(0);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
")
|
||||||
|
|
||||||
|
set(AVX512_CODE "
|
||||||
|
#include <immintrin.h>
|
||||||
|
int main()
|
||||||
|
{
|
||||||
|
__m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0);
|
||||||
|
__m512i b = a;
|
||||||
|
__mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
")
|
||||||
|
|
||||||
|
set(AVX2_CODE "
|
||||||
|
#include <immintrin.h>
|
||||||
|
int main()
|
||||||
|
{
|
||||||
|
__m256i a = {0};
|
||||||
|
a = _mm256_abs_epi16(a);
|
||||||
|
__m256i x;
|
||||||
|
_mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
")
|
||||||
|
|
||||||
|
set(FMA_CODE "
|
||||||
|
#include <immintrin.h>
|
||||||
|
int main()
|
||||||
|
{
|
||||||
|
__m256 acc = _mm256_setzero_ps();
|
||||||
|
const __m256 d = _mm256_setzero_ps();
|
||||||
|
const __m256 p = _mm256_setzero_ps();
|
||||||
|
acc = _mm256_fmadd_ps( d, p, acc );
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
")
|
||||||
|
|
||||||
|
macro(check_sse type flags)
|
||||||
|
set(__FLAG_I 1)
|
||||||
|
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
||||||
|
foreach (__FLAG ${flags})
|
||||||
|
if (NOT ${type}_FOUND)
|
||||||
|
set(CMAKE_REQUIRED_FLAGS ${__FLAG})
|
||||||
|
check_c_source_runs("${${type}_CODE}" HAS_${type}_${__FLAG_I})
|
||||||
|
if (HAS_${type}_${__FLAG_I})
|
||||||
|
set(${type}_FOUND TRUE CACHE BOOL "${type} support")
|
||||||
|
set(${type}_FLAGS "${__FLAG}" CACHE STRING "${type} flags")
|
||||||
|
endif()
|
||||||
|
math(EXPR __FLAG_I "${__FLAG_I}+1")
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
||||||
|
|
||||||
|
if (NOT ${type}_FOUND)
|
||||||
|
set(${type}_FOUND FALSE CACHE BOOL "${type} support")
|
||||||
|
set(${type}_FLAGS "" CACHE STRING "${type} flags")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
mark_as_advanced(${type}_FOUND ${type}_FLAGS)
|
||||||
|
endmacro()
|
||||||
|
|
||||||
|
# flags are for MSVC only!
|
||||||
|
check_sse("AVX" " ;/arch:AVX")
|
||||||
|
if (NOT ${AVX_FOUND})
|
||||||
|
set(LLAMA_AVX OFF)
|
||||||
|
else()
|
||||||
|
set(LLAMA_AVX ON)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
check_sse("AVX2" " ;/arch:AVX2")
|
||||||
|
check_sse("FMA" " ;/arch:AVX2")
|
||||||
|
if ((NOT ${AVX2_FOUND}) OR (NOT ${FMA_FOUND}))
|
||||||
|
set(LLAMA_AVX2 OFF)
|
||||||
|
else()
|
||||||
|
set(LLAMA_AVX2 ON)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
check_sse("AVX512" " ;/arch:AVX512")
|
||||||
|
if (NOT ${AVX512_FOUND})
|
||||||
|
set(LLAMA_AVX512 OFF)
|
||||||
|
else()
|
||||||
|
set(LLAMA_AVX512 ON)
|
||||||
|
endif()
|
131
ggml-cuda.cu
131
ggml-cuda.cu
|
@ -39,10 +39,6 @@
|
||||||
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
||||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||||
#define cudaDeviceGetMemPool hipDeviceGetMemPool
|
|
||||||
#define cudaMemPoolAttrReleaseThreshold hipMemPoolAttrReleaseThreshold
|
|
||||||
#define cudaMemPoolSetAttribute hipMemPoolSetAttribute
|
|
||||||
#define cudaMemPool_t hipMemPool_t
|
|
||||||
#define cudaDeviceProp hipDeviceProp_t
|
#define cudaDeviceProp hipDeviceProp_t
|
||||||
#define cudaDeviceSynchronize hipDeviceSynchronize
|
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||||
#define cudaError_t hipError_t
|
#define cudaError_t hipError_t
|
||||||
|
@ -52,7 +48,6 @@
|
||||||
#define cudaEvent_t hipEvent_t
|
#define cudaEvent_t hipEvent_t
|
||||||
#define cudaEventDestroy hipEventDestroy
|
#define cudaEventDestroy hipEventDestroy
|
||||||
#define cudaFree hipFree
|
#define cudaFree hipFree
|
||||||
#define cudaFreeAsync hipFreeAsync
|
|
||||||
#define cudaFreeHost hipHostFree
|
#define cudaFreeHost hipHostFree
|
||||||
#define cudaGetDevice hipGetDevice
|
#define cudaGetDevice hipGetDevice
|
||||||
#define cudaGetDeviceCount hipGetDeviceCount
|
#define cudaGetDeviceCount hipGetDeviceCount
|
||||||
|
@ -60,7 +55,6 @@
|
||||||
#define cudaGetErrorString hipGetErrorString
|
#define cudaGetErrorString hipGetErrorString
|
||||||
#define cudaGetLastError hipGetLastError
|
#define cudaGetLastError hipGetLastError
|
||||||
#define cudaMalloc hipMalloc
|
#define cudaMalloc hipMalloc
|
||||||
#define cudaMallocFromPoolAsync hipMallocFromPoolAsync
|
|
||||||
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
||||||
#define cudaMemcpy hipMemcpy
|
#define cudaMemcpy hipMemcpy
|
||||||
#define cudaMemcpy2DAsync hipMemcpy2DAsync
|
#define cudaMemcpy2DAsync hipMemcpy2DAsync
|
||||||
|
@ -187,11 +181,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||||
do { \
|
do { \
|
||||||
cudaError_t err_ = (err); \
|
cudaError_t err_ = (err); \
|
||||||
if (err_ != cudaSuccess) { \
|
if (err_ != cudaSuccess) { \
|
||||||
int dev_id; \
|
int id; \
|
||||||
cudaGetDevice(&dev_id); \
|
cudaGetDevice(&id); \
|
||||||
fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
|
fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
|
||||||
cudaGetErrorString(err_)); \
|
cudaGetErrorString(err_)); \
|
||||||
fprintf(stderr, "current device: %d\n", dev_id); \
|
fprintf(stderr, "current device: %d\n", id); \
|
||||||
exit(1); \
|
exit(1); \
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
@ -201,11 +195,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||||
do { \
|
do { \
|
||||||
cublasStatus_t err_ = (err); \
|
cublasStatus_t err_ = (err); \
|
||||||
if (err_ != CUBLAS_STATUS_SUCCESS) { \
|
if (err_ != CUBLAS_STATUS_SUCCESS) { \
|
||||||
int dev_id; \
|
int id; \
|
||||||
cudaGetDevice(&dev_id); \
|
cudaGetDevice(&id); \
|
||||||
fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
|
fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
|
||||||
err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
|
err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
|
||||||
fprintf(stderr, "current device: %d\n", dev_id); \
|
fprintf(stderr, "current device: %d\n", id); \
|
||||||
exit(1); \
|
exit(1); \
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
@ -471,7 +465,6 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
|
||||||
|
|
||||||
#define MAX_STREAMS 8
|
#define MAX_STREAMS 8
|
||||||
static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
|
static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
|
||||||
static cudaMemPool_t g_cudaMemPools[GGML_CUDA_MAX_DEVICES] = { nullptr };
|
|
||||||
|
|
||||||
struct ggml_tensor_extra_gpu {
|
struct ggml_tensor_extra_gpu {
|
||||||
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
|
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
|
||||||
|
@ -5777,16 +5770,6 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void * ggml_cuda_pool_malloc_async(size_t size, size_t * actual_size, int id, cudaStream_t stream) {
|
|
||||||
if (g_cudaMemPools[id] == nullptr) {
|
|
||||||
return ggml_cuda_pool_malloc(size, actual_size);
|
|
||||||
}
|
|
||||||
void *ptr;
|
|
||||||
CUDA_CHECK(cudaMallocFromPoolAsync(&ptr, size, g_cudaMemPools[id], stream));
|
|
||||||
*actual_size = size;
|
|
||||||
return ptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_cuda_pool_free(void * ptr, size_t size) {
|
static void ggml_cuda_pool_free(void * ptr, size_t size) {
|
||||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||||
int id;
|
int id;
|
||||||
|
@ -5805,13 +5788,6 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static void ggml_cuda_pool_free_async(void * ptr, size_t actual_size, int id, cudaStream_t stream) {
|
|
||||||
if (g_cudaMemPools[id] == nullptr) {
|
|
||||||
return ggml_cuda_pool_free(ptr, actual_size);
|
|
||||||
}
|
|
||||||
CUDA_CHECK(cudaFreeAsync(ptr, stream));
|
|
||||||
}
|
|
||||||
|
|
||||||
void ggml_init_cublas() {
|
void ggml_init_cublas() {
|
||||||
static bool initialized = false;
|
static bool initialized = false;
|
||||||
|
|
||||||
|
@ -5858,13 +5834,6 @@ void ggml_init_cublas() {
|
||||||
// create cublas handle
|
// create cublas handle
|
||||||
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
|
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
|
||||||
CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
|
CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
|
||||||
|
|
||||||
// configure memory pool
|
|
||||||
cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id);
|
|
||||||
if (err == cudaSuccess) {
|
|
||||||
size_t treshold = UINT64_MAX;
|
|
||||||
CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// configure logging to stdout
|
// configure logging to stdout
|
||||||
|
@ -6458,7 +6427,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
|
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
|
||||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||||
size_t ne = row_diff*ne00;
|
size_t ne = row_diff*ne00;
|
||||||
src0_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src0_as, id, stream);
|
src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
|
||||||
to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
|
to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
|
||||||
}
|
}
|
||||||
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
|
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
|
||||||
|
@ -6469,12 +6438,13 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
||||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||||
size_t ne = src1_ncols*ne10;
|
size_t ne = src1_ncols*ne10;
|
||||||
src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
|
src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
|
||||||
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
|
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
|
||||||
}
|
}
|
||||||
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
|
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
|
||||||
size_t dst_f16_as = 0;
|
|
||||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
|
size_t dst_as = 0;
|
||||||
|
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
|
||||||
|
|
||||||
const half alpha_f16 = 1.0f;
|
const half alpha_f16 = 1.0f;
|
||||||
const half beta_f16 = 0.0f;
|
const half beta_f16 = 0.0f;
|
||||||
|
@ -6492,15 +6462,14 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||||
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
|
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
|
||||||
|
|
||||||
if (dst_f16_as != 0) {
|
ggml_cuda_pool_free(dst_f16, dst_as);
|
||||||
ggml_cuda_pool_free_async(dst_f16, dst_f16_as, id, stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (src0_as != 0) {
|
if (src0_as != 0) {
|
||||||
ggml_cuda_pool_free_async(src0_as_f16, src0_as, id, stream);
|
ggml_cuda_pool_free(src0_as_f16, src0_as);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (src1_as != 0) {
|
if (src1_as != 0) {
|
||||||
ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, stream);
|
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
@ -6510,7 +6479,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||||
if (src0->type != GGML_TYPE_F32) {
|
if (src0->type != GGML_TYPE_F32) {
|
||||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
|
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
|
||||||
GGML_ASSERT(to_fp32_cuda != nullptr);
|
GGML_ASSERT(to_fp32_cuda != nullptr);
|
||||||
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc_async(row_diff*ne00 * sizeof(float), &src0_as, id, stream); // NOLINT
|
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
|
||||||
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
|
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
|
||||||
}
|
}
|
||||||
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
|
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
|
||||||
|
@ -6527,7 +6496,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||||
&beta, dst_dd_i, ldc));
|
&beta, dst_dd_i, ldc));
|
||||||
|
|
||||||
if (src0_as != 0) {
|
if (src0_as != 0) {
|
||||||
ggml_cuda_pool_free_async(src0_ddq_as_f32, src0_as, id, stream);
|
ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6950,22 +6919,21 @@ static void ggml_cuda_op_mul_mat(
|
||||||
src0_dd[id] = (char *) src0_extra->data_device[id];
|
src0_dd[id] = (char *) src0_extra->data_device[id];
|
||||||
} else {
|
} else {
|
||||||
const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
|
const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
|
||||||
src0_dd[id] = (char *) ggml_cuda_pool_malloc_async(ggml_nbytes(src0), &src0_as[id], id, stream);
|
src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (src1_on_device && src1_is_contiguous) {
|
if (src1_on_device && src1_is_contiguous) {
|
||||||
src1_ddf[id] = (float *) src1_extra->data_device[id];
|
src1_ddf[id] = (float *) src1_extra->data_device[id];
|
||||||
} else {
|
} else {
|
||||||
src1_ddf[id] = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src1), &src1_asf[id], id, stream);
|
src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (convert_src1_to_q8_1) {
|
if (convert_src1_to_q8_1) {
|
||||||
const size_t size_dst_ddq = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
|
src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]);
|
||||||
src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async(size_dst_ddq, &src1_asq[id], id, stream);
|
|
||||||
|
|
||||||
if (src1_on_device && src1_is_contiguous) {
|
if (src1_on_device && src1_is_contiguous) {
|
||||||
quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
|
quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
|
||||||
// CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6973,7 +6941,7 @@ static void ggml_cuda_op_mul_mat(
|
||||||
dst_dd[id] = (float *) dst_extra->data_device[id];
|
dst_dd[id] = (float *) dst_extra->data_device[id];
|
||||||
} else {
|
} else {
|
||||||
const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
|
const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
|
||||||
dst_dd[id] = (float *) ggml_cuda_pool_malloc_async(size_dst_ddf, &dst_as[id], id, stream);
|
dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7099,6 +7067,24 @@ static void ggml_cuda_op_mul_mat(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (int64_t id = 0; id < g_device_count; ++id) {
|
||||||
|
CUDA_CHECK(ggml_cuda_set_device(id));
|
||||||
|
|
||||||
|
// free buffers again when done
|
||||||
|
if (src0_as[id] > 0) {
|
||||||
|
ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
|
||||||
|
}
|
||||||
|
if (src1_asf[id] > 0) {
|
||||||
|
ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
|
||||||
|
}
|
||||||
|
if (src1_asq[id] > 0) {
|
||||||
|
ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
|
||||||
|
}
|
||||||
|
if (dst_as[id] > 0) {
|
||||||
|
ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// main device waits for all other devices to be finished
|
// main device waits for all other devices to be finished
|
||||||
if (split && g_device_count > 1) {
|
if (split && g_device_count > 1) {
|
||||||
int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
|
int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
|
||||||
|
@ -7116,21 +7102,6 @@ static void ggml_cuda_op_mul_mat(
|
||||||
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int64_t id = 0; id < g_device_count; ++id) {
|
|
||||||
if (src0_as[id] > 0) {
|
|
||||||
ggml_cuda_pool_free_async(src0_dd[id], src0_as[id], id, g_cudaStreams[id][0]);
|
|
||||||
}
|
|
||||||
if (src1_asf[id] > 0) {
|
|
||||||
ggml_cuda_pool_free_async(src1_ddf[id], src1_asf[id], id, g_cudaStreams[id][0]);
|
|
||||||
}
|
|
||||||
if (src1_asq[id] > 0) {
|
|
||||||
ggml_cuda_pool_free_async(src1_ddq[id], src1_asq[id], id, g_cudaStreams[id][0]);
|
|
||||||
}
|
|
||||||
if (dst_as[id] > 0) {
|
|
||||||
ggml_cuda_pool_free_async(dst_dd[id], dst_as[id], id, g_cudaStreams[id][0]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
@ -7317,11 +7288,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||||
|
|
||||||
size_t src1_as = 0;
|
size_t src1_as = 0;
|
||||||
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne1 * sizeof(half), &src1_as, id, main_stream);
|
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
|
||||||
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
|
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
|
||||||
|
|
||||||
size_t dst_as = 0;
|
size_t dst_as = 0;
|
||||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &dst_as, id, main_stream);
|
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
|
||||||
|
|
||||||
GGML_ASSERT(ne12 % ne02 == 0);
|
GGML_ASSERT(ne12 % ne02 == 0);
|
||||||
GGML_ASSERT(ne13 % ne03 == 0);
|
GGML_ASSERT(ne13 % ne03 == 0);
|
||||||
|
@ -7375,8 +7346,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
size_t ptrs_src_s = 0;
|
size_t ptrs_src_s = 0;
|
||||||
size_t ptrs_dst_s = 0;
|
size_t ptrs_dst_s = 0;
|
||||||
|
|
||||||
ptrs_src = (const void **) ggml_cuda_pool_malloc_async(2*ne23*sizeof(void *), &ptrs_src_s, id, main_stream);
|
ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
|
||||||
ptrs_dst = ( void **) ggml_cuda_pool_malloc_async(1*ne23*sizeof(void *), &ptrs_dst_s, id, main_stream);
|
ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
|
||||||
|
|
||||||
dim3 block_dims(ne13, ne12);
|
dim3 block_dims(ne13, ne12);
|
||||||
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
||||||
|
@ -7389,6 +7360,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
dst->nb[2], dst->nb[3],
|
dst->nb[2], dst->nb[3],
|
||||||
r2, r3);
|
r2, r3);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
CUBLAS_CHECK(
|
CUBLAS_CHECK(
|
||||||
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
ne01, ne11, ne10,
|
ne01, ne11, ne10,
|
||||||
|
@ -7400,22 +7372,19 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||||
|
|
||||||
if (ptrs_src_s != 0) {
|
if (ptrs_src_s != 0) {
|
||||||
ggml_cuda_pool_free_async(ptrs_src, ptrs_src_s, id, main_stream);
|
ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
|
||||||
}
|
}
|
||||||
if (ptrs_dst_s != 0) {
|
if (ptrs_dst_s != 0) {
|
||||||
ggml_cuda_pool_free_async(ptrs_dst, ptrs_dst_s, id, main_stream);
|
ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||||
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
|
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
|
||||||
if (src1_as != 0) {
|
|
||||||
ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, main_stream);
|
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
||||||
}
|
ggml_cuda_pool_free(dst_f16, dst_as);
|
||||||
if (dst_as != 0) {
|
|
||||||
ggml_cuda_pool_free_async(dst_f16, dst_as, id, main_stream);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
|
|
@ -393,6 +393,7 @@ class TensorNameMap:
|
||||||
"layers.{bid}.attention_norm", # llama-pth
|
"layers.{bid}.attention_norm", # llama-pth
|
||||||
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
|
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
|
||||||
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
|
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
|
||||||
|
"model.layers.{bid}.ln1", # yi
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention norm 2
|
# Attention norm 2
|
||||||
|
@ -464,6 +465,7 @@ class TensorNameMap:
|
||||||
"layers.{bid}.ffn_norm", # llama-pth
|
"layers.{bid}.ffn_norm", # llama-pth
|
||||||
"encoder.layer.{bid}.output.LayerNorm", # bert
|
"encoder.layer.{bid}.output.LayerNorm", # bert
|
||||||
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
|
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
|
||||||
|
"model.layers.{bid}.ln2", # yi
|
||||||
),
|
),
|
||||||
|
|
||||||
# Feed-forward up
|
# Feed-forward up
|
||||||
|
|
11
llama.cpp
11
llama.cpp
|
@ -5196,11 +5196,12 @@ static int llama_decode_internal(
|
||||||
|
|
||||||
// If all tensors can be run on the GPU then using more than 1 thread is detrimental.
|
// If all tensors can be run on the GPU then using more than 1 thread is detrimental.
|
||||||
const bool full_offload_supported =
|
const bool full_offload_supported =
|
||||||
model.arch == LLM_ARCH_LLAMA ||
|
model.arch == LLM_ARCH_LLAMA ||
|
||||||
model.arch == LLM_ARCH_BAICHUAN ||
|
model.arch == LLM_ARCH_BAICHUAN ||
|
||||||
model.arch == LLM_ARCH_FALCON ||
|
model.arch == LLM_ARCH_FALCON ||
|
||||||
model.arch == LLM_ARCH_REFACT ||
|
model.arch == LLM_ARCH_REFACT ||
|
||||||
model.arch == LLM_ARCH_MPT;
|
model.arch == LLM_ARCH_MPT ||
|
||||||
|
model.arch == LLM_ARCH_STARCODER;
|
||||||
|
|
||||||
const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3;
|
const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3;
|
||||||
if (ggml_cpu_has_cublas() && full_offload_supported && fully_offloaded) {
|
if (ggml_cpu_has_cublas() && full_offload_supported && fully_offloaded) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue