use cuda_pool_alloc in ggml_cuda_op_mul_mat

This commit is contained in:
slaren 2023-12-25 20:44:10 +01:00
parent 865d042d56
commit 32304d796f

View file

@ -6641,10 +6641,8 @@ static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) {
return ptr; return ptr;
} }
static void ggml_cuda_pool_free_leg(void * ptr, size_t size) { static void ggml_cuda_pool_free_leg(int id, void * ptr, size_t size) {
scoped_spin_lock lock(g_cuda_pool_lock); scoped_spin_lock lock(g_cuda_pool_lock);
int id;
CUDA_CHECK(cudaGetDevice(&id));
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
cuda_buffer& b = g_cuda_buffer_pool[id][i]; cuda_buffer& b = g_cuda_buffer_pool[id][i];
@ -6731,10 +6729,8 @@ static void * ggml_cuda_pool_malloc_vmm(size_t size, size_t * actual_size) {
return ptr; return ptr;
} }
static void ggml_cuda_pool_free_vmm(void * ptr, size_t size) { static void ggml_cuda_pool_free_vmm(int id, void * ptr, size_t size) {
scoped_spin_lock lock(g_cuda_pool_lock); scoped_spin_lock lock(g_cuda_pool_lock);
int id;
CUDA_CHECK(cudaGetDevice(&id));
#ifdef DEBUG_CUDA_MALLOC #ifdef DEBUG_CUDA_MALLOC
printf("cuda pool[%d]: freed %llu bytes at %llx\n", id, (unsigned long long) size, ptr); printf("cuda pool[%d]: freed %llu bytes at %llx\n", id, (unsigned long long) size, ptr);
@ -6756,13 +6752,11 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
} }
} }
static void ggml_cuda_pool_free(void * ptr, size_t size) { static void ggml_cuda_pool_free(int id, void * ptr, size_t size) {
int id;
CUDA_CHECK(cudaGetDevice(&id));
if (g_device_caps[id].vmm) { if (g_device_caps[id].vmm) {
ggml_cuda_pool_free_vmm(ptr, size); ggml_cuda_pool_free_vmm(id, ptr, size);
} else { } else {
ggml_cuda_pool_free_leg(ptr, size); ggml_cuda_pool_free_leg(id, ptr, size);
} }
} }
#else #else
@ -6772,12 +6766,14 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
template<typename T> template<typename T>
struct cuda_pool_alloc { struct cuda_pool_alloc {
int device = -1;
T * ptr = nullptr; T * ptr = nullptr;
size_t actual_size = 0; size_t actual_size = 0;
// size is in number of elements // size is in number of elements
T * alloc(size_t size) { T * alloc(size_t size) {
GGML_ASSERT(ptr == nullptr); GGML_ASSERT(ptr == nullptr);
CUDA_CHECK(cudaGetDevice(&device));
ptr = (T *) ggml_cuda_pool_malloc(size * sizeof(T), &this->actual_size); ptr = (T *) ggml_cuda_pool_malloc(size * sizeof(T), &this->actual_size);
return ptr; return ptr;
} }
@ -6788,7 +6784,7 @@ struct cuda_pool_alloc {
~cuda_pool_alloc() { ~cuda_pool_alloc() {
if (ptr != nullptr) { if (ptr != nullptr) {
ggml_cuda_pool_free(ptr, actual_size); ggml_cuda_pool_free(device, ptr, actual_size);
} }
} }
@ -8068,27 +8064,29 @@ static void ggml_cuda_op_mul_mat(
GGML_ASSERT(!(split && ne03 > 1)); GGML_ASSERT(!(split && ne03 > 1));
GGML_ASSERT(!(split && ne02 < ne12)); GGML_ASSERT(!(split && ne02 < ne12));
// dd = data device struct dev_data {
char * src0_dd[GGML_CUDA_MAX_DEVICES] = {nullptr}; cuda_pool_alloc<char> src0_dd_alloc;
float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float cuda_pool_alloc<float> src1_ddf_alloc;
char * src1_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // q8_1 cuda_pool_alloc<char> src1_ddq_alloc;
float * dst_dd[GGML_CUDA_MAX_DEVICES] = {nullptr}; cuda_pool_alloc<float> dst_dd_alloc;
// as = actual size char * src0_dd = nullptr;
size_t src0_as[GGML_CUDA_MAX_DEVICES] = {0}; float * src1_ddf = nullptr; // float
size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0}; char * src1_ddq = nullptr; // q8_1
size_t src1_asq[GGML_CUDA_MAX_DEVICES] = {0}; float * dst_dd = nullptr;
size_t dst_as[GGML_CUDA_MAX_DEVICES] = {0};
int64_t row_low[GGML_CUDA_MAX_DEVICES]; int64_t row_low;
int64_t row_high[GGML_CUDA_MAX_DEVICES]; int64_t row_high;
};
dev_data dev[GGML_CUDA_MAX_DEVICES];
int used_devices = 0; int used_devices = 0;
for (int64_t id = 0; id < g_device_count; ++id) { for (int64_t id = 0; id < g_device_count; ++id) {
// by default, use all rows // by default, use all rows
row_low[id] = 0; dev[id].row_low = 0;
row_high[id] = ne01; dev[id].row_high = ne01;
// for multi GPU, get the row boundaries from tensor split // for multi GPU, get the row boundaries from tensor split
// and round to mul_mat_q tile sizes // and round to mul_mat_q tile sizes
@ -8096,23 +8094,23 @@ static void ggml_cuda_op_mul_mat(
const int64_t rounding = get_row_rounding(src0->type); const int64_t rounding = get_row_rounding(src0->type);
if (id != 0) { if (id != 0) {
row_low[id] = ne01*g_tensor_split[id]; dev[id].row_low = ne01*g_tensor_split[id];
if (row_low[id] < ne01) { if (dev[id].row_low < ne01) {
row_low[id] -= row_low[id] % rounding; dev[id].row_low -= dev[id].row_low % rounding;
} }
} }
if (id != g_device_count - 1) { if (id != g_device_count - 1) {
row_high[id] = ne01*g_tensor_split[id + 1]; dev[id].row_high = ne01*g_tensor_split[id + 1];
if (row_high[id] < ne01) { if (dev[id].row_high < ne01) {
row_high[id] -= row_high[id] % rounding; dev[id].row_high -= dev[id].row_high % rounding;
} }
} }
} }
} }
for (int64_t id = 0; id < g_device_count; ++id) { for (int id = 0; id < g_device_count; ++id) {
if ((!split && id != g_main_device) || row_low[id] == row_high[id]) { if ((!split && id != g_main_device) || dev[id].row_low == dev[id].row_high) {
continue; continue;
} }
@ -8122,35 +8120,34 @@ static void ggml_cuda_op_mul_mat(
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device; const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
ggml_cuda_set_device(id); ggml_cuda_set_device(id);
const cudaStream_t stream = g_cudaStreams[id][0]; cudaStream_t stream = g_cudaStreams[id][0];
if (src0_on_device && src0_is_contiguous) { if (src0_on_device && src0_is_contiguous) {
src0_dd[id] = (char *) src0_extra->data_device[id]; dev[id].src0_dd = (char *) src0_extra->data_device[id];
} else { } else {
// const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0); dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ggml_nbytes(src0));
src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]);
} }
if (src1_on_device && src1_is_contiguous) { if (src1_on_device && src1_is_contiguous) {
src1_ddf[id] = (float *) src1_extra->data_device[id]; dev[id].src1_ddf = (float *) src1_extra->data_device[id];
} else { } else {
src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]); dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ggml_nelements(src1));
} }
if (convert_src1_to_q8_1) { if (convert_src1_to_q8_1) {
src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]); dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
if (src1_on_device && src1_is_contiguous) { if (src1_on_device && src1_is_contiguous) {
quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream); quantize_row_q8_1_cuda(dev[id].src1_ddf, dev[id].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
} }
} }
if (dst_on_device) { if (dst_on_device) {
dst_dd[id] = (float *) dst_extra->data_device[id]; dev[id].dst_dd = (float *) dst_extra->data_device[id];
} else { } else {
const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst); const size_t size_dst_ddf = split ? (dev[id].row_high-dev[id].row_low)*ne1 : ggml_nelements(dst);
dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]); dev[id].dst_dd = dev[id].dst_dd_alloc.alloc(size_dst_ddf);
} }
} }
@ -8167,16 +8164,16 @@ static void ggml_cuda_op_mul_mat(
const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride; const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
for (int64_t id = 0; id < g_device_count; ++id) { for (int64_t id = 0; id < g_device_count; ++id) {
if ((!split && id != g_main_device) || row_low[id] == row_high[id]) { if ((!split && id != g_main_device) || dev[id].row_low == dev[id].row_high) {
continue; continue;
} }
const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device; const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device; const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
const int64_t row_diff = row_high[id] - row_low[id]; const int64_t row_diff = dev[id].row_high - dev[id].row_low;
ggml_cuda_set_device(id); ggml_cuda_set_device(id);
const cudaStream_t stream = g_cudaStreams[id][is]; cudaStream_t stream = g_cudaStreams[id][is];
// wait for main GPU data if necessary // wait for main GPU data if necessary
if (split && (id != g_main_device || is != 0)) { if (split && (id != g_main_device || is != 0)) {
@ -8190,22 +8187,22 @@ static void ggml_cuda_op_mul_mat(
const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs; const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
// for split tensors the data begins at i0 == i0_offset_low // for split tensors the data begins at i0 == i0_offset_low
char * src0_dd_i = src0_dd[id] + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs; char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
float * src1_ddf_i = src1_ddf[id] + (i0*ne11 + src1_col_0) * ne10; float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
char * src1_ddq_i = src1_ddq[id] + src1_ddq_i_offset; char * src1_ddq_i = dev[id].src1_ddq + src1_ddq_i_offset;
float * dst_dd_i = dst_dd[id] + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff); float * dst_dd_i = dev[id].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff);
// the main device memory buffer can be on VRAM scratch, with space for all partial results // the main device memory buffer can be on VRAM scratch, with space for all partial results
// in that case an offset on dst_ddf_i is needed // in that case an offset on dst_ddf_i is needed
if (dst->backend == GGML_BACKEND_GPU && id == g_main_device) { if (dst->backend == GGML_BACKEND_GPU && id == g_main_device) {
dst_dd_i += row_low[id]; // offset is 0 if no tensor split dst_dd_i += dev[id].row_low; // offset is 0 if no tensor split
} }
// copy src0, src1 to device if necessary // copy src0, src1 to device if necessary
if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) { if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
if (id != g_main_device) { if (id != g_main_device) {
if (convert_src1_to_q8_1) { if (convert_src1_to_q8_1) {
char * src1_ddq_i_source = src1_ddq[g_main_device] + src1_ddq_i_offset; char * src1_ddq_i_source = dev[g_main_device].src1_ddq + src1_ddq_i_offset;
CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddq_i, id, src1_ddq_i_source, g_main_device, CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddq_i, id, src1_ddq_i_source, g_main_device,
src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream)); src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream));
} else { } else {
@ -8228,12 +8225,12 @@ static void ggml_cuda_op_mul_mat(
} }
if (src1_col_0 == 0 && (!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) { if (src1_col_0 == 0 && (!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) {
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, row_low[id], row_high[id], stream)); CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
} }
// do the computation // do the computation
op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i, op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
row_low[id], row_high[id], src1_ncols, src1_padded_col_size, stream); dev[id].row_low, dev[id].row_high, src1_ncols, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
// copy dst to host or other device if necessary // copy dst to host or other device if necessary
@ -8257,7 +8254,7 @@ static void ggml_cuda_op_mul_mat(
// If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results. // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
dhf_dst_i += src1_col_0*ne0 + row_low[id]; dhf_dst_i += src1_col_0*ne0 + dev[id].row_low;
#if !defined(GGML_USE_HIPBLAS) #if !defined(GGML_USE_HIPBLAS)
if (kind == cudaMemcpyDeviceToDevice && id != g_main_device) { if (kind == cudaMemcpyDeviceToDevice && id != g_main_device) {
// cudaMemcpy2DAsync may fail with copies between vmm pools of different devices // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
@ -8292,27 +8289,6 @@ static void ggml_cuda_op_mul_mat(
} }
} }
for (int64_t id = 0; id < g_device_count; ++id) {
if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
continue;
}
CUDA_CHECK(ggml_cuda_set_device(id));
// free buffers again when done
if (dst_as[id] > 0) {
ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
}
if (src1_asq[id] > 0) {
ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
}
if (src1_asf[id] > 0) {
ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
}
if (src0_as[id] > 0) {
ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
}
}
// main device waits for all other devices to be finished // main device waits for all other devices to be finished
if (split && g_device_count > 1) { if (split && g_device_count > 1) {
int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE; int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
@ -8320,7 +8296,7 @@ static void ggml_cuda_op_mul_mat(
CUDA_CHECK(ggml_cuda_set_device(g_main_device)); CUDA_CHECK(ggml_cuda_set_device(g_main_device));
for (int64_t id = 0; id < g_device_count; ++id) { for (int64_t id = 0; id < g_device_count; ++id) {
if (row_low[id] == row_high[id]) { if (dev[id].row_low == dev[id].row_high) {
continue; continue;
} }
for (int64_t is = 0; is < is_max; ++is) { for (int64_t is = 0; is < is_max; ++is) {