This commit is contained in:
Johannes Gäßler 2024-01-12 16:40:14 +02:00 committed by GitHub
commit 04bd3ef801
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -325,7 +325,7 @@ typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * s
typedef void (*ggml_cuda_op_mul_mat_t)(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);
const int64_t src1_padded_row_size, const int64_t is, cudaStream_t stream);
typedef void (*ggml_cuda_op_flatten_t)(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream);
@ -541,11 +541,16 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
#define MUL_MAT_SRC1_COL_STRIDE 128
#define MAX_STREAMS 8
#define MAX_STREAMS 4
static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { { nullptr } };
struct ggml_tensor_extra_gpu {
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
int64_t is;
int64_t is_branch;
bool data_constant;
cudaEvent_t src0_done;
cudaEvent_t src1_done;
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs
};
@ -7278,10 +7283,10 @@ struct ggml_cuda_buffer {
size_t size = 0;
};
static ggml_cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS];
static size_t g_cuda_pool_size[GGML_CUDA_MAX_DEVICES] = {0};
static ggml_cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_STREAMS][MAX_CUDA_BUFFERS];
static size_t g_cuda_pool_size[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = {0};
static void * ggml_cuda_pool_malloc_leg(int device, size_t size, size_t * actual_size) {
static void * ggml_cuda_pool_malloc_leg(int device, int stream, size_t size, size_t * actual_size) {
scoped_spin_lock lock(g_cuda_pool_lock);
#ifdef DEBUG_CUDA_MALLOC
int nnz = 0;
@ -7290,7 +7295,7 @@ static void * ggml_cuda_pool_malloc_leg(int device, size_t size, size_t * actual
size_t best_diff = 1ull << 36;
int ibest = -1;
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
ggml_cuda_buffer& b = g_cuda_buffer_pool[device][i];
ggml_cuda_buffer& b = g_cuda_buffer_pool[device][stream][i];
if (b.ptr != nullptr) {
#ifdef DEBUG_CUDA_MALLOC
++nnz;
@ -7313,7 +7318,7 @@ static void * ggml_cuda_pool_malloc_leg(int device, size_t size, size_t * actual
}
}
if (ibest >= 0) {
ggml_cuda_buffer& b = g_cuda_buffer_pool[device][ibest];
ggml_cuda_buffer& b = g_cuda_buffer_pool[device][stream][ibest];
void * ptr = b.ptr;
*actual_size = b.size;
b.ptr = nullptr;
@ -7326,7 +7331,7 @@ static void * ggml_cuda_pool_malloc_leg(int device, size_t size, size_t * actual
ggml_cuda_set_device(device);
CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
*actual_size = look_ahead_size;
g_cuda_pool_size[device] += look_ahead_size;
g_cuda_pool_size[device][stream] += look_ahead_size;
#ifdef DEBUG_CUDA_MALLOC
fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
(uint32_t)(max_size/1024/1024), (uint32_t)(g_cuda_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
@ -7334,11 +7339,11 @@ static void * ggml_cuda_pool_malloc_leg(int device, size_t size, size_t * actual
return ptr;
}
static void ggml_cuda_pool_free_leg(int device, void * ptr, size_t size) {
static void ggml_cuda_pool_free_leg(int device, int stream, void * ptr, size_t size) {
scoped_spin_lock lock(g_cuda_pool_lock);
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
ggml_cuda_buffer& b = g_cuda_buffer_pool[device][i];
ggml_cuda_buffer& b = g_cuda_buffer_pool[device][stream][i];
if (b.ptr == nullptr) {
b.ptr = ptr;
b.size = size;
@ -7348,23 +7353,23 @@ static void ggml_cuda_pool_free_leg(int device, void * ptr, size_t size) {
fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
ggml_cuda_set_device(device);
CUDA_CHECK(cudaFree(ptr));
g_cuda_pool_size[device] -= size;
g_cuda_pool_size[device][stream] -= size;
}
#if !defined(GGML_USE_HIPBLAS)
// pool with virtual memory
static CUdeviceptr g_cuda_pool_addr[GGML_CUDA_MAX_DEVICES] = {0};
static size_t g_cuda_pool_used[GGML_CUDA_MAX_DEVICES] = {0};
static CUdeviceptr g_cuda_pool_addr[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = {0};
static size_t g_cuda_pool_used[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = {0};
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
static void * ggml_cuda_pool_malloc_vmm(int device, size_t size, size_t * actual_size) {
static void * ggml_cuda_pool_malloc_vmm(int device, int stream, size_t size, size_t * actual_size) {
scoped_spin_lock lock(g_cuda_pool_lock);
// round up the allocation size to the alignment to ensure that all allocations are aligned for all data types
const size_t alignment = 128;
size = alignment * ((size + alignment - 1) / alignment);
size_t avail = g_cuda_pool_size[device] - g_cuda_pool_used[device];
size_t avail = g_cuda_pool_size[device][stream] - g_cuda_pool_used[device][stream];
if (size > avail) {
// round up to the next multiple of the granularity
@ -7372,7 +7377,7 @@ static void * ggml_cuda_pool_malloc_vmm(int device, size_t size, size_t * actual
const size_t granularity = g_device_caps[device].vmm_granularity;
reserve_size = granularity * ((reserve_size + granularity - 1) / granularity);
GGML_ASSERT(g_cuda_pool_size[device] + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
GGML_ASSERT(g_cuda_pool_size[device][stream] + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
// allocate more physical memory
CUmemAllocationProp prop = {};
@ -7383,12 +7388,12 @@ static void * ggml_cuda_pool_malloc_vmm(int device, size_t size, size_t * actual
CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));
// reserve virtual address space (if not already reserved)
if (g_cuda_pool_addr[device] == 0) {
CU_CHECK(cuMemAddressReserve(&g_cuda_pool_addr[device], CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
if (g_cuda_pool_addr[device][stream] == 0) {
CU_CHECK(cuMemAddressReserve(&g_cuda_pool_addr[device][stream], CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
}
// map at the end of the pool
CU_CHECK(cuMemMap(g_cuda_pool_addr[device] + g_cuda_pool_size[device], reserve_size, 0, handle, 0));
CU_CHECK(cuMemMap(g_cuda_pool_addr[device][stream] + g_cuda_pool_size[device][stream], reserve_size, 0, handle, 0));
// the memory allocation handle is no longer needed after mapping
CU_CHECK(cuMemRelease(handle));
@ -7398,21 +7403,21 @@ static void * ggml_cuda_pool_malloc_vmm(int device, size_t size, size_t * actual
access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
access.location.id = device;
access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
CU_CHECK(cuMemSetAccess(g_cuda_pool_addr[device] + g_cuda_pool_size[device], reserve_size, &access, 1));
CU_CHECK(cuMemSetAccess(g_cuda_pool_addr[device][stream] + g_cuda_pool_size[device][stream], reserve_size, &access, 1));
// add to the pool
g_cuda_pool_size[device] += reserve_size;
g_cuda_pool_size[device][stream] += reserve_size;
//printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
// id, (unsigned long long) (g_cuda_pool_size[id]/1024/1024),
// (unsigned long long) (reserve_size/1024/1024));
}
GGML_ASSERT(g_cuda_pool_addr[device] != 0);
GGML_ASSERT(g_cuda_pool_addr[device][stream] != 0);
void * ptr = (void *) (g_cuda_pool_addr[device] + g_cuda_pool_used[device]);
void * ptr = (void *) (g_cuda_pool_addr[device][stream] + g_cuda_pool_used[device][stream]);
*actual_size = size;
g_cuda_pool_used[device] += size;
g_cuda_pool_used[device][stream] += size;
#ifdef DEBUG_CUDA_MALLOC
printf("cuda pool[%d]: allocated %llu bytes at %llx [%s]\n", id, (unsigned long long) size, ptr);
@ -7421,32 +7426,32 @@ static void * ggml_cuda_pool_malloc_vmm(int device, size_t size, size_t * actual
return ptr;
}
static void ggml_cuda_pool_free_vmm(int device, void * ptr, size_t size) {
static void ggml_cuda_pool_free_vmm(int device, int stream, void * ptr, size_t size) {
scoped_spin_lock lock(g_cuda_pool_lock);
#ifdef DEBUG_CUDA_MALLOC
printf("cuda pool[%d]: freed %llu bytes at %llx\n", id, (unsigned long long) size, ptr);
#endif
g_cuda_pool_used[device] -= size;
g_cuda_pool_used[device][stream] -= size;
// all deallocations must be in reverse order of the allocations
GGML_ASSERT(ptr == (void *) (g_cuda_pool_addr[device] + g_cuda_pool_used[device]));
GGML_ASSERT(ptr == (void *) (g_cuda_pool_addr[device][stream] + g_cuda_pool_used[device][stream]));
}
static void * ggml_cuda_pool_malloc(int device, size_t size, size_t * actual_size) {
static void * ggml_cuda_pool_malloc(int device, int stream, size_t size, size_t * actual_size) {
if (g_device_caps[device].vmm) {
return ggml_cuda_pool_malloc_vmm(device, size, actual_size);
return ggml_cuda_pool_malloc_vmm(device, stream, size, actual_size);
} else {
return ggml_cuda_pool_malloc_leg(device, size, actual_size);
return ggml_cuda_pool_malloc_leg(device, stream, size, actual_size);
}
}
static void ggml_cuda_pool_free(int device, void * ptr, size_t size) {
static void ggml_cuda_pool_free(int device, int stream, void * ptr, size_t size) {
if (g_device_caps[device].vmm) {
ggml_cuda_pool_free_vmm(device, ptr, size);
ggml_cuda_pool_free_vmm(device, stream, ptr, size);
} else {
ggml_cuda_pool_free_leg(device, ptr, size);
ggml_cuda_pool_free_leg(device, stream, ptr, size);
}
}
#else
@ -7457,24 +7462,26 @@ static void ggml_cuda_pool_free(int device, void * ptr, size_t size) {
template<typename T>
struct cuda_pool_alloc {
int device = -1;
int stream = -1;
T * ptr = nullptr;
size_t actual_size = 0;
// size is in number of elements
T * alloc(size_t size) {
T * alloc(int stream, size_t size) {
GGML_ASSERT(ptr == nullptr);
CUDA_CHECK(cudaGetDevice(&device));
ptr = (T *) ggml_cuda_pool_malloc(device, size * sizeof(T), &this->actual_size);
this->stream = stream;
ptr = (T *) ggml_cuda_pool_malloc(device, stream, size * sizeof(T), &this->actual_size);
return ptr;
}
cuda_pool_alloc(size_t size) {
alloc(size);
cuda_pool_alloc(int stream, size_t size) {
alloc(stream, size);
}
~cuda_pool_alloc() {
if (ptr != nullptr) {
ggml_cuda_pool_free(device, ptr, actual_size);
ggml_cuda_pool_free(device, stream, ptr, actual_size);
}
}
@ -7998,7 +8005,7 @@ static void ggml_cuda_op_rms_norm(
static void ggml_cuda_op_mul_mat_q(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream) {
const int64_t src1_padded_row_size, const int64_t is, cudaStream_t stream) {
const int64_t ne00 = src0->ne[0];
@ -8125,7 +8132,7 @@ static int64_t get_row_rounding(ggml_type type) {
static void ggml_cuda_op_mul_mat_vec_q(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream) {
const int64_t src1_padded_row_size, const int64_t is, cudaStream_t stream) {
GGML_ASSERT(ggml_nrows(src1) == 1);
@ -8184,7 +8191,7 @@ static void ggml_cuda_op_mul_mat_vec_q(
static void ggml_cuda_op_dequantize_mul_mat_vec(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream) {
const int64_t src1_padded_row_size, const int64_t is, cudaStream_t stream) {
const int64_t ne00 = src0->ne[0];
const int64_t row_diff = row_high - row_low;
@ -8202,7 +8209,7 @@ static void ggml_cuda_op_dequantize_mul_mat_vec(
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
if (src1_convert_f16) {
src1_dfloat = src1_dfloat_a.alloc(ne00);
src1_dfloat = src1_dfloat_a.alloc(is, ne00);
ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00,
ne00, 1, sizeof(float), 0, 0,
ne00, 1, sizeof(half), 0, 0, stream);
@ -8260,7 +8267,7 @@ static void ggml_cuda_op_dequantize_mul_mat_vec(
static void ggml_cuda_op_mul_mat_cublas(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream) {
const int64_t src1_padded_row_size, const int64_t is, cudaStream_t stream) {
GGML_ASSERT(src0_dd_i != nullptr);
GGML_ASSERT(src1_ddf_i != nullptr);
@ -8290,7 +8297,7 @@ static void ggml_cuda_op_mul_mat_cublas(
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
size_t ne = row_diff*ne00;
src0_as_f16.alloc(ne);
src0_as_f16.alloc(is, ne);
to_fp16_cuda(src0_dd_i, src0_as_f16.get(), ne, stream);
}
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
@ -8300,11 +8307,11 @@ static void ggml_cuda_op_mul_mat_cublas(
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
size_t ne = src1_ncols*ne10;
src1_as_f16.alloc(ne);
src1_as_f16.alloc(is, ne);
to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
}
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
cuda_pool_alloc<half> dst_f16(row_diff*src1_ncols);
cuda_pool_alloc<half> dst_f16(is, row_diff*src1_ncols);
const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;
@ -8328,13 +8335,13 @@ static void ggml_cuda_op_mul_mat_cublas(
if (src0->type != GGML_TYPE_F32) {
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
GGML_ASSERT(to_fp32_cuda != nullptr);
src0_ddq_as_f32.alloc(row_diff*ne00);
src0_ddq_as_f32.alloc(is, row_diff*ne00);
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
}
if (src1->type != GGML_TYPE_F32) {
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src1->type);
GGML_ASSERT(to_fp32_cuda != nullptr);
src1_ddq_as_f32.alloc(src1_ncols*ne10);
src1_ddq_as_f32.alloc(is, src1_ncols*ne10);
to_fp32_cuda(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
}
@ -8650,6 +8657,15 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU;
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU;
int is = 0;
if (dst_extra != nullptr) {
is = dst_extra->is;
} else if (src0_extra != nullptr) {
is = src0_extra->is;
} else if (src1_extra != nullptr) {
is = src1_extra->is;
}
// dd = data device
float * src0_ddf = nullptr;
float * src1_ddf = nullptr;
@ -8660,12 +8676,12 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
cuda_pool_alloc<float> dst_f;
ggml_cuda_set_device(g_main_device);
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
cudaStream_t main_stream = g_cudaStreams[g_main_device][is];
if (src0_on_device) {
src0_ddf = (float *) src0_extra->data_device[g_main_device];
} else {
src0_ddf = src0_f.alloc(ggml_nelements(src0));
src0_ddf = src0_f.alloc(is, ggml_nelements(src0));
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream));
}
@ -8673,14 +8689,14 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
if (src1_on_device) {
src1_ddf = (float *) src1_extra->data_device[g_main_device];
} else {
src1_ddf = src1_f.alloc(ggml_nelements(src1));
src1_ddf = src1_f.alloc(is, ggml_nelements(src1));
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf, src1, 0, 0, 0, nrows1, main_stream));
}
}
if (dst_on_device) {
dst_ddf = (float *) dst_extra->data_device[g_main_device];
} else {
dst_ddf = dst_f.alloc(ggml_nelements(dst));
dst_ddf = dst_f.alloc(is, ggml_nelements(dst));
}
// do the computation
@ -8779,6 +8795,15 @@ static void ggml_cuda_op_mul_mat(
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
int64_t is0 = 0;
if (dst_extra != nullptr) {
is0 = dst_extra->is;
} else if (src0_extra != nullptr) {
is0 = src0_extra->is;
} else if (src1_extra != nullptr) {
is0 = src1_extra->is;
}
const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
const bool src0_is_contiguous = ggml_is_contiguous(src0);
const bool src1_is_contiguous = ggml_is_contiguous(src1);
@ -8846,22 +8871,22 @@ static void ggml_cuda_op_mul_mat(
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
ggml_cuda_set_device(id);
cudaStream_t stream = g_cudaStreams[id][0];
cudaStream_t stream = g_cudaStreams[id][is0];
if (src0_on_device && src0_is_contiguous) {
dev[id].src0_dd = (char *) src0_extra->data_device[id];
} else {
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ggml_nbytes(src0));
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(is0, ggml_nbytes(src0));
}
if (src1_on_device && src1_is_contiguous) {
dev[id].src1_ddf = (float *) src1_extra->data_device[id];
} else {
dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ggml_nelements(src1));
dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(is0, ggml_nelements(src1));
}
if (convert_src1_to_q8_1) {
dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(is0, nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
if (src1_on_device && src1_is_contiguous) {
quantize_row_q8_1_cuda(dev[id].src1_ddf, dev[id].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
@ -8873,7 +8898,7 @@ static void ggml_cuda_op_mul_mat(
dev[id].dst_dd = (float *) dst_extra->data_device[id];
} else {
const size_t size_dst_ddf = split ? (dev[id].row_high - dev[id].row_low)*ne1 : ggml_nelements(dst);
dev[id].dst_dd = dev[id].dst_dd_alloc.alloc(size_dst_ddf);
dev[id].dst_dd = dev[id].dst_dd_alloc.alloc(is0, size_dst_ddf);
}
}
@ -8881,12 +8906,12 @@ static void ggml_cuda_op_mul_mat(
// here an event is recorded that signals that the main device has finished calculating the input data
if (split && used_devices > 1) {
ggml_cuda_set_device(g_main_device);
CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0]));
CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][is0], g_cudaStreams[g_main_device][is0]));
}
const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0;
const int64_t is = split ? (is0 + src1_col_0/src1_col_stride) % MAX_STREAMS : is0;
const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
for (int id = 0; id < g_device_count; ++id) {
@ -8902,8 +8927,8 @@ static void ggml_cuda_op_mul_mat(
cudaStream_t stream = g_cudaStreams[id][is];
// wait for main GPU data if necessary
if (split && (id != g_main_device || is != 0)) {
CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[g_main_device][0], 0));
if (split && (id != g_main_device || is != is0)) {
CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[g_main_device][is0], 0));
}
for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
@ -8956,7 +8981,7 @@ static void ggml_cuda_op_mul_mat(
// do the computation
op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
dev[id].row_low, dev[id].row_high, src1_ncols, src1_padded_col_size, stream);
dev[id].row_low, dev[id].row_high, src1_ncols, src1_padded_col_size, is, stream);
CUDA_CHECK(cudaGetLastError());
// copy dst to host or other device if necessary
@ -9008,7 +9033,7 @@ static void ggml_cuda_op_mul_mat(
}
// add event for the main device to wait on until other device is done
if (split && (id != g_main_device || is != 0)) {
if (split && (id != g_main_device || is != is0)) {
CUDA_CHECK(cudaEventRecord(src0_extra->events[id][is], stream));
}
}
@ -9026,7 +9051,7 @@ static void ggml_cuda_op_mul_mat(
continue;
}
for (int64_t is = 0; is < is_max; ++is) {
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0));
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][is0], src0_extra->events[id][(is0 + is) % MAX_STREAMS], 0));
}
}
}
@ -9143,7 +9168,6 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens
const int64_t ne12 = src1->ne[2];
ggml_cuda_set_device(g_main_device);
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
void * src0_ddq = src0_extra->data_device[g_main_device];
@ -9154,6 +9178,9 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
const int64_t is = dst_extra->is;
cudaStream_t main_stream = g_cudaStreams[g_main_device][is];
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
}
@ -9175,7 +9202,6 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
const int64_t ne12 = src1->ne[2];
ggml_cuda_set_device(g_main_device);
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
void * src0_ddq = src0_extra->data_device[g_main_device];
@ -9189,6 +9215,9 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
const int64_t row_stride_x = nb01 / sizeof(half);
const int64_t channel_stride_x = nb02 / sizeof(half);
const int64_t is = dst_extra->is;
cudaStream_t main_stream = g_cudaStreams[g_main_device][is];
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
}
@ -9228,9 +9257,6 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
const int64_t ne_dst = ggml_nelements(dst);
ggml_cuda_set_device(g_main_device);
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
void * src0_ddq = src0_extra->data_device[g_main_device];
@ -9242,12 +9268,16 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
const int64_t is = dst_extra->is;
cudaStream_t main_stream = g_cudaStreams[g_main_device][is];
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));
// convert src1 to fp16
cuda_pool_alloc<half> src1_f16_alloc;
if (src1->type != GGML_TYPE_F16) {
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
const int64_t ne_src1 = ggml_nelements(src1);
src1_f16_alloc.alloc(ne_src1);
src1_f16_alloc.alloc(is, ne_src1);
GGML_ASSERT(to_fp16_cuda != nullptr);
to_fp16_cuda(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
}
@ -9273,7 +9303,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
const void * beta = &beta_f16;
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
dst_t = (char *) dst_f16.alloc(ne_dst);
dst_t = (char *) dst_f16.alloc(is, ne_dst);
nbd2 /= sizeof(float) / sizeof(half);
nbd3 /= sizeof(float) / sizeof(half);
@ -9330,8 +9360,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
// use cublasGemmBatchedEx
const int ne23 = ne12*ne13;
cuda_pool_alloc<const void *> ptrs_src(2*ne23);
cuda_pool_alloc< void *> ptrs_dst(1*ne23);
cuda_pool_alloc<const void *> ptrs_src(is, 2*ne23);
cuda_pool_alloc< void *> ptrs_dst(is, 1*ne23);
dim3 block_dims(ne13, ne12);
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
@ -9525,7 +9555,7 @@ static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
const int64_t ne = ggml_nelements(dst);
ggml_cuda_set_device(g_main_device);
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
cudaStream_t main_stream = g_cudaStreams[g_main_device][is];
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));
@ -9642,7 +9672,19 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
std::vector<char> ids_host(ggml_nbytes(ids));
cudaStream_t stream = g_cudaStreams[g_main_device][0];
const ggml_tensor_extra_gpu * src0_extra = (const ggml_tensor_extra_gpu *) src0->extra;
const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
int is = 0;
if (dst_extra != nullptr) {
is = dst_extra->is;
} else if (src0_extra != nullptr) {
is = src0_extra->is;
} else if (src1_extra != nullptr) {
is = src1_extra->is;
}
cudaStream_t stream = g_cudaStreams[g_main_device][is];
if (ids->backend == GGML_BACKEND_GPU) {
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
@ -9652,12 +9694,12 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
}
const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
ggml_tensor_extra_gpu src1_row_extra;
ggml_tensor_extra_gpu dst_row_extra;
src1_row_extra.is = src1_extra->is;
dst_row_extra.is = dst_extra->is;
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;
@ -9696,8 +9738,8 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
}
} else {
cuda_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
cuda_pool_alloc<char> dst_contiguous(sizeof(float)*ggml_nelements(dst));
cuda_pool_alloc<char> src1_contiguous(is, sizeof(float)*ggml_nelements(src1));
cuda_pool_alloc<char> dst_contiguous(is, sizeof(float)*ggml_nelements(dst));
src1_row_extra.data_device[g_main_device] = src1_contiguous.get();
dst_row_extra.data_device[g_main_device] = dst_contiguous.get();
@ -9799,10 +9841,20 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
const int64_t nb12 = src1->nb[2];
ggml_cuda_set_device(g_main_device);
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
const ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
const ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
const ggml_tensor_extra_gpu * dst_extra = dst == nullptr ? nullptr : (ggml_tensor_extra_gpu *) dst->extra;
int64_t is = 0;
if (dst_extra != nullptr) {
is = dst_extra->is;
} else if (src0_extra != nullptr) {
is = src0_extra->is;
} else if (src1_extra != nullptr) {
is = src1_extra->is;
}
cudaStream_t main_stream = g_cudaStreams[g_main_device][is];
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
@ -9887,6 +9939,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
ggml_backend_type backend = tensor->backend;
ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
memset(extra, 0, sizeof(*extra));
extra->data_constant = true;
for (int id = 0; id < g_device_count; ++id) {
if (backend == GGML_BACKEND_GPU && id != g_main_device) {
@ -9981,12 +10034,20 @@ static size_t g_temp_tensor_extra_index = 0;
static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
if (g_temp_tensor_extras == nullptr) {
g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES];
for (int64_t i = 0; i < GGML_CUDA_MAX_NODES; ++i) {
CUDA_CHECK(cudaEventCreateWithFlags(&g_temp_tensor_extras[i].src0_done, cudaEventDisableTiming));
CUDA_CHECK(cudaEventCreateWithFlags(&g_temp_tensor_extras[i].src1_done, cudaEventDisableTiming));
}
}
size_t alloc_index = g_temp_tensor_extra_index;
g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_CUDA_MAX_NODES;
ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index];
memset(extra, 0, sizeof(*extra));
cudaEvent_t src0_done = extra->src0_done;
cudaEvent_t src1_done = extra->src1_done;
memset(extra, 0, sizeof(ggml_tensor_extra_gpu));
extra->src0_done = src0_done;
extra->src1_done = src1_done;
return extra;
}
@ -10304,6 +10365,91 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return true;
}
if (tensor->src[0] != nullptr) {
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) tensor->src[0]->extra;
ggml_tensor_extra_gpu * extra_src1 = tensor->src[1] == nullptr ? nullptr : (ggml_tensor_extra_gpu *) tensor->src[1]->extra;
ggml_tensor_extra_gpu * extra_dst = (ggml_tensor_extra_gpu *) tensor->extra;
if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
extra_src1->is = extra_src0->is_branch;
extra_src1->is_branch = extra_src0->is_branch;
if (tensor->src[1]->op == GGML_OP_VIEW) {
ggml_tensor_extra_gpu * extra_src10 = (ggml_tensor_extra_gpu *) tensor->src[1]->src[0]->extra;
extra_src10->is = extra_src0->is_branch;
extra_src10->is_branch = extra_src0->is_branch;
}
extra_src0->is_branch = (extra_src0->is_branch + 1) % MAX_STREAMS;
} else {
// fprintf(stderr, "cpy/dup %s -> %s, is=%ld,%ld \n", tensor->src[0]->name, tensor->src[1]->name, extra_src0->is, extra_src1->is);
// fprintf(stderr, "src0: %s\n", tensor->src[0]->name);
// fprintf(stderr, "dst: %s\n", tensor->name);
bool is_set = false;
if (extra_src0 != nullptr && !extra_src0->data_constant) {
if (!is_set && extra_dst != nullptr) {
// fprintf(stderr, "is set from src0: %ld -> %ld \n", extra_dst->is_branch, extra_src0->is_branch);
extra_dst->is = extra_src0->is_branch;
extra_dst->is_branch = extra_src0->is_branch;
is_set = true;
}
extra_src0->is_branch = (extra_src0->is_branch + 1) % MAX_STREAMS;
}
if (tensor->src[1] != nullptr && tensor->src[1]->extra != nullptr && !((ggml_tensor_extra_gpu *) tensor->src[1]->extra)->data_constant) {
if (!is_set && extra_dst != nullptr) {
// fprintf(stderr, "is set from src1: %ld -> %ld \n", extra_dst->is_branch, extra_src1->is_branch);
extra_dst->is = extra_src1->is_branch;
extra_dst->is_branch = extra_src1->is_branch;
is_set = true;
}
extra_src1->is_branch = (extra_src1->is_branch + 1) % MAX_STREAMS;
}
}
int64_t is = 0;
if (extra_dst != nullptr) {
is = extra_dst->is;
} else if (extra_src0 != nullptr) {
is = extra_src0->is;
} else if (extra_src1 != nullptr) {
is = extra_src1->is;
}
if (extra_src0 != nullptr && extra_src0->is != is) {
CUDA_CHECK(cudaEventRecord(extra_dst->src0_done, g_cudaStreams[g_main_device][extra_src0->is]));
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][is], extra_dst->src0_done, 0));
}
if (tensor->src[1] != nullptr && extra_src1 != nullptr && extra_src1->is != is) {
CUDA_CHECK(cudaEventRecord(extra_dst->src1_done, g_cudaStreams[g_main_device][extra_src1->is]));
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][extra_dst->is], extra_dst->src1_done, 0));
}
// if (is != 0) {
// fprintf(stderr, "%s + %s -> %s is=%ld\n", tensor->src[0]->name, tensor->src[1] == nullptr ? "null" : tensor->src[1]->name, tensor->name, is);
// }
// fprintf(stderr, "%s: is=%ld\n", tensor->name, extra_dst->is);
}
// fprintf(stderr, "\n");
// if (tensor->src[1] != nullptr) {
// CUDA_CHECK(cudaDeviceSynchronize());
// }
// char * p0 = (char *) ((ggml_tensor_extra_gpu *) tensor->extra)->data_device;
// fprintf(stderr, "%s\n %p - %p\n", tensor->name, p0, p0 + ggml_nbytes(tensor));
// if (tensor->src[0] != nullptr && tensor->src[0]->extra != nullptr && ((ggml_tensor_extra_gpu *) tensor->src[0]->extra)->is != 0) {
// fprintf(stderr, "tensor=%s: src0=%s\n", tensor->name, tensor->src[0]->name);
// }
// if (tensor->src[1] != nullptr && tensor->src[1]->extra != nullptr && ((ggml_tensor_extra_gpu *) tensor->src[1]->extra)->is != 0) {
// fprintf(stderr, "tensor=%s: src1=%s\n", tensor->name, tensor->src[1]->name);
// }
// if (tensor->extra != nullptr && ((ggml_tensor_extra_gpu *) tensor->extra)->is != 0) {
// fprintf(stderr, "tensor=%s: dst=%s\n", tensor->name, tensor->name);
// }
func(tensor->src[0], tensor->src[1], tensor);
return true;
}