If cuda device doesnt support memory pool than use old implementation.
This commit is contained in:
parent
08868a4474
commit
7e6f41327a
1 changed files with 68 additions and 51 deletions
115
ggml-cuda.cu
115
ggml-cuda.cu
|
@ -5719,16 +5719,6 @@ struct cuda_buffer {
|
||||||
static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS];
|
static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS];
|
||||||
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
|
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
|
||||||
|
|
||||||
static void * ggml_cuda_pool_malloc_async(size_t size, int id, cudaStream_t stream) {
|
|
||||||
void* ptr;
|
|
||||||
CUDA_CHECK(cudaMallocFromPoolAsync(&ptr, size, g_cudaMemPools[id], stream));
|
|
||||||
return ptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_cuda_pool_free_async(void * ptr, cudaStream_t stream) {
|
|
||||||
CUDA_CHECK(cudaFreeAsync(ptr, stream));
|
|
||||||
}
|
|
||||||
|
|
||||||
static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
||||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||||
int id;
|
int id;
|
||||||
|
@ -5783,6 +5773,16 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void * ggml_cuda_pool_malloc_async(size_t size, size_t * actual_size, int id, cudaStream_t stream) {
|
||||||
|
if (g_cudaMemPools[id] == nullptr) {
|
||||||
|
return ggml_cuda_pool_malloc(size, actual_size);
|
||||||
|
}
|
||||||
|
void *ptr;
|
||||||
|
CUDA_CHECK(cudaMallocFromPoolAsync(&ptr, size, g_cudaMemPools[id], stream));
|
||||||
|
*actual_size = size;
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cuda_pool_free(void * ptr, size_t size) {
|
static void ggml_cuda_pool_free(void * ptr, size_t size) {
|
||||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||||
int id;
|
int id;
|
||||||
|
@ -5801,6 +5801,13 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static void ggml_cuda_pool_free_async(void * ptr, size_t actual_size, int id, cudaStream_t stream) {
|
||||||
|
if (g_cudaMemPools[id] == nullptr) {
|
||||||
|
return ggml_cuda_pool_free(ptr, actual_size);
|
||||||
|
}
|
||||||
|
CUDA_CHECK(cudaFreeAsync(ptr, stream));
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_init_cublas() {
|
void ggml_init_cublas() {
|
||||||
static bool initialized = false;
|
static bool initialized = false;
|
||||||
|
|
||||||
|
@ -5857,10 +5864,12 @@ void ggml_init_cublas() {
|
||||||
CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
|
CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
|
||||||
|
|
||||||
// configure memory pool
|
// configure memory pool
|
||||||
CUDA_CHECK(cudaDeviceGetMemPool(&g_cudaMemPools[id], id));
|
cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id);
|
||||||
|
if (err == cudaSuccess) {
|
||||||
size_t treshold = UINT64_MAX;
|
size_t treshold = UINT64_MAX;
|
||||||
CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold));
|
CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// configure logging to stdout
|
// configure logging to stdout
|
||||||
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
|
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
|
||||||
|
@ -6453,8 +6462,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
|
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
|
||||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||||
size_t ne = row_diff*ne00;
|
size_t ne = row_diff*ne00;
|
||||||
src0_as = ne * sizeof(half);
|
src0_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src0_as, id, stream);
|
||||||
src0_as_f16 = (half *) ggml_cuda_pool_malloc_async(src0_as, id, stream);
|
|
||||||
to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
|
to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
|
||||||
}
|
}
|
||||||
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
|
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
|
||||||
|
@ -6465,13 +6473,12 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
||||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||||
size_t ne = src1_ncols*ne10;
|
size_t ne = src1_ncols*ne10;
|
||||||
src1_as = ne * sizeof(half);
|
src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
|
||||||
src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(src1_as, id, stream);
|
|
||||||
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
|
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
|
||||||
}
|
}
|
||||||
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
|
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
|
||||||
|
size_t dst_f16_as = 0;
|
||||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), id, stream);
|
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
|
||||||
|
|
||||||
const half alpha_f16 = 1.0f;
|
const half alpha_f16 = 1.0f;
|
||||||
const half beta_f16 = 0.0f;
|
const half beta_f16 = 0.0f;
|
||||||
|
@ -6489,12 +6496,15 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||||
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
|
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
|
||||||
|
|
||||||
ggml_cuda_pool_free_async(dst_f16, stream);
|
if (dst_f16_as != 0) {
|
||||||
|
ggml_cuda_pool_free_async(dst_f16, dst_f16_as, id, stream);
|
||||||
|
}
|
||||||
|
|
||||||
if (src0_as != 0) {
|
if (src0_as != 0) {
|
||||||
ggml_cuda_pool_free_async(src0_as_f16, stream);
|
ggml_cuda_pool_free_async(src0_as_f16, src0_as, id, stream);
|
||||||
}
|
}
|
||||||
if (src1_as != 0) {
|
if (src1_as != 0) {
|
||||||
ggml_cuda_pool_free_async(src1_as_f16, stream);
|
ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
@ -6504,7 +6514,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||||
if (src0->type != GGML_TYPE_F32) {
|
if (src0->type != GGML_TYPE_F32) {
|
||||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
|
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
|
||||||
GGML_ASSERT(to_fp32_cuda != nullptr);
|
GGML_ASSERT(to_fp32_cuda != nullptr);
|
||||||
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc_async(row_diff*ne00 * sizeof(float), id, stream); // NOLINT
|
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc_async(row_diff*ne00 * sizeof(float), &src0_as, id, stream); // NOLINT
|
||||||
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
|
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
|
||||||
}
|
}
|
||||||
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
|
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
|
||||||
|
@ -6521,7 +6531,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||||
&beta, dst_dd_i, ldc));
|
&beta, dst_dd_i, ldc));
|
||||||
|
|
||||||
if (src0_as != 0) {
|
if (src0_as != 0) {
|
||||||
ggml_cuda_pool_free_async(src0_ddq_as_f32, stream);
|
ggml_cuda_pool_free_async(src0_ddq_as_f32, src0_as, id, stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6944,21 +6954,18 @@ static void ggml_cuda_op_mul_mat(
|
||||||
src0_dd[id] = (char *) src0_extra->data_device[id];
|
src0_dd[id] = (char *) src0_extra->data_device[id];
|
||||||
} else {
|
} else {
|
||||||
const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
|
const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
|
||||||
src0_as[id] = ggml_nbytes(src0);
|
src0_dd[id] = (char *) ggml_cuda_pool_malloc_async(ggml_nbytes(src0), &src0_as[id], id, stream);
|
||||||
src0_dd[id] = (char *) ggml_cuda_pool_malloc_async(src0_as[id], id, stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (src1_on_device && src1_is_contiguous) {
|
if (src1_on_device && src1_is_contiguous) {
|
||||||
src1_ddf[id] = (float *) src1_extra->data_device[id];
|
src1_ddf[id] = (float *) src1_extra->data_device[id];
|
||||||
} else {
|
} else {
|
||||||
|
src1_ddf[id] = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src1), &src1_asf[id], id, stream);
|
||||||
src1_ddf[id] = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src1), id, stream);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (convert_src1_to_q8_1) {
|
if (convert_src1_to_q8_1) {
|
||||||
const size_t size_dst_ddq = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
|
const size_t size_dst_ddq = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
|
||||||
src1_asq[id] = size_dst_ddq;
|
src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async(size_dst_ddq, &src1_asq[id], id, stream);
|
||||||
src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async(size_dst_ddq, id, stream);
|
|
||||||
|
|
||||||
if (src1_on_device && src1_is_contiguous) {
|
if (src1_on_device && src1_is_contiguous) {
|
||||||
quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
|
quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
|
||||||
|
@ -6970,8 +6977,7 @@ static void ggml_cuda_op_mul_mat(
|
||||||
dst_dd[id] = (float *) dst_extra->data_device[id];
|
dst_dd[id] = (float *) dst_extra->data_device[id];
|
||||||
} else {
|
} else {
|
||||||
const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
|
const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
|
||||||
dst_as[id] = size_dst_ddf;
|
dst_dd[id] = (float *) ggml_cuda_pool_malloc_async(size_dst_ddf, &dst_as[id], id, stream);
|
||||||
dst_dd[id] = (float *) ggml_cuda_pool_malloc_async(size_dst_ddf, id, stream);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7107,18 +7113,6 @@ static void ggml_cuda_op_mul_mat(
|
||||||
for (int64_t is = 0; is < is_max; ++is) {
|
for (int64_t is = 0; is < is_max; ++is) {
|
||||||
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0));
|
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0));
|
||||||
}
|
}
|
||||||
if (src0_as[id] > 0) {
|
|
||||||
ggml_cuda_pool_free_async(src0_dd[id], g_cudaStreams[id][0]);
|
|
||||||
}
|
|
||||||
if (src1_asf[id] > 0) {
|
|
||||||
ggml_cuda_pool_free_async(src1_ddf[id], g_cudaStreams[id][0]);
|
|
||||||
}
|
|
||||||
if (src1_asq[id] > 0) {
|
|
||||||
ggml_cuda_pool_free_async(src1_ddq[id], g_cudaStreams[id][0]);
|
|
||||||
}
|
|
||||||
if (dst_as[id] > 0) {
|
|
||||||
ggml_cuda_pool_free_async(dst_dd[id], g_cudaStreams[id][0]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7126,6 +7120,21 @@ static void ggml_cuda_op_mul_mat(
|
||||||
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (int64_t id = 0; id < g_device_count; ++id) {
|
||||||
|
if (src0_as[id] > 0) {
|
||||||
|
ggml_cuda_pool_free_async(src0_dd[id], src0_as[id], id, g_cudaStreams[id][0]);
|
||||||
|
}
|
||||||
|
if (src1_asf[id] > 0) {
|
||||||
|
ggml_cuda_pool_free_async(src1_ddf[id], src1_asf[id], id, g_cudaStreams[id][0]);
|
||||||
|
}
|
||||||
|
if (src1_asq[id] > 0) {
|
||||||
|
ggml_cuda_pool_free_async(src1_ddq[id], src1_asq[id], id, g_cudaStreams[id][0]);
|
||||||
|
}
|
||||||
|
if (dst_as[id] > 0) {
|
||||||
|
ggml_cuda_pool_free_async(dst_dd[id], dst_as[id], id, g_cudaStreams[id][0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
@ -7311,10 +7320,12 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
||||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||||
|
|
||||||
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne1 * sizeof(half), id, main_stream);
|
size_t src1_as = 0;
|
||||||
|
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne1 * sizeof(half), &src1_as, id, main_stream);
|
||||||
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
|
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
|
||||||
|
|
||||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), id, main_stream);
|
size_t dst_as = 0;
|
||||||
|
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &dst_as, id, main_stream);
|
||||||
|
|
||||||
GGML_ASSERT(ne12 % ne02 == 0);
|
GGML_ASSERT(ne12 % ne02 == 0);
|
||||||
GGML_ASSERT(ne13 % ne03 == 0);
|
GGML_ASSERT(ne13 % ne03 == 0);
|
||||||
|
@ -7362,7 +7373,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
// use cublasGemmBatchedEx
|
// use cublasGemmBatchedEx
|
||||||
const int ne23 = ne12*ne13;
|
const int ne23 = ne12*ne13;
|
||||||
// allocate device memory for pointers
|
// allocate device memory for pointers
|
||||||
void ** ptrs_as = (void **)ggml_cuda_pool_malloc_async(3*ne23*sizeof(void *), id, main_stream);
|
size_t ptrs_s = 0;
|
||||||
|
void ** ptrs_as = (void **)ggml_cuda_pool_malloc_async(3*ne23*sizeof(void *), &ptrs_s, id, main_stream);
|
||||||
|
|
||||||
dim3 block_dims(ne13, ne12);
|
dim3 block_dims(ne13, ne12);
|
||||||
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
||||||
|
@ -7386,15 +7398,20 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
CUBLAS_COMPUTE_16F,
|
CUBLAS_COMPUTE_16F,
|
||||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||||
// free device memory for pointers
|
// free device memory for pointers
|
||||||
ggml_cuda_pool_free_async(ptrs_as, main_stream);
|
if (ptrs_s != 0) {
|
||||||
|
ggml_cuda_pool_free_async(ptrs_as, ptrs_s, id, main_stream);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||||
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
|
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
|
||||||
|
if (src1_as != 0) {
|
||||||
ggml_cuda_pool_free_async(src1_as_f16, main_stream);
|
ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, main_stream);
|
||||||
ggml_cuda_pool_free_async(dst_f16, main_stream);
|
}
|
||||||
|
if (dst_as != 0) {
|
||||||
|
ggml_cuda_pool_free_async(dst_f16, dst_as, id, main_stream);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue