ggml : add ggml_mul_mat_set_prec
ggml-ci
This commit is contained in:
parent
a8d2a6f3ef
commit
18c67bdd84
4 changed files with 108 additions and 45 deletions
104
ggml-cuda.cu
104
ggml-cuda.cu
|
@ -7376,6 +7376,8 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||||
|
|
||||||
const int64_t row_diff = row_high - row_low;
|
const int64_t row_diff = row_high - row_low;
|
||||||
|
|
||||||
|
// TODO: handle dst->op_params[0] != GGML_PREC_DEFAULT
|
||||||
|
|
||||||
int id;
|
int id;
|
||||||
CUDA_CHECK(cudaGetDevice(&id));
|
CUDA_CHECK(cudaGetDevice(&id));
|
||||||
|
|
||||||
|
@ -8309,27 +8311,27 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void k_compute_batched_ptrs(
|
static __global__ void k_compute_batched_ptrs(
|
||||||
const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
|
const half * src0_as_f16, const half * src1_as_f16, char * dst,
|
||||||
const void ** ptrs_src, void ** ptrs_dst,
|
const void ** ptrs_src, void ** ptrs_dst,
|
||||||
int ne12, int ne13,
|
int64_t ne12, int64_t ne13,
|
||||||
int ne23,
|
int64_t ne23,
|
||||||
int nb02, int nb03,
|
size_t nb02, size_t nb03,
|
||||||
int nb12, int nb13,
|
size_t nb12, size_t nb13,
|
||||||
int nb2, int nb3,
|
size_t nbd2, size_t nbd3,
|
||||||
int r2, int r3) {
|
int64_t r2, int64_t r3) {
|
||||||
int i13 = blockIdx.x * blockDim.x + threadIdx.x;
|
int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
int i12 = blockIdx.y * blockDim.y + threadIdx.y;
|
int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
|
||||||
|
|
||||||
if (i13 >= ne13 || i12 >= ne12) {
|
if (i13 >= ne13 || i12 >= ne12) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
int i03 = i13 / r3;
|
int64_t i03 = i13 / r3;
|
||||||
int i02 = i12 / r2;
|
int64_t i02 = i12 / r2;
|
||||||
|
|
||||||
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
|
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
|
||||||
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
|
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
|
||||||
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
|
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
@ -8385,7 +8387,41 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
|
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
|
||||||
|
|
||||||
size_t dst_as = 0;
|
size_t dst_as = 0;
|
||||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
|
|
||||||
|
half * dst_f16 = nullptr;
|
||||||
|
char * dst_t = nullptr;
|
||||||
|
|
||||||
|
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
|
||||||
|
cudaDataType_t cu_data_type = CUDA_R_16F;
|
||||||
|
|
||||||
|
// dst strides
|
||||||
|
size_t nbd2 = dst->nb[2];
|
||||||
|
size_t nbd3 = dst->nb[3];
|
||||||
|
|
||||||
|
const half alpha_f16 = 1.0f;
|
||||||
|
const half beta_f16 = 0.0f;
|
||||||
|
|
||||||
|
const float alpha_f32 = 1.0f;
|
||||||
|
const float beta_f32 = 0.0f;
|
||||||
|
|
||||||
|
const char * alpha = (const char *) &alpha_f16;
|
||||||
|
const char * beta = (const char *) &beta_f16;
|
||||||
|
|
||||||
|
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||||
|
dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
|
||||||
|
dst_t = (char *) dst_f16;
|
||||||
|
|
||||||
|
nbd2 /= sizeof(float) / sizeof(half);
|
||||||
|
nbd3 /= sizeof(float) / sizeof(half);
|
||||||
|
} else {
|
||||||
|
dst_t = (char *) dst_ddf;
|
||||||
|
|
||||||
|
cu_compute_type = CUBLAS_COMPUTE_32F;
|
||||||
|
cu_data_type = CUDA_R_32F;
|
||||||
|
|
||||||
|
alpha = (const char *) &alpha_f32;
|
||||||
|
beta = (const char *) &beta_f32;
|
||||||
|
}
|
||||||
|
|
||||||
GGML_ASSERT(ne12 % ne02 == 0);
|
GGML_ASSERT(ne12 % ne02 == 0);
|
||||||
GGML_ASSERT(ne13 % ne03 == 0);
|
GGML_ASSERT(ne13 % ne03 == 0);
|
||||||
|
@ -8394,9 +8430,6 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
const int64_t r2 = ne12/ne02;
|
const int64_t r2 = ne12/ne02;
|
||||||
const int64_t r3 = ne13/ne03;
|
const int64_t r3 = ne13/ne03;
|
||||||
|
|
||||||
const half alpha_f16 = 1.0f;
|
|
||||||
const half beta_f16 = 0.0f;
|
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
// use cublasGemmEx
|
// use cublasGemmEx
|
||||||
{
|
{
|
||||||
|
@ -8406,12 +8439,12 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
int i02 = i12 / r2;
|
int i02 = i12 / r2;
|
||||||
|
|
||||||
CUBLAS_CHECK(
|
CUBLAS_CHECK(
|
||||||
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
ne01, ne11, ne10,
|
ne01, ne11, ne10,
|
||||||
&alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
|
alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
|
||||||
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
|
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
|
||||||
&beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
|
beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01,
|
||||||
CUBLAS_COMPUTE_16F,
|
cu_compute_type,
|
||||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8423,11 +8456,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
CUBLAS_CHECK(
|
CUBLAS_CHECK(
|
||||||
cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
ne01, ne11, ne10,
|
ne01, ne11, ne10,
|
||||||
&alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
|
alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
|
||||||
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
|
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
|
||||||
&beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC
|
beta, ( char *) dst_t, cu_data_type, ne01, dst->nb[2]/sizeof(float), // strideC
|
||||||
ne12*ne13,
|
ne12*ne13,
|
||||||
CUBLAS_COMPUTE_16F,
|
cu_compute_type,
|
||||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||||
} else {
|
} else {
|
||||||
// use cublasGemmBatchedEx
|
// use cublasGemmBatchedEx
|
||||||
|
@ -8444,24 +8477,24 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
|
|
||||||
dim3 block_dims(ne13, ne12);
|
dim3 block_dims(ne13, ne12);
|
||||||
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
||||||
src0_as_f16, src1_as_f16, dst_f16,
|
src0_as_f16, src1_as_f16, dst_t,
|
||||||
ptrs_src, ptrs_dst,
|
ptrs_src, ptrs_dst,
|
||||||
ne12, ne13,
|
ne12, ne13,
|
||||||
ne23,
|
ne23,
|
||||||
nb02, nb03,
|
nb02, nb03,
|
||||||
nb12, nb13,
|
nb12, nb13,
|
||||||
dst->nb[2], dst->nb[3],
|
nbd2, nbd3,
|
||||||
r2, r3);
|
r2, r3);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
CUBLAS_CHECK(
|
CUBLAS_CHECK(
|
||||||
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
ne01, ne11, ne10,
|
ne01, ne11, ne10,
|
||||||
&alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
|
alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
|
||||||
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
|
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
|
||||||
&beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
|
beta, ( void **) (ptrs_dst + 0*ne23), cu_data_type, ne01,
|
||||||
ne23,
|
ne23,
|
||||||
CUBLAS_COMPUTE_16F,
|
cu_compute_type,
|
||||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||||
|
|
||||||
if (ptrs_src_s != 0) {
|
if (ptrs_src_s != 0) {
|
||||||
|
@ -8473,11 +8506,14 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||||
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
|
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||||
|
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
|
||||||
|
|
||||||
|
ggml_cuda_pool_free(dst_f16, dst_as);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
||||||
ggml_cuda_pool_free(dst_f16, dst_as);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
|
8
ggml.c
8
ggml.c
|
@ -4098,6 +4098,14 @@ struct ggml_tensor * ggml_mul_mat(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_mul_mat_set_prec(
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
enum ggml_prec prec) {
|
||||||
|
const int32_t prec_i32 = (int32_t) prec;
|
||||||
|
|
||||||
|
ggml_set_op_params_i32(a, 0, prec_i32);
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_mul_mat_id
|
// ggml_mul_mat_id
|
||||||
|
|
||||||
struct ggml_tensor * ggml_mul_mat_id(
|
struct ggml_tensor * ggml_mul_mat_id(
|
||||||
|
|
12
ggml.h
12
ggml.h
|
@ -343,6 +343,12 @@ extern "C" {
|
||||||
GGML_TYPE_COUNT,
|
GGML_TYPE_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// precision
|
||||||
|
enum ggml_prec {
|
||||||
|
GGML_PREC_DEFAULT,
|
||||||
|
GGML_PREC_F32,
|
||||||
|
};
|
||||||
|
|
||||||
enum ggml_backend_type {
|
enum ggml_backend_type {
|
||||||
GGML_BACKEND_CPU = 0,
|
GGML_BACKEND_CPU = 0,
|
||||||
GGML_BACKEND_GPU = 10,
|
GGML_BACKEND_GPU = 10,
|
||||||
|
@ -1057,6 +1063,12 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
// change the precision of a matrix multiplication
|
||||||
|
// set to GGML_PREC_F32 for higher precision (useful for phi-2)
|
||||||
|
GGML_API void ggml_mul_mat_set_prec(
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
enum ggml_prec prec);
|
||||||
|
|
||||||
// indirect matrix multiplication
|
// indirect matrix multiplication
|
||||||
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
|
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
|
||||||
GGML_API struct ggml_tensor * ggml_mul_mat_id(
|
GGML_API struct ggml_tensor * ggml_mul_mat_id(
|
||||||
|
|
29
llama.cpp
29
llama.cpp
|
@ -4082,6 +4082,7 @@ static struct ggml_tensor * llm_build_ffn(
|
||||||
// if max_alibi_bias > 0 then apply ALiBi
|
// if max_alibi_bias > 0 then apply ALiBi
|
||||||
static struct ggml_tensor * llm_build_kqv(
|
static struct ggml_tensor * llm_build_kqv(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
const llama_model & model,
|
||||||
const llama_hparams & hparams,
|
const llama_hparams & hparams,
|
||||||
const llama_kv_cache & kv,
|
const llama_kv_cache & kv,
|
||||||
struct ggml_tensor * wo,
|
struct ggml_tensor * wo,
|
||||||
|
@ -4116,6 +4117,12 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||||
cb(kq, "kq", il);
|
cb(kq, "kq", il);
|
||||||
|
|
||||||
|
if (model.arch == LLM_ARCH_PHI2) {
|
||||||
|
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
|
||||||
|
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
|
||||||
|
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||||
|
}
|
||||||
|
|
||||||
if (max_alibi_bias > 0.0f) {
|
if (max_alibi_bias > 0.0f) {
|
||||||
// temporary branch until we figure out how to handle ggml_alibi through ggml_add
|
// temporary branch until we figure out how to handle ggml_alibi through ggml_add
|
||||||
kq = ggml_scale(ctx, kq, kq_scale);
|
kq = ggml_scale(ctx, kq, kq_scale);
|
||||||
|
@ -4342,7 +4349,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
@ -4525,7 +4532,7 @@ struct llm_build_context {
|
||||||
// apply ALiBi for 13B model
|
// apply ALiBi for 13B model
|
||||||
const float max_alibi_bias = model.type == MODEL_13B ? 8.0f : -1.0f;
|
const float max_alibi_bias = model.type == MODEL_13B ? 8.0f : -1.0f;
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
@ -4649,7 +4656,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
@ -4749,7 +4756,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
@ -4958,7 +4965,7 @@ struct llm_build_context {
|
||||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
// TODO: not tested, could be broken
|
// TODO: not tested, could be broken
|
||||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
@ -5049,7 +5056,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
@ -5146,7 +5153,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
@ -5240,7 +5247,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
@ -5353,7 +5360,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
@ -5470,7 +5477,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
@ -5591,7 +5598,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f, cb, il);
|
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f, cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue