ggml_cuda_mul_mat_vec_p021
This commit is contained in:
parent
8c6bd319db
commit
b87178b558
3 changed files with 277 additions and 26 deletions
291
ggml-cuda.cu
291
ggml-cuda.cu
|
@ -151,6 +151,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
|||
#define CUDA_ADD_BLOCK_SIZE 256
|
||||
#define CUDA_MUL_BLOCK_SIZE 256
|
||||
#define CUDA_SILU_BLOCK_SIZE 256
|
||||
#define CUDA_CPY_BLOCK_SIZE 32
|
||||
#define CUDA_ROPE_BLOCK_SIZE 256
|
||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||
|
||||
|
@ -737,6 +738,73 @@ static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y
|
|||
}
|
||||
}
|
||||
|
||||
static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, const int ncols_y) {
|
||||
const half * x = (half *) vx;
|
||||
// const int col_x = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
// const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
|
||||
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
|
||||
|
||||
const int nrows_y = ncols_x;
|
||||
const int ncols_dst = ncols_y;
|
||||
const int nrows_dst = nrows_x;
|
||||
const int row_dst = row_x;
|
||||
|
||||
float tmp = 0.0f;
|
||||
|
||||
for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
|
||||
const int col_x = col_x0 + threadIdx.x;
|
||||
|
||||
if (col_x >= ncols_x) {
|
||||
return;
|
||||
}
|
||||
|
||||
// x is transposed and permuted
|
||||
const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x;
|
||||
const float xi = __half2float(x[ix]);
|
||||
|
||||
const int row_y = col_x;
|
||||
|
||||
|
||||
// y is not transposed but permuted
|
||||
const int iy = channel*nrows_y + row_y;
|
||||
|
||||
// dst is not transposed and not permuted
|
||||
|
||||
tmp += xi * y[iy];
|
||||
}
|
||||
const int idst = channel*ncols_dst*nrows_dst + row_dst;
|
||||
|
||||
// sum up partial sums and write back result
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[idst] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void cpy_f32_f16(const float * x, void * vdst, const int ne0, const int ne1, const int stride_1, const int stride_2) {
|
||||
const int i0 = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i1 = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int i2 = blockDim.z*blockIdx.z + threadIdx.z;
|
||||
|
||||
half * dst = (half *) vdst;
|
||||
|
||||
const int ix = i0 + i1*stride_1 + i2*stride_2;
|
||||
const int idst = i0 + i1*ne0 + i2*ne1*ne0;
|
||||
dst[idst] = __float2half(x[ix]);
|
||||
}
|
||||
|
||||
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) {
|
||||
const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
|
||||
|
||||
|
@ -942,6 +1010,21 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, const int ncols_y, cudaStream_t stream) {
|
||||
const int block_num_x = (ncols_x + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, nrows_x, nchannels_x);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, ncols_y);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_f16_cuda(const float * x, void * vdst, const int ne0, const int ne1, const int ne2,
|
||||
const int stride_1, const int stride_2, cudaStream_t stream) {
|
||||
const int block_num_x = (ne0 + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
const dim3 block_nums(block_num_x, ne1, ne2);
|
||||
const dim3 block_dims(CUDA_CPY_BLOCK_SIZE, 1, 1);
|
||||
cpy_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, vdst, ne0, ne1, stride_1, stride_2);
|
||||
}
|
||||
|
||||
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float theta_scale, cudaStream_t stream) {
|
||||
GGML_ASSERT(nrows % 2 == 0);
|
||||
const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
|
||||
|
@ -1663,18 +1746,17 @@ void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml
|
|||
}
|
||||
|
||||
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU);
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
|
||||
// if (strcmp(dst->name, "KQ") == 0 || strcmp(dst->name, "KQV") == 0) {
|
||||
// if (strcmp(dst->name, "KQ") == 0) {
|
||||
// fprintf(stderr, "(%ld, %ld, %ld, %ld) + (%ld, %ld, %ld, %ld) -> (%ld, %ld, %ld, %ld)\n",
|
||||
// src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
// src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
|
||||
// dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]);
|
||||
// return false;
|
||||
// return true;
|
||||
// }
|
||||
|
||||
// TODO: find the optimal values for these
|
||||
|
@ -1688,8 +1770,101 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
|
|||
return false;
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat_p021_f16_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
|
||||
GGML_ASSERT(!ggml_is_contiguous(src0) && !ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
|
||||
GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]);
|
||||
GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
// GGML_ASSERT(src0->backend == GGML_BACKEND_GPU && src1->backend == GGML_BACKEND_GPU);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
GGML_ASSERT(ne00 % 8 == 0);
|
||||
GGML_ASSERT(ne03 == 1);
|
||||
const size_t src0_size = ggml_nbytes(src0);
|
||||
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
const int64_t ne11 = src1->ne[1];
|
||||
const int64_t ne12 = src1->ne[2];
|
||||
const int64_t ne13 = src1->ne[3];
|
||||
GGML_ASSERT(ne10 % 4 == 0);
|
||||
GGML_ASSERT(ne13 == 1);
|
||||
GGML_ASSERT(ne12 == ne02);
|
||||
const size_t src1_size = ggml_nbytes(src1);
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
const int64_t ne3 = dst->ne[3];
|
||||
GGML_ASSERT(ne2 == ne02);
|
||||
const size_t dst_size = ggml_nbytes(dst);
|
||||
|
||||
const int64_t nb1 = dst->nb[1];
|
||||
const int64_t nb2 = dst->nb[2];
|
||||
const int64_t nb3 = dst->nb[3];
|
||||
|
||||
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
|
||||
|
||||
void * src0_ddq;
|
||||
float * src1_ddf;
|
||||
float * dst_ddf;
|
||||
if (src0->backend == GGML_BACKEND_CPU) {
|
||||
CUDA_CHECK(cudaMalloc(&src0_ddq, src0_size));
|
||||
CUDA_CHECK(cudaMemcpyAsync(src0_ddq, src0->data, src0_size, cudaMemcpyHostToDevice, cudaStream_main));
|
||||
} else {
|
||||
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
src0_ddq = src0_extra->data_device[g_main_device];
|
||||
// CUDA_CHECK(cudaMemset(src0_ddq, 0, ggml_nbytes(src0)));
|
||||
}
|
||||
if (src1->backend == GGML_BACKEND_CPU) {
|
||||
CUDA_CHECK(cudaMalloc(&src1_ddf, src1_size));
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddf, src1->data, src1_size, cudaMemcpyHostToDevice, cudaStream_main));
|
||||
} else {
|
||||
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
src1_ddf = (float *) src1_extra->data_device[g_main_device];
|
||||
// CUDA_CHECK(cudaMemset(src1_ddf, 0, ggml_nbytes(src1)));
|
||||
}
|
||||
CUDA_CHECK(cudaMalloc(&dst_ddf, dst_size));
|
||||
|
||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||
float * src1_ddf_i = src1_ddf + i11 * ne10*ne12;
|
||||
float * dst_ddf_i = dst_ddf + i11 * ne0;
|
||||
|
||||
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf_i, dst_ddf_i, ne00, ne01, ne02, ne11, cudaStream_main);
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
const int64_t i = i3*ne2*ne1 + i2*ne1 + i1;
|
||||
float * dst_ddf_i = dst_ddf + i*ne0;
|
||||
float * dhf_dst_i = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
||||
CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, ne0*sizeof(float), cudaMemcpyDeviceToHost, cudaStream_main));
|
||||
}
|
||||
}
|
||||
}
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
if (src0->backend == GGML_BACKEND_CPU) {
|
||||
CUDA_CHECK(cudaFree(src0_ddq));
|
||||
}
|
||||
if (src1->backend == GGML_BACKEND_CPU) {
|
||||
CUDA_CHECK(cudaFree(src1_ddf));
|
||||
}
|
||||
CUDA_CHECK(cudaFree(dst_ddf));
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
if (!ggml_is_contiguous(src0) && !ggml_is_contiguous(src1)) {
|
||||
ggml_cuda_mul_mat_p021_f16_f32(src0, src1, dst);
|
||||
// ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true);
|
||||
} else if (src0->type == GGML_TYPE_F32) {
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true);
|
||||
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
|
||||
if (src1->ne[1] == 1) {
|
||||
|
@ -1702,6 +1877,43 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_nelements(src0) == ggml_nelements(src1));
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
|
||||
GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
|
||||
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
GGML_ASSERT(src0->ne[3] == src1->ne[3]);
|
||||
|
||||
const int64_t nb01 = src0->nb[1];
|
||||
const int64_t nb02 = src0->nb[2];
|
||||
|
||||
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
|
||||
|
||||
const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
|
||||
float * src0_ddf = (float *) src0_extra->data_device[g_main_device];
|
||||
void * src1_ddv = src1_extra->data_device[g_main_device];
|
||||
|
||||
const int64_t stride_1 = nb01 / sizeof(float);
|
||||
GGML_ASSERT(nb01 % sizeof(float) == 0);
|
||||
const int64_t stride_2 = nb02 / sizeof(float);
|
||||
GGML_ASSERT(nb02 % sizeof(float) == 0);
|
||||
|
||||
ggml_cpy_f32_f16_cuda(src0_ddf, src1_ddv, ne00, ne01, ne02, stride_1, stride_2, cudaStream_main);
|
||||
// test<<<ggml_nelements(src0), 1, 0, cudaStream_main>>>(src0_ddf, src1_ddv);
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
(void) dst;
|
||||
}
|
||||
|
||||
void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true);
|
||||
|
@ -1718,10 +1930,9 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
|
|||
const size_t nb1 = tensor->nb[1];
|
||||
ggml_backend backend = tensor->backend;
|
||||
struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
|
||||
memset(extra, 0, sizeof(*extra));
|
||||
|
||||
for (int id = 0; id < g_device_count; ++id) {
|
||||
extra->data_device[id] = nullptr;
|
||||
|
||||
if (backend == GGML_BACKEND_GPU && id != g_main_device) {
|
||||
continue;
|
||||
}
|
||||
|
@ -1781,44 +1992,67 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
|
|||
delete extra;
|
||||
}
|
||||
|
||||
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
|
||||
if (tensor->src0 != nullptr && tensor->src0->op == GGML_OP_RESHAPE) {
|
||||
ggml_cuda_assign_buffers(tensor->src0);
|
||||
}
|
||||
|
||||
const size_t size = ggml_nbytes(tensor);
|
||||
GGML_ASSERT(size <= g_scratch_size);
|
||||
if (g_scratch_offset + size > g_scratch_size) {
|
||||
g_scratch_offset = 0;
|
||||
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
|
||||
if (tensor->src0 != nullptr) {
|
||||
const ggml_op src0_op = tensor->src0->op;
|
||||
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
|
||||
ggml_cuda_assign_buffers_impl(tensor->src0, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
tensor->backend = GGML_BACKEND_GPU;
|
||||
struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
|
||||
|
||||
bool inplace = tensor->src0 != nullptr && tensor->src0->data == tensor->data;
|
||||
const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
|
||||
tensor->op == GGML_OP_VIEW;
|
||||
const size_t size = ggml_nbytes(tensor);
|
||||
|
||||
CUDA_CHECK(cudaSetDevice(g_main_device));
|
||||
if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) {
|
||||
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra;
|
||||
extra->data_device[g_main_device] = src0_extra->data_device[g_main_device];
|
||||
} else {
|
||||
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
|
||||
size_t offset = 0;
|
||||
if (tensor->op == GGML_OP_VIEW) {
|
||||
memcpy(&offset, tensor->opt[0]->data, sizeof(size_t));
|
||||
}
|
||||
extra->data_device[g_main_device] = src0_ddc + offset;
|
||||
} else if (scratch) {
|
||||
GGML_ASSERT(size <= g_scratch_size);
|
||||
if (g_scratch_offset + size > g_scratch_size) {
|
||||
g_scratch_offset = 0;
|
||||
}
|
||||
|
||||
char * data = (char *) g_scratch_buffer;
|
||||
if (data == nullptr) {
|
||||
CUDA_CHECK(cudaMalloc(&data, g_scratch_size));
|
||||
g_scratch_buffer = data;
|
||||
}
|
||||
extra->data_device[g_main_device] = data + g_scratch_offset;
|
||||
|
||||
// fprintf(stderr, "data=%p offset=%ld data_device=%p\n", data, g_scratch_offset, extra->data_device[0]);
|
||||
g_scratch_offset += size;
|
||||
// fprintf(stderr, "%s: scratch %d, %p - %p\n",
|
||||
// tensor->name, g_scratch_index, data + g_scratch_offset, data + g_scratch_offset + size);
|
||||
|
||||
GGML_ASSERT(g_scratch_offset <= g_scratch_size);
|
||||
} else {
|
||||
void * data;
|
||||
CUDA_CHECK(cudaMalloc(&data, size));
|
||||
CUDA_CHECK(cudaMemset(data, 0, size));
|
||||
extra->data_device[g_main_device] = data;
|
||||
}
|
||||
|
||||
// fprintf(stderr, "data=%p offset=%ld data_device=%p\n", data, g_scratch_offset, extra->data_device[0]);
|
||||
g_scratch_offset += size;
|
||||
// fprintf(stderr, "%s: scratch %d, %p - %p\n",
|
||||
// tensor->name, g_scratch_index, data + g_scratch_offset, data + g_scratch_offset + size);
|
||||
|
||||
GGML_ASSERT(g_scratch_offset <= g_scratch_size);
|
||||
tensor->extra = extra;
|
||||
}
|
||||
|
||||
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
|
||||
ggml_cuda_assign_buffers_impl(tensor, true);
|
||||
}
|
||||
|
||||
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
|
||||
ggml_cuda_assign_buffers_impl(tensor, false);
|
||||
}
|
||||
|
||||
void ggml_cuda_set_main_device(int main_device) {
|
||||
if (main_device > g_device_count) {
|
||||
fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n",
|
||||
|
@ -1874,7 +2108,16 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
|||
}
|
||||
func = ggml_cuda_mul_mat;
|
||||
break;
|
||||
case GGML_OP_CPY:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_cpy;
|
||||
break;
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
|
|||
|
||||
void ggml_cuda_free_data(struct ggml_tensor * tensor);
|
||||
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
|
||||
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
||||
void ggml_cuda_set_main_device(int main_device);
|
||||
void ggml_cuda_set_scratch_size(size_t scratch_size);
|
||||
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
|
||||
|
|
11
llama.cpp
11
llama.cpp
|
@ -893,6 +893,10 @@ static bool kv_cache_init(
|
|||
ggml_set_name(cache.k, "cache_k");
|
||||
ggml_set_name(cache.v, "cache_v");
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
ggml_cuda_assign_buffers_no_scratch(cache.k);
|
||||
#endif // GGML_USE_CUBLAS
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1375,12 +1379,12 @@ static bool llama_eval_internal(
|
|||
|
||||
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
offload_func(Kcur);
|
||||
Kcur->backend = GGML_BACKEND_CPU;
|
||||
// Kcur->backend = GGML_BACKEND_CPU;
|
||||
ggml_set_name(Kcur, "Kcur");
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
offload_func(Qcur);
|
||||
Qcur->backend = GGML_BACKEND_CPU;
|
||||
// Qcur->backend = GGML_BACKEND_CPU;
|
||||
ggml_set_name(Qcur, "Qcur");
|
||||
|
||||
// store key and value to memory
|
||||
|
@ -1390,6 +1394,7 @@ static bool llama_eval_internal(
|
|||
ggml_set_name(Vcur, "Vcur");
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
|
||||
offload_func(k);
|
||||
ggml_set_name(k, "k");
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
|
@ -1405,6 +1410,7 @@ static bool llama_eval_internal(
|
|||
ggml_permute(ctx0,
|
||||
Qcur,
|
||||
0, 2, 1, 3);
|
||||
offload_func(Q);
|
||||
ggml_set_name(Q, "Q");
|
||||
|
||||
struct ggml_tensor * K =
|
||||
|
@ -1413,6 +1419,7 @@ static bool llama_eval_internal(
|
|||
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
0, 2, 1, 3);
|
||||
offload_func(K);
|
||||
ggml_set_name(K, "K");
|
||||
|
||||
// K * Q
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue