llama : custom attention mask + parallel decoding + no context swaps (#3228)
* tests : verify that RoPE is "additive" * llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask) * ggml : ggml_rope now takes a vector with positions instead of n_past * metal : add rope_f16 kernel + optimize cpy kernels * llama : unified KV cache + batch inference API * llama : add new llama_decode() API that works with llama_batch * llama : add cell_max heuristic for more efficient kv_cache * llama : extend llama_kv_cache API * llama : more robust cell_max heuristic + wip shift * metal : disable concurrency optimization * llama : add llama_kv_cache_shift_seq + no more context swaps * llama : apply K-cache roping for Falcon and Baichuan * speculative : fix KV cache management * parallel : example for serving multiple users in parallel * parallel : disable hot-plug to avoid cache fragmentation * fixes : speculative KV cache + llama worst-case graph * llama : extend batch API to select which logits to output * llama : fix worst case graph build * ggml-cuda : update rope implementation for parallel decoding (#3254) * ggml-cuda : update rope implementation for parallel decoding * better solution for p0 computation * fix rope * simpler rope implementation --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * make : add parallel to build + fix static functions in llama.cpp * simple : fix token counting * parallel : various improvements * llama : fix cell_max logic + rename functions * parallel : try smaller batches when the KV cache is fragmented * parallel : fix sequence termination criteria * llama : silence errors KV cache errors * parallel : remove new line from prompt * parallel : process system prompt once + configurable paramters + llama API * parallel : remove question with short answers * parallel : count cache misses * parallel : print misses on each request * parallel : minor * llama : fix n_kv to never become 0 * parallel : rename hot-plug to continuous-batching * llama : improve llama_batch API + simplify parallel example * simple : add parallel decoding support * simple : improve comments + free batch * ggml-cuda : add rope f16, restore performance with parallel decoding (#3272) * ggml-cuda : add rope f16, restore performance * offload KQ_mask with all models * fix rope shift --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * llama : disable MPI for now ggml-ci * train : make KQ_pos memory buffer permanent via dummy scale op * ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275) ggml-ci * parallel : fix bug (extra BOS) + smaller token_prev array * parallel : fix cases where the input prompts can overflow the batch * parallel : add disabled experimental batch chunking in powers of two * llama : llama.h formatting + comments * simple : add README.md * llama : fix kv cache heuristic when context is less than 32 * parallel : fix crash when `-n -1` * llama : simplify returns if/else branches * metal : use mm kernels for batch size > 2 * examples : utilize new llama_get_logits_ith() * examples : add example for batched decoding * examples : do not eval prompt 2 times (close #3348) * server : clear the KV cache beyond n_past before llama_decode * server : avoid context swaps by shifting the KV cache --------- Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
45855b3f1c
commit
ec893798b7
35 changed files with 2700 additions and 673 deletions
147
ggml-cuda.cu
147
ggml-cuda.cu
|
@ -4369,8 +4369,10 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
|||
}
|
||||
|
||||
// rope == RoPE == rotary positional embedding
|
||||
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
|
||||
const float p_delta, const int p_delta_rows, const float theta_scale) {
|
||||
|
||||
template<typename T, bool has_pos>
|
||||
static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
|
||||
const int p_delta_rows, const float theta_scale) {
|
||||
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (col >= ncols) {
|
||||
|
@ -4379,8 +4381,11 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
|
|||
|
||||
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
const int i = row*ncols + col;
|
||||
const int i2 = row/p_delta_rows;
|
||||
|
||||
const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
|
||||
const int p = has_pos ? pos[i2] : 0;
|
||||
const float p0 = p*freq_scale;
|
||||
const float theta = p0*powf(theta_scale, col/2);
|
||||
const float sin_theta = sinf(theta);
|
||||
const float cos_theta = cosf(theta);
|
||||
|
||||
|
@ -4391,8 +4396,9 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
|
|||
dst[i + 1] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0,
|
||||
const float p_delta, const int p_delta_rows, const float theta_scale) {
|
||||
template<typename T, bool has_pos>
|
||||
static __global__ void rope_neox(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
|
||||
const int p_delta_rows, const float theta_scale) {
|
||||
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (col >= ncols) {
|
||||
|
@ -4401,8 +4407,11 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
|
|||
|
||||
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
const int i = row*ncols + col/2;
|
||||
const int i2 = row/p_delta_rows;
|
||||
|
||||
const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
|
||||
const int p = has_pos ? pos[i2] : 0;
|
||||
const float p0 = p*freq_scale;
|
||||
const float theta = p0*powf(theta_scale, col/2);
|
||||
const float sin_theta = sinf(theta);
|
||||
const float cos_theta = cosf(theta);
|
||||
|
||||
|
@ -4413,8 +4422,8 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
|
|||
dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p0,
|
||||
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx) {
|
||||
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
|
||||
const int p_delta_rows, const float theta_scale, const int n_ctx) {
|
||||
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
const int half_n_dims = ncols/4;
|
||||
|
||||
|
@ -4424,11 +4433,13 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
|
|||
|
||||
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int i = row*ncols + col;
|
||||
const int i2 = row/p_delta_rows;
|
||||
|
||||
const float col_theta_scale = powf(theta_scale, col);
|
||||
const float p = p0 + p_delta*(row/p_delta_rows);
|
||||
// FIXME: this is likely wrong
|
||||
const int p = pos != nullptr ? pos[i2] : 0;
|
||||
|
||||
const float theta = min(p, p_delta*(n_ctx - 2))*col_theta_scale;
|
||||
const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale;
|
||||
const float sin_theta = sinf(theta);
|
||||
const float cos_theta = cosf(theta);
|
||||
|
||||
|
@ -4438,7 +4449,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
|
|||
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
||||
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
|
||||
|
||||
const float block_theta = max(p - p_delta*(n_ctx - 2), 0.f)*col_theta_scale;
|
||||
const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale;
|
||||
const float sin_block_theta = sinf(block_theta);
|
||||
const float cos_block_theta = cosf(block_theta);
|
||||
|
||||
|
@ -5389,31 +5400,41 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
|
|||
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
|
||||
}
|
||||
|
||||
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
|
||||
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
||||
template<typename T>
|
||||
static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
|
||||
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(nrows, num_blocks_x, 1);
|
||||
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
|
||||
if (pos == nullptr) {
|
||||
rope<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
|
||||
} else {
|
||||
rope<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
|
||||
}
|
||||
}
|
||||
|
||||
static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
|
||||
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
||||
template<typename T>
|
||||
static void rope_neox_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
|
||||
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(nrows, num_blocks_x, 1);
|
||||
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
|
||||
if (pos == nullptr) {
|
||||
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
|
||||
} else {
|
||||
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
|
||||
}
|
||||
}
|
||||
|
||||
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
|
||||
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
|
||||
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
|
||||
const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % 4 == 0);
|
||||
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
|
||||
const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
|
||||
const dim3 block_nums(num_blocks_x, nrows, 1);
|
||||
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale, n_ctx);
|
||||
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, n_ctx);
|
||||
}
|
||||
|
||||
static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
|
||||
|
@ -6136,14 +6157,16 @@ inline void ggml_cuda_op_rope(
|
|||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src0->type == dst->type);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
|
@ -6154,19 +6177,38 @@ inline void ggml_cuda_op_rope(
|
|||
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
|
||||
|
||||
const int32_t * pos = nullptr;
|
||||
if ((mode & 1) == 0) {
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(src1->ne[0] == ne2);
|
||||
pos = (const int32_t *) src1_dd;
|
||||
}
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
// compute
|
||||
if (is_glm) {
|
||||
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, n_ctx, main_stream);
|
||||
GGML_ASSERT(false);
|
||||
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream);
|
||||
} else if (is_neox) {
|
||||
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
|
||||
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream);
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
rope_neox_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
rope_neox_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
} else {
|
||||
rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream);
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
rope_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
rope_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
(void) src1;
|
||||
|
@ -6337,7 +6379,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_set_peer_access(const int n_tokens) {
|
||||
static void ggml_cuda_set_peer_access(const int n_tokens) {
|
||||
static bool peer_access_enabled = false;
|
||||
|
||||
const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;
|
||||
|
@ -6665,27 +6707,27 @@ static void ggml_cuda_op_mul_mat(
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add);
|
||||
}
|
||||
|
||||
void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
|
||||
}
|
||||
|
||||
void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu);
|
||||
}
|
||||
|
||||
void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
|
||||
}
|
||||
|
||||
void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
|
||||
}
|
||||
|
||||
void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
|
||||
}
|
||||
|
||||
|
@ -6706,7 +6748,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
|
|||
return false;
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
|
||||
static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
|
||||
GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
|
||||
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
|
||||
GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
|
||||
|
@ -6735,7 +6777,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
|
|||
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
|
||||
static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
|
||||
GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(!ggml_is_permuted(src0));
|
||||
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
|
||||
|
@ -6769,7 +6811,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
|
|||
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
|
||||
src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
|
||||
|
||||
|
@ -6813,11 +6855,11 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
|
||||
}
|
||||
|
||||
void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const int64_t ne = ggml_nelements(src0);
|
||||
GGML_ASSERT(ne == ggml_nelements(src1));
|
||||
|
||||
|
@ -6865,29 +6907,29 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
|
|||
(void) dst;
|
||||
}
|
||||
|
||||
void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_cpy(src0, dst, nullptr);
|
||||
(void) src1;
|
||||
}
|
||||
|
||||
void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_diag_mask_inf);
|
||||
}
|
||||
|
||||
void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_soft_max);
|
||||
}
|
||||
|
||||
void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rope);
|
||||
}
|
||||
|
||||
void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
|
||||
}
|
||||
|
||||
void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
(void) src0;
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
|
@ -7010,11 +7052,13 @@ static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
|
|||
return extra;
|
||||
}
|
||||
|
||||
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
|
||||
static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
|
||||
if (scratch && g_scratch_size == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
tensor->backend = GGML_BACKEND_GPU;
|
||||
|
||||
// recursively assign CUDA buffers until a compute tensor is found
|
||||
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
|
||||
const ggml_op src0_op = tensor->src[0]->op;
|
||||
|
@ -7026,8 +7070,6 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
|
|||
ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc);
|
||||
}
|
||||
|
||||
tensor->backend = GGML_BACKEND_GPU;
|
||||
|
||||
if (scratch && no_alloc) {
|
||||
return;
|
||||
}
|
||||
|
@ -7112,6 +7154,15 @@ void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset)
|
|||
tensor->extra = extra;
|
||||
}
|
||||
|
||||
void ggml_cuda_copy_to_device(struct ggml_tensor * tensor) {
|
||||
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
|
||||
GGML_ASSERT(ggml_is_contiguous(tensor));
|
||||
|
||||
struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
|
||||
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
||||
CUDA_CHECK(cudaMemcpy(extra->data_device[g_main_device], tensor->data, ggml_nbytes(tensor), cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
|
||||
ggml_cuda_assign_buffers_impl(tensor, true, false, false);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue