Simplify and improve CUDA graphs through use of indirect copy pointers
Previously there was complexity in the CUDA graphs implementation due frequently changing parameters to copy kernels associated with K and V cache pointers. This patch simplifies by using indirection to avoid such parameters frequently changing, avoiding the need for frequent graph updates.
This commit is contained in:
parent
0fd93cdef5
commit
a3d48e448a
5 changed files with 118 additions and 135 deletions
|
@ -232,6 +232,9 @@ extern "C" {
|
|||
GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
|
||||
GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
|
||||
|
||||
// Copy K and V cache pointers to backend
|
||||
GGML_API void ggml_backend_copy_k_cache_ptrs(const char ** host_cache_ptrs, size_t size);
|
||||
GGML_API void ggml_backend_copy_v_cache_ptrs(const char ** host_cache_ptrs, size_t size);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -2479,9 +2479,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||
|
||||
bool use_cuda_graph = true;
|
||||
bool cuda_graph_update_required = false;
|
||||
// vector of pointers to CUDA cpy kernels, which are required to identify
|
||||
// kernel parameters which need updated in the graph for each token
|
||||
std::vector<void *> ggml_cuda_cpy_fn_ptrs;
|
||||
|
||||
if (cuda_ctx->cuda_graph->graph == nullptr) {
|
||||
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
|
||||
|
@ -2527,7 +2524,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||
}
|
||||
|
||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||
cuda_ctx->cuda_graph->updated_kernel_arg.clear();
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
|
@ -2554,16 +2550,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||
#endif
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_CPY) {
|
||||
// store the copy op parameter which changes with each token.
|
||||
cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
|
||||
// store a pointer to each copy op CUDA kernel to identify it later
|
||||
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
|
||||
if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
|
||||
ggml_cuda_cpy_fn_ptrs.push_back(ptr);
|
||||
}
|
||||
}
|
||||
|
||||
if (!use_cuda_graph) {
|
||||
break;
|
||||
}
|
||||
|
@ -2653,64 +2639,23 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||
}
|
||||
|
||||
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
|
||||
|
||||
if (cuda_graph_update_required) {
|
||||
// Extract nodes from graph
|
||||
// First call with null argument gets number of nodes in graph
|
||||
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
|
||||
// Subsequent call with non-null argument gets nodes
|
||||
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
|
||||
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
|
||||
if (cuda_ctx->cuda_graph->num_nodes > 0) {
|
||||
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
|
||||
|
||||
// Loop over nodes, and extract kernel parameters from each node
|
||||
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
|
||||
cudaGraphNodeType node_type;
|
||||
CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
|
||||
if (node_type == cudaGraphNodeTypeKernel) {
|
||||
cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
|
||||
if (stat == cudaErrorInvalidDeviceFunction) {
|
||||
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
|
||||
// We don't need to update blas nodes, so clear error and move on.
|
||||
cudaGetLastError();
|
||||
} else {
|
||||
GGML_ASSERT(stat == cudaSuccess);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// One of the arguments to the copy kernel is updated for each token, hence we need to
|
||||
// replace that argument with the updated value in the CUDA graph
|
||||
if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
|
||||
int k = 0;
|
||||
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
|
||||
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
|
||||
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
|
||||
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
|
||||
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update graph executable
|
||||
cudaGraphExecUpdateResultInfo result_info;
|
||||
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
||||
if (stat == cudaErrorGraphExecUpdateFailure) {
|
||||
// Update graph executable
|
||||
cudaGraphExecUpdateResultInfo result_info;
|
||||
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
||||
if (stat == cudaErrorGraphExecUpdateFailure) {
|
||||
#ifndef NDEBUG
|
||||
GGML_CUDA_LOG_ERROR("%s: CUDA graph update failed\n", __func__);
|
||||
GGML_CUDA_LOG_ERROR("%s: CUDA graph update failed\n", __func__);
|
||||
#endif
|
||||
// The pre-existing graph exec cannot be updated due to violated constraints
|
||||
// so instead clear error and re-instantiate
|
||||
cudaGetLastError();
|
||||
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
|
||||
cuda_ctx->cuda_graph->instance = nullptr;
|
||||
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||
} else {
|
||||
GGML_ASSERT(stat == cudaSuccess);
|
||||
// The pre-existing graph exec cannot be updated due to violated constraints
|
||||
// so instead clear error and re-instantiate
|
||||
cudaGetLastError();
|
||||
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
|
||||
cuda_ctx->cuda_graph->instance = nullptr;
|
||||
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||
} else {
|
||||
GGML_ASSERT(stat == cudaSuccess);
|
||||
}
|
||||
}
|
||||
// Launch graph
|
||||
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
|
||||
|
|
|
@ -583,15 +583,11 @@ struct ggml_cuda_graph {
|
|||
}
|
||||
cudaGraph_t graph = nullptr;
|
||||
cudaGraphExec_t instance = nullptr;
|
||||
size_t num_nodes = 0;
|
||||
std::vector<cudaGraphNode_t> nodes;
|
||||
std::vector<cudaKernelNodeParams> params;
|
||||
bool disable_due_to_gpu_arch = false;
|
||||
bool disable_due_to_too_many_updates = false;
|
||||
bool disable_due_to_failed_graph_capture = false;
|
||||
int number_consecutive_updates = 0;
|
||||
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
||||
std::vector<char **> updated_kernel_arg;
|
||||
#endif
|
||||
};
|
||||
|
||||
|
|
|
@ -31,16 +31,18 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
|
|||
}
|
||||
|
||||
template <cpy_kernel_t cpy_1>
|
||||
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
||||
static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13) {
|
||||
const int nb12, const int nb13, char ** cdst_indirect, int layer_index) {
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[layer_index]: cdst_direct;
|
||||
|
||||
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
|
||||
// then combine those indices with the corresponding byte offsets to get the total offsets
|
||||
const int64_t i03 = i/(ne00 * ne01 * ne02);
|
||||
|
@ -263,16 +265,18 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
|
|||
}
|
||||
|
||||
template <cpy_kernel_t cpy_blck, int qk>
|
||||
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
||||
static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13) {
|
||||
const int nb12, const int nb13, char ** cdst_indirect, int layer_index) {
|
||||
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[layer_index]: cdst_direct;
|
||||
|
||||
const int i03 = i/(ne00 * ne01 * ne02);
|
||||
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||
|
@ -288,110 +292,128 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
|||
cpy_blck(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
static char ** k_cache_ptrs;
|
||||
static char ** v_cache_ptrs;
|
||||
|
||||
static void ggml_backend_copy_cache_ptrs(char **& backend_cache_ptrs, const char ** host_cache_ptrs, size_t size) {
|
||||
if(backend_cache_ptrs == nullptr) {
|
||||
cudaMalloc(&backend_cache_ptrs, size*sizeof(char *));
|
||||
}
|
||||
cudaMemcpy(backend_cache_ptrs, host_cache_ptrs, size*sizeof(char *), cudaMemcpyHostToDevice);
|
||||
}
|
||||
|
||||
void ggml_backend_copy_k_cache_ptrs(const char ** host_cache_ptrs, size_t size) {
|
||||
ggml_backend_copy_cache_ptrs(k_cache_ptrs, host_cache_ptrs, size);
|
||||
}
|
||||
|
||||
void ggml_backend_copy_v_cache_ptrs(const char ** host_cache_ptrs, size_t size) {
|
||||
ggml_backend_copy_cache_ptrs(v_cache_ptrs, host_cache_ptrs, size);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f16_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int layer_index) {
|
||||
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, layer_index);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int layer_index) {
|
||||
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, layer_index);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_f16_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int layer_index) {
|
||||
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, layer_index);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q8_0_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int layer_index) {
|
||||
|
||||
GGML_ASSERT(ne % QK8_0 == 0);
|
||||
const int num_blocks = ne / QK8_0;
|
||||
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, layer_index);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q4_0_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int layer_index) {
|
||||
|
||||
GGML_ASSERT(ne % QK4_0 == 0);
|
||||
const int num_blocks = ne / QK4_0;
|
||||
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, layer_index);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q4_1_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int layer_index) {
|
||||
|
||||
GGML_ASSERT(ne % QK4_1 == 0);
|
||||
const int num_blocks = ne / QK4_1;
|
||||
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, layer_index);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q5_0_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int layer_index) {
|
||||
|
||||
GGML_ASSERT(ne % QK5_0 == 0);
|
||||
const int num_blocks = ne / QK5_0;
|
||||
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, layer_index);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q5_1_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int layer_index) {
|
||||
|
||||
GGML_ASSERT(ne % QK5_1 == 0);
|
||||
const int num_blocks = ne / QK5_1;
|
||||
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, layer_index);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_iq4_nl_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int layer_index) {
|
||||
|
||||
GGML_ASSERT(ne % QK4_NL == 0);
|
||||
const int num_blocks = ne / QK4_NL;
|
||||
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, layer_index);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f16_f16_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int layer_index) {
|
||||
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, layer_index);
|
||||
}
|
||||
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
|
@ -428,26 +450,41 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|||
char * src0_ddc = (char *) src0->data;
|
||||
char * src1_ddc = (char *) src1->data;
|
||||
|
||||
// If this copy is associated with K or V caches, then use indirection of destination pointers to avoid
|
||||
// the kernel params changing for each token and hence the need for frequent compute graph updates.
|
||||
char ** dest_indirect = nullptr;
|
||||
int layer_index=-1;
|
||||
const char* k_prefix = "k_cache_view-";
|
||||
if (strncmp(src1->name, k_prefix, strlen(k_prefix)) == 0) {
|
||||
dest_indirect = k_cache_ptrs;
|
||||
layer_index = atoi(src1->name + strlen(k_prefix));
|
||||
}
|
||||
const char* v_prefix = "v_cache_view-";
|
||||
if (strncmp(src1->name, v_prefix, strlen(v_prefix)) == 0) {
|
||||
dest_indirect = v_cache_ptrs;
|
||||
layer_index = atoi(src1->name + strlen(v_prefix));
|
||||
}
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_indirect, layer_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_indirect, layer_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_indirect, layer_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_indirect, layer_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_indirect, layer_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_indirect, layer_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_indirect, layer_index);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_indirect, layer_index);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_indirect, layer_index);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_indirect, layer_index);
|
||||
} else {
|
||||
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
|
@ -459,31 +496,3 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
const ggml_tensor * src0 = dst->src[0];
|
||||
ggml_cuda_cpy(ctx, src0, dst);
|
||||
}
|
||||
|
||||
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f32_f32>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
|
||||
} else {
|
||||
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14734,6 +14734,36 @@ static int llama_decode_internal(
|
|||
|
||||
ggml_backend_sched_alloc_graph(lctx.sched, gf);
|
||||
|
||||
#ifdef GGML_USE_CUDA
|
||||
// Copy K and V cache pointers to backend
|
||||
|
||||
// Stage pointers for each layer in host vectors
|
||||
std::vector<const char *> k_cache_ptrs;
|
||||
std::vector<const char *> v_cache_ptrs;
|
||||
const int64_t n_layer = model.hparams.n_layer;
|
||||
const int64_t kv_head = kv_self.head;
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// K cache pointer for this layer
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||
ggml_tensor * tmp_tensor = kv_self.k_l[il];
|
||||
size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
|
||||
k_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
|
||||
// V cache pointer for this layer
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
tmp_tensor = kv_self.v_l[il];
|
||||
if (cparams.flash_attn) {
|
||||
tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
|
||||
} else {
|
||||
tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
|
||||
}
|
||||
v_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
|
||||
}
|
||||
|
||||
// copy host vector data to backend
|
||||
ggml_backend_copy_k_cache_ptrs(k_cache_ptrs.data(), k_cache_ptrs.size());
|
||||
ggml_backend_copy_v_cache_ptrs(v_cache_ptrs.data(), v_cache_ptrs.size());
|
||||
#endif
|
||||
|
||||
llama_set_inputs(lctx, u_batch);
|
||||
|
||||
llama_graph_compute(lctx, gf, n_threads);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue