backend : offload large batches to GPU (#6083)
* backend : offload large batches to GPU * fix hip * code cleanup * fix CUDA split buffers * Update ggml-backend-impl.h Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * cuda : fix memset without set_device * imatrix : remove sched affix from weight names * sched : add a new split if the current one has too many inputs reduce max inputs per split more cleanup * update backends ggml-ci --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
parent
496bc79bc2
commit
2bf8d0f7c4
14 changed files with 349 additions and 396 deletions
297
ggml-cuda.cu
297
ggml-cuda.cu
|
@ -82,6 +82,10 @@
|
|||
#define cudaGetDeviceProperties hipGetDeviceProperties
|
||||
#define cudaGetErrorString hipGetErrorString
|
||||
#define cudaGetLastError hipGetLastError
|
||||
#define cudaHostRegister hipHostRegister
|
||||
#define cudaHostRegisterPortable hipHostRegisterPortable
|
||||
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
|
||||
#define cudaHostUnregister hipHostUnregister
|
||||
#define cudaLaunchHostFunc hipLaunchHostFunc
|
||||
#ifdef GGML_HIP_UMA
|
||||
#define cudaMalloc hipMallocManaged
|
||||
|
@ -7787,11 +7791,7 @@ struct cuda_pool_alloc {
|
|||
|
||||
static bool g_cublas_loaded = false;
|
||||
|
||||
GGML_CALL bool ggml_cublas_loaded(void) {
|
||||
return g_cublas_loaded;
|
||||
}
|
||||
|
||||
GGML_CALL void ggml_init_cublas() {
|
||||
static void ggml_init_cublas() {
|
||||
static bool initialized = false;
|
||||
|
||||
if (!initialized) {
|
||||
|
@ -7880,7 +7880,7 @@ GGML_CALL void ggml_init_cublas() {
|
|||
}
|
||||
}
|
||||
|
||||
GGML_CALL void * ggml_cuda_host_malloc(size_t size) {
|
||||
static void * ggml_cuda_host_malloc(size_t size) {
|
||||
if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -7890,7 +7890,7 @@ GGML_CALL void * ggml_cuda_host_malloc(size_t size) {
|
|||
if (err != cudaSuccess) {
|
||||
// clear the error
|
||||
cudaGetLastError();
|
||||
fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
|
||||
fprintf(stderr, "%s: warning: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
|
||||
size/1024.0/1024.0, cudaGetErrorString(err));
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -7898,7 +7898,7 @@ GGML_CALL void * ggml_cuda_host_malloc(size_t size) {
|
|||
return ptr;
|
||||
}
|
||||
|
||||
GGML_CALL void ggml_cuda_host_free(void * ptr) {
|
||||
static void ggml_cuda_host_free(void * ptr) {
|
||||
CUDA_CHECK(cudaFreeHost(ptr));
|
||||
}
|
||||
|
||||
|
@ -9036,21 +9036,13 @@ static void ggml_cuda_op_soft_max(
|
|||
|
||||
// positions tensor
|
||||
float * src2_dd = nullptr;
|
||||
cuda_pool_alloc<float> src2_f;
|
||||
|
||||
ggml_tensor * src2 = dst->src[2];
|
||||
const bool use_src2 = src2 != nullptr;
|
||||
|
||||
if (use_src2) {
|
||||
const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU;
|
||||
|
||||
if (src2_on_device) {
|
||||
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
|
||||
src2_dd = (float *) src2_extra->data_device[g_main_device];
|
||||
} else {
|
||||
src2_dd = src2_f.alloc(ggml_nelements(src2));
|
||||
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
|
||||
}
|
||||
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
|
||||
src2_dd = (float *) src2_extra->data_device[g_main_device];
|
||||
}
|
||||
|
||||
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream);
|
||||
|
@ -9107,55 +9099,24 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
|
|||
ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
|
||||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
|
||||
const bool src0_on_device = src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
|
||||
const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_TYPE_GPU;
|
||||
const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU;
|
||||
|
||||
// dd = data device
|
||||
float * src0_ddf = nullptr;
|
||||
float * src1_ddf = nullptr;
|
||||
float * dst_ddf = nullptr;
|
||||
|
||||
cuda_pool_alloc<float> src0_f;
|
||||
cuda_pool_alloc<float> src1_f;
|
||||
cuda_pool_alloc<float> dst_f;
|
||||
|
||||
ggml_cuda_set_device(g_main_device);
|
||||
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
||||
|
||||
if (src0_on_device) {
|
||||
src0_ddf = (float *) src0_extra->data_device[g_main_device];
|
||||
} else {
|
||||
src0_ddf = src0_f.alloc(ggml_nelements(src0));
|
||||
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream));
|
||||
}
|
||||
src0_ddf = (float *) src0_extra->data_device[g_main_device];
|
||||
|
||||
if (use_src1) {
|
||||
if (src1_on_device) {
|
||||
src1_ddf = (float *) src1_extra->data_device[g_main_device];
|
||||
} else {
|
||||
src1_ddf = src1_f.alloc(ggml_nelements(src1));
|
||||
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf, src1, 0, 0, 0, nrows1, main_stream));
|
||||
}
|
||||
}
|
||||
if (dst_on_device) {
|
||||
dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||
} else {
|
||||
dst_ddf = dst_f.alloc(ggml_nelements(dst));
|
||||
src1_ddf = (float *) src1_extra->data_device[g_main_device];
|
||||
}
|
||||
dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||
|
||||
// do the computation
|
||||
op(src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// copy dst to host if necessary
|
||||
if (!dst_on_device) {
|
||||
CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
|
||||
}
|
||||
|
||||
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_set_peer_access(const int n_tokens) {
|
||||
|
@ -9251,7 +9212,6 @@ static void ggml_cuda_op_mul_mat(
|
|||
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
|
||||
const bool src0_on_device = src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
|
||||
const bool src0_is_contiguous = ggml_is_contiguous(src0);
|
||||
const bool src1_is_contiguous = ggml_is_contiguous(src1);
|
||||
|
||||
|
@ -9322,13 +9282,13 @@ static void ggml_cuda_op_mul_mat(
|
|||
|
||||
used_devices++;
|
||||
|
||||
const bool src1_on_device = src1->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device;
|
||||
const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device;
|
||||
const bool src1_on_device = id == g_main_device; // TODO: check from buffer
|
||||
const bool dst_on_device = id == g_main_device;
|
||||
|
||||
ggml_cuda_set_device(id);
|
||||
cudaStream_t stream = g_cudaStreams[id][0];
|
||||
|
||||
if (src0_on_device && src0_is_contiguous) {
|
||||
if (src0_is_contiguous) {
|
||||
dev[id].src0_dd = (char *) src0_extra->data_device[id];
|
||||
} else {
|
||||
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ggml_nbytes(src0));
|
||||
|
@ -9374,8 +9334,8 @@ static void ggml_cuda_op_mul_mat(
|
|||
continue;
|
||||
}
|
||||
|
||||
const bool src1_on_device = src1->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device;
|
||||
const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device;
|
||||
const bool src1_on_device = id == g_main_device; // TODO: check from buffer
|
||||
const bool dst_on_device = id == g_main_device;
|
||||
const int64_t row_diff = dev[id].row_high - dev[id].row_low;
|
||||
|
||||
ggml_cuda_set_device(id);
|
||||
|
@ -9400,12 +9360,12 @@ static void ggml_cuda_op_mul_mat(
|
|||
|
||||
// the main device memory buffer can be on VRAM scratch, with space for all partial results
|
||||
// in that case an offset on dst_ddf_i is needed
|
||||
if (dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device) {
|
||||
if (id == g_main_device) {
|
||||
dst_dd_i += dev[id].row_low; // offset is 0 if no tensor split
|
||||
}
|
||||
|
||||
// copy src0, src1 to device if necessary
|
||||
if (src1->backend == GGML_BACKEND_TYPE_GPU && src1_is_contiguous) {
|
||||
if (src1_is_contiguous) {
|
||||
if (id != g_main_device) {
|
||||
if (convert_src1_to_q8_1) {
|
||||
char * src1_ddq_i_source = dev[g_main_device].src1_ddq + src1_ddq_i_offset;
|
||||
|
@ -9418,19 +9378,19 @@ static void ggml_cuda_op_mul_mat(
|
|||
src1_ncols*ne10*sizeof(float), stream));
|
||||
}
|
||||
}
|
||||
} else if (src1->backend == GGML_BACKEND_TYPE_CPU || (src1_on_device && !src1_is_contiguous)) {
|
||||
} else if (src1_on_device && !src1_is_contiguous) {
|
||||
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
|
||||
src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
if (convert_src1_to_q8_1 && (src1->backend == GGML_BACKEND_TYPE_CPU || !src1_is_contiguous)) {
|
||||
if (convert_src1_to_q8_1 && !src1_is_contiguous) {
|
||||
quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
if (src1_col_0 == 0 && (!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) {
|
||||
if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
|
||||
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
|
||||
}
|
||||
|
||||
|
@ -9441,17 +9401,7 @@ static void ggml_cuda_op_mul_mat(
|
|||
|
||||
// copy dst to host or other device if necessary
|
||||
if (!dst_on_device) {
|
||||
void * dst_off_device;
|
||||
cudaMemcpyKind kind;
|
||||
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
|
||||
dst_off_device = dst->data;
|
||||
kind = cudaMemcpyDeviceToHost;
|
||||
} else if (dst->backend == GGML_BACKEND_TYPE_GPU) {
|
||||
dst_off_device = dst_extra->data_device[g_main_device];
|
||||
kind = cudaMemcpyDeviceToDevice;
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
void * dst_off_device = dst_extra->data_device[g_main_device];
|
||||
if (split) {
|
||||
// src0 = weight matrix is saved as a transposed matrix for better memory layout.
|
||||
// dst is NOT transposed.
|
||||
|
@ -9462,28 +9412,26 @@ static void ggml_cuda_op_mul_mat(
|
|||
GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
|
||||
dhf_dst_i += src1_col_0*ne0 + dev[id].row_low;
|
||||
#if !defined(GGML_USE_HIPBLAS)
|
||||
if (kind == cudaMemcpyDeviceToDevice) {
|
||||
// cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
|
||||
cudaMemcpy3DPeerParms p = {};
|
||||
p.dstDevice = g_main_device;
|
||||
p.dstPtr = make_cudaPitchedPtr(dhf_dst_i, ne0*sizeof(float), row_diff, src1_ncols);
|
||||
p.srcDevice = id;
|
||||
p.srcPtr = make_cudaPitchedPtr(dst_dd_i, row_diff*sizeof(float), row_diff, src1_ncols);
|
||||
p.extent = make_cudaExtent(row_diff*sizeof(float), src1_ncols, 1);
|
||||
CUDA_CHECK(cudaMemcpy3DPeerAsync(&p, stream));
|
||||
} else
|
||||
// cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
|
||||
cudaMemcpy3DPeerParms p = {};
|
||||
p.dstDevice = g_main_device;
|
||||
p.dstPtr = make_cudaPitchedPtr(dhf_dst_i, ne0*sizeof(float), row_diff, src1_ncols);
|
||||
p.srcDevice = id;
|
||||
p.srcPtr = make_cudaPitchedPtr(dst_dd_i, row_diff*sizeof(float), row_diff, src1_ncols);
|
||||
p.extent = make_cudaExtent(row_diff*sizeof(float), src1_ncols, 1);
|
||||
CUDA_CHECK(cudaMemcpy3DPeerAsync(&p, stream));
|
||||
#else
|
||||
// HIP does not support cudaMemcpy3DPeerAsync or vmm pools
|
||||
CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float),
|
||||
dst_dd_i, row_diff*sizeof(float),
|
||||
row_diff*sizeof(float), src1_ncols,
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
#endif
|
||||
{
|
||||
CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float),
|
||||
dst_dd_i, row_diff*sizeof(float),
|
||||
row_diff*sizeof(float), src1_ncols,
|
||||
kind, stream));
|
||||
}
|
||||
} else {
|
||||
float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
|
||||
GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
|
||||
dhf_dst_i += src1_col_0*ne0;
|
||||
CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_dd_i, src1_ncols*ne0*sizeof(float), kind, stream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_dd_i, src1_ncols*ne0*sizeof(float), cudaMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -9510,11 +9458,6 @@ static void ggml_cuda_op_mul_mat(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
|
||||
ggml_cuda_set_device(g_main_device);
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
@ -9599,36 +9542,19 @@ static void ggml_cuda_pad(const ggml_tensor * src0, const ggml_tensor * src1, gg
|
|||
static void ggml_cuda_arange(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
|
||||
const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU;
|
||||
|
||||
// dd = data device
|
||||
float * src0_ddf = nullptr;
|
||||
float * src1_ddf = nullptr;
|
||||
float * dst_ddf = nullptr;
|
||||
|
||||
cuda_pool_alloc<float> dst_f;
|
||||
|
||||
ggml_cuda_set_device(g_main_device);
|
||||
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
||||
|
||||
if (dst_on_device) {
|
||||
dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||
} else {
|
||||
dst_ddf = dst_f.alloc(ggml_nelements(dst));
|
||||
}
|
||||
dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||
|
||||
// do the computation
|
||||
ggml_cuda_op_arange(src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// copy dst to host if necessary
|
||||
if (!dst_on_device) {
|
||||
CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
|
||||
}
|
||||
|
||||
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_timestep_embedding(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
@ -9639,21 +9565,6 @@ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src
|
|||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
|
||||
}
|
||||
|
||||
GGML_CALL bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
||||
if (!g_cublas_loaded) return false;
|
||||
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
|
||||
// TODO: find the optimal values for these
|
||||
return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
|
||||
src1->type == GGML_TYPE_F32 &&
|
||||
dst->type == GGML_TYPE_F32 &&
|
||||
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32);
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
|
||||
GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
|
||||
GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
|
||||
|
@ -9891,11 +9802,6 @@ static void ggml_cuda_mul_mat_batched_cublas(const ggml_tensor * src0, const ggm
|
|||
}
|
||||
|
||||
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const bool all_on_device =
|
||||
(src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT) &&
|
||||
(src1->backend == GGML_BACKEND_TYPE_GPU) &&
|
||||
( dst->backend == GGML_BACKEND_TYPE_GPU);
|
||||
|
||||
const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
|
||||
|
||||
int64_t min_compute_capability = INT_MAX;
|
||||
|
@ -9972,13 +9878,13 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|||
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
||||
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
||||
|
||||
if (!split && all_on_device && !fp16_performance_good && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
// KQ single-batch
|
||||
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
|
||||
} else if (!split && all_on_device && !fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
||||
} else if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
||||
// KQV single-batch
|
||||
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
|
||||
} else if (!split && all_on_device && fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
||||
} else if (!split && fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
||||
// KQ + KQV multi-batch
|
||||
ggml_cuda_mul_mat_batched_cublas(src0, src1, dst);
|
||||
} else if (use_dequantize_mul_mat_vec) {
|
||||
|
@ -10178,6 +10084,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
|||
ggml_cuda_mul_mat_id_cublas(dst);
|
||||
// TODO: mmq/mmv support
|
||||
#endif
|
||||
cudaStream_t stream = g_cudaStreams[g_main_device][0];
|
||||
|
||||
const size_t nb11 = src1->nb[1];
|
||||
const size_t nb1 = dst->nb[1];
|
||||
|
@ -10187,16 +10094,9 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
|||
const int32_t n_as = ((int32_t *) dst->op_params)[1];
|
||||
|
||||
std::vector<char> ids_host(ggml_nbytes(ids));
|
||||
|
||||
cudaStream_t stream = g_cudaStreams[g_main_device][0];
|
||||
|
||||
if (ids->backend == GGML_BACKEND_TYPE_GPU) {
|
||||
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
|
||||
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
} else {
|
||||
memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
|
||||
}
|
||||
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
|
||||
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
|
||||
const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
|
||||
|
@ -10213,20 +10113,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
|||
src1_row.extra = &src1_row_extra;
|
||||
dst_row.extra = &dst_row_extra;
|
||||
|
||||
char * src1_original = src1->backend == GGML_BACKEND_TYPE_CPU ?
|
||||
(char *) src1->data : (char *) src1_extra->data_device[g_main_device];
|
||||
char * dst_original = dst->backend == GGML_BACKEND_TYPE_CPU ?
|
||||
(char *) dst->data : (char *) dst_extra->data_device[g_main_device];
|
||||
char * src1_original = (char *) src1_extra->data_device[g_main_device];
|
||||
char * dst_original = (char *) dst_extra->data_device[g_main_device];
|
||||
|
||||
if (src1->ne[1] == 1) {
|
||||
GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU);
|
||||
GGML_ASSERT(dst->backend == GGML_BACKEND_TYPE_GPU);
|
||||
|
||||
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
||||
//int32_t row_id;
|
||||
//CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
|
||||
//CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
|
||||
|
||||
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
|
||||
|
||||
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
||||
|
@ -10248,11 +10139,6 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
|||
src1_row_extra.data_device[g_main_device] = src1_contiguous.get();
|
||||
dst_row_extra.data_device[g_main_device] = dst_contiguous.get();
|
||||
|
||||
const cudaMemcpyKind src1_kind = src1->backend == GGML_BACKEND_TYPE_CPU ?
|
||||
cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
|
||||
const cudaMemcpyKind dst_kind = dst->backend == GGML_BACKEND_TYPE_CPU ?
|
||||
cudaMemcpyDeviceToHost : cudaMemcpyDeviceToDevice;
|
||||
|
||||
for (int32_t row_id = 0; row_id < n_as; ++row_id) {
|
||||
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
|
||||
|
||||
|
@ -10267,7 +10153,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
|||
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
||||
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, src1_original + i01*nb11,
|
||||
nb11, src1_kind, stream));
|
||||
nb11, cudaMemcpyDeviceToDevice, stream));
|
||||
num_src1_rows++;
|
||||
}
|
||||
|
||||
|
@ -10299,15 +10185,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
|||
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
||||
|
||||
CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous.get() + num_src1_rows*nb1,
|
||||
nb1, dst_kind, stream));
|
||||
nb1, cudaMemcpyDeviceToDevice, stream));
|
||||
num_src1_rows++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
@ -10435,7 +10317,7 @@ static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_spl
|
|||
return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]);
|
||||
}
|
||||
|
||||
GGML_CALL static void ggml_cuda_set_main_device(const int main_device) {
|
||||
static void ggml_cuda_set_main_device(const int main_device) {
|
||||
if (main_device >= g_device_count) {
|
||||
fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n",
|
||||
main_device, g_device_count, g_main_device);
|
||||
|
@ -10450,18 +10332,9 @@ GGML_CALL static void ggml_cuda_set_main_device(const int main_device) {
|
|||
}
|
||||
}
|
||||
|
||||
GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
||||
static bool ggml_cuda_compute_forward(struct ggml_tensor * tensor) {
|
||||
if (!g_cublas_loaded) return false;
|
||||
|
||||
ggml_cuda_func_t func;
|
||||
const bool any_on_device = tensor->backend == GGML_BACKEND_TYPE_GPU
|
||||
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU || tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT))
|
||||
|| (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_TYPE_GPU);
|
||||
|
||||
if (!any_on_device && tensor->op != GGML_OP_MUL_MAT && tensor->op != GGML_OP_MUL_MAT_ID) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tensor->op == GGML_OP_MUL_MAT) {
|
||||
if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
|
||||
#ifndef NDEBUG
|
||||
|
@ -10471,6 +10344,8 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
|
|||
}
|
||||
}
|
||||
|
||||
ggml_cuda_func_t func;
|
||||
|
||||
switch (tensor->op) {
|
||||
case GGML_OP_REPEAT:
|
||||
func = ggml_cuda_repeat;
|
||||
|
@ -10548,15 +10423,9 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
|
|||
func = ggml_cuda_rms_norm;
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_mul_mat;
|
||||
break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src[2], tensor->src[1], tensor)) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_mul_mat_id;
|
||||
break;
|
||||
case GGML_OP_SCALE:
|
||||
|
@ -10613,17 +10482,11 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
|
|||
ggml_cuda_set_peer_access(tensor->src[1]->ne[1]);
|
||||
}
|
||||
|
||||
if (params->ith != 0) {
|
||||
return true;
|
||||
}
|
||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||
return true;
|
||||
}
|
||||
func(tensor->src[0], tensor->src[1], tensor);
|
||||
return true;
|
||||
}
|
||||
|
||||
GGML_CALL int ggml_cuda_get_device_count() {
|
||||
static int ggml_cuda_get_device_count() {
|
||||
int device_count;
|
||||
if (cudaGetDeviceCount(&device_count) != cudaSuccess) {
|
||||
return 0;
|
||||
|
@ -10631,7 +10494,7 @@ GGML_CALL int ggml_cuda_get_device_count() {
|
|||
return device_count;
|
||||
}
|
||||
|
||||
GGML_CALL void ggml_cuda_get_device_description(int device, char * description, size_t description_size) {
|
||||
static void ggml_cuda_get_device_description(int device, char * description, size_t description_size) {
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
|
||||
snprintf(description, description_size, "%s", prop.name);
|
||||
|
@ -10736,6 +10599,7 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t
|
|||
size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
|
||||
|
||||
if (padded_size > original_size && tensor->view_src == nullptr) {
|
||||
ggml_cuda_set_device(ctx->device);
|
||||
CUDA_CHECK(cudaMemset((char *)tensor->data + original_size, 0, padded_size - original_size));
|
||||
}
|
||||
}
|
||||
|
@ -10873,6 +10737,8 @@ static ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {
|
|||
};
|
||||
|
||||
GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
|
||||
ggml_init_cublas();
|
||||
|
||||
// FIXME: this is not thread safe
|
||||
if (device >= ggml_backend_cuda_get_device_count()) {
|
||||
return nullptr;
|
||||
|
@ -11157,6 +11023,8 @@ static ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_interface
|
|||
};
|
||||
|
||||
GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split) {
|
||||
ggml_init_cublas();
|
||||
|
||||
// FIXME: this is not thread safe
|
||||
static std::map<std::array<float, GGML_CUDA_MAX_DEVICES>, struct ggml_backend_buffer_type> buft_map;
|
||||
|
||||
|
@ -11348,9 +11216,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||
|
||||
ggml_cuda_set_main_device(cuda_ctx->device);
|
||||
|
||||
ggml_compute_params params = {};
|
||||
params.type = GGML_TASK_TYPE_COMPUTE;
|
||||
params.ith = 0;
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
|
@ -11372,7 +11237,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||
}
|
||||
#endif
|
||||
|
||||
bool ok = ggml_cuda_compute_forward(¶ms, node);
|
||||
bool ok = ggml_cuda_compute_forward(node);
|
||||
if (!ok) {
|
||||
fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||
}
|
||||
|
@ -11509,6 +11374,14 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|||
UNUSED(backend);
|
||||
}
|
||||
|
||||
GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
|
||||
const int min_batch_size = 32;
|
||||
|
||||
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
|
||||
|
||||
UNUSED(backend);
|
||||
}
|
||||
|
||||
static ggml_backend_event_t ggml_backend_cuda_event_new(ggml_backend_t backend) {
|
||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
||||
|
||||
|
@ -11571,6 +11444,7 @@ static ggml_backend_i ggml_backend_cuda_interface = {
|
|||
/* .graph_plan_compute = */ NULL,
|
||||
/* .graph_compute = */ ggml_backend_cuda_graph_compute,
|
||||
/* .supports_op = */ ggml_backend_cuda_supports_op,
|
||||
/* .offload_op = */ ggml_backend_cuda_offload_op,
|
||||
/* .event_new = */ ggml_backend_cuda_event_new,
|
||||
/* .event_free = */ ggml_backend_cuda_event_free,
|
||||
/* .event_record = */ ggml_backend_cuda_event_record,
|
||||
|
@ -11584,7 +11458,7 @@ static ggml_guid_t ggml_backend_cuda_guid() {
|
|||
}
|
||||
|
||||
GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) {
|
||||
ggml_init_cublas(); // TODO: remove from ggml.c
|
||||
ggml_init_cublas();
|
||||
|
||||
if (device < 0 || device >= ggml_cuda_get_device_count()) {
|
||||
fprintf(stderr, "%s: error: invalid device %d\n", __func__, device);
|
||||
|
@ -11627,6 +11501,31 @@ GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, si
|
|||
CUDA_CHECK(cudaMemGetInfo(free, total));
|
||||
}
|
||||
|
||||
GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
|
||||
if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
|
||||
if (err != cudaSuccess) {
|
||||
// clear the error
|
||||
cudaGetLastError();
|
||||
|
||||
fprintf(stderr, "%s: warning: failed to register %.2f MiB of pinned memory: %s\n", __func__,
|
||||
size/1024.0/1024.0, cudaGetErrorString(err));
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
|
||||
cudaError_t err = cudaHostUnregister(buffer);
|
||||
if (err != cudaSuccess) {
|
||||
// clear the error
|
||||
cudaGetLastError();
|
||||
}
|
||||
}
|
||||
|
||||
// backend registry
|
||||
GGML_CALL static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * user_data) {
|
||||
ggml_backend_t cuda_backend = ggml_backend_cuda_init((int) (intptr_t) user_data);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue