better solution for p0 computation
This commit is contained in:
parent
eec6b66ac9
commit
fb92acdd6b
2 changed files with 25 additions and 9 deletions
28
ggml-cuda.cu
28
ggml-cuda.cu
|
@ -5,7 +5,6 @@
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#if defined(GGML_USE_HIPBLAS)
|
#if defined(GGML_USE_HIPBLAS)
|
||||||
#include <hip/hip_runtime.h>
|
#include <hip/hip_runtime.h>
|
||||||
|
@ -440,6 +439,7 @@ static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullpt
|
||||||
struct ggml_tensor_extra_gpu {
|
struct ggml_tensor_extra_gpu {
|
||||||
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
|
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
|
||||||
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs
|
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs
|
||||||
|
bool copied;
|
||||||
};
|
};
|
||||||
|
|
||||||
// this is faster on Windows
|
// this is faster on Windows
|
||||||
|
@ -4356,6 +4356,14 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
||||||
}
|
}
|
||||||
|
|
||||||
// rope == RoPE == rotary positional embedding
|
// rope == RoPE == rotary positional embedding
|
||||||
|
static __global__ void compute_rope_p0(const int32_t * pos, float * p0, int n, int mode, float freq_scale) {
|
||||||
|
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (i < n) {
|
||||||
|
int p = pos[i];
|
||||||
|
p0[i] = (((mode & 1) == 0 ? p : 0)) * freq_scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float * p0,
|
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float * p0,
|
||||||
const float p_delta, const int p_delta_rows, const float theta_scale) {
|
const float p_delta, const int p_delta_rows, const float theta_scale) {
|
||||||
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||||
|
@ -6091,18 +6099,20 @@ inline void ggml_cuda_op_rope(
|
||||||
|
|
||||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||||
GGML_ASSERT(src1->ne[0] == ne2);
|
GGML_ASSERT(src1->ne[0] == ne2);
|
||||||
|
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
|
||||||
|
|
||||||
std::vector<float> p0s(ne2);
|
int id;
|
||||||
for (int64_t i = 0; i < ne2; ++i) {
|
CUDA_CHECK(cudaGetDevice(&id));
|
||||||
int n_past = ((int32_t *) src1->data)[i];
|
|
||||||
p0s[i] = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
|
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||||
|
if (!src1_extra->copied) {
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(src1_extra->data_device[id], src1->data, ggml_nbytes(src1), cudaMemcpyHostToDevice, main_stream));
|
||||||
|
src1_extra->copied = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t p0d_as = 0;
|
size_t p0d_as = 0;
|
||||||
float * p0d;
|
float * p0d = (float *) ggml_cuda_pool_malloc(ne2 * sizeof(float), &p0d_as);
|
||||||
|
compute_rope_p0<<<(ne2 + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE, CUDA_ROPE_BLOCK_SIZE, 0, main_stream>>>((int32_t*)src1_extra->data_device[id], p0d, ne2, mode, freq_scale);
|
||||||
p0d = (float *) ggml_cuda_pool_malloc(ne2 * sizeof(float), &p0d_as);
|
|
||||||
CUDA_CHECK(cudaMemcpyAsync(p0d, p0s.data(), ne2 * sizeof(float), cudaMemcpyHostToDevice, main_stream));
|
|
||||||
|
|
||||||
const bool is_neox = mode & 2;
|
const bool is_neox = mode & 2;
|
||||||
const bool is_glm = mode & 4;
|
const bool is_glm = mode & 4;
|
||||||
|
|
|
@ -2705,6 +2705,7 @@ static struct ggml_cgraph * llm_build_llama(
|
||||||
|
|
||||||
// KQ_pos - contains the positions
|
// KQ_pos - contains the positions
|
||||||
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||||
|
offload_func_kq(KQ_pos);
|
||||||
ggml_allocr_alloc(lctx.alloc, KQ_pos);
|
ggml_allocr_alloc(lctx.alloc, KQ_pos);
|
||||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||||
int * data = (int *) KQ_pos->data;
|
int * data = (int *) KQ_pos->data;
|
||||||
|
@ -2715,6 +2716,7 @@ static struct ggml_cgraph * llm_build_llama(
|
||||||
|
|
||||||
// K_shift
|
// K_shift
|
||||||
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
|
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
|
||||||
|
offload_func_kq(K_shift);
|
||||||
ggml_allocr_alloc(lctx.alloc, K_shift);
|
ggml_allocr_alloc(lctx.alloc, K_shift);
|
||||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||||
int * data = (int *) K_shift->data;
|
int * data = (int *) K_shift->data;
|
||||||
|
@ -3087,6 +3089,7 @@ static struct ggml_cgraph * llm_build_baichaun(
|
||||||
|
|
||||||
// KQ_pos - contains the positions
|
// KQ_pos - contains the positions
|
||||||
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||||
|
offload_func_kq(KQ_pos);
|
||||||
ggml_allocr_alloc(lctx.alloc, KQ_pos);
|
ggml_allocr_alloc(lctx.alloc, KQ_pos);
|
||||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||||
int * data = (int *) KQ_pos->data;
|
int * data = (int *) KQ_pos->data;
|
||||||
|
@ -3097,6 +3100,7 @@ static struct ggml_cgraph * llm_build_baichaun(
|
||||||
|
|
||||||
// K_shift
|
// K_shift
|
||||||
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
|
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
|
||||||
|
offload_func_kq(K_shift);
|
||||||
ggml_allocr_alloc(lctx.alloc, K_shift);
|
ggml_allocr_alloc(lctx.alloc, K_shift);
|
||||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||||
int * data = (int *) K_shift->data;
|
int * data = (int *) K_shift->data;
|
||||||
|
@ -3486,6 +3490,7 @@ static struct ggml_cgraph * llm_build_falcon(
|
||||||
|
|
||||||
// KQ_pos - contains the positions
|
// KQ_pos - contains the positions
|
||||||
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||||
|
offload_func_kq(KQ_pos);
|
||||||
ggml_allocr_alloc(lctx.alloc, KQ_pos);
|
ggml_allocr_alloc(lctx.alloc, KQ_pos);
|
||||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||||
int * data = (int *) KQ_pos->data;
|
int * data = (int *) KQ_pos->data;
|
||||||
|
@ -3496,6 +3501,7 @@ static struct ggml_cgraph * llm_build_falcon(
|
||||||
|
|
||||||
// K_shift
|
// K_shift
|
||||||
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
|
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
|
||||||
|
offload_func_kq(K_shift);
|
||||||
ggml_allocr_alloc(lctx.alloc, K_shift);
|
ggml_allocr_alloc(lctx.alloc, K_shift);
|
||||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||||
int * data = (int *) K_shift->data;
|
int * data = (int *) K_shift->data;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue