Merge branch 'server_branch' of https://github.com/pudepiedj/llama.cpp into server_branch
This commit is contained in:
commit
e8c37fd893
9 changed files with 126 additions and 86 deletions
9
Makefile
9
Makefile
|
@ -381,8 +381,13 @@ ifdef LLAMA_BLIS
|
|||
endif # LLAMA_BLIS
|
||||
|
||||
ifdef LLAMA_CUBLAS
|
||||
MK_CPPFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include -I/usr/local/cuda/targets/aarch64-linux/include
|
||||
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib -L/usr/lib/wsl/lib
|
||||
ifneq ('', '$(wildcard /opt/cuda)')
|
||||
CUDA_PATH ?= /opt/cuda
|
||||
else
|
||||
CUDA_PATH ?= /usr/local/cuda
|
||||
endif
|
||||
MK_CPPFLAGS += -DGGML_USE_CUBLAS -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
|
||||
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
|
||||
OBJS += ggml-cuda.o
|
||||
MK_NVCCFLAGS += -use_fast_math
|
||||
ifdef LLAMA_FATAL_WARNINGS
|
||||
|
|
|
@ -335,6 +335,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||
break;
|
||||
}
|
||||
params.yarn_beta_slow = std::stof(argv[i]);
|
||||
} else if (arg == "--defrag-thold" || arg == "-dt") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.defrag_thold = std::stof(argv[i]);
|
||||
} else if (arg == "--samplers") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -1004,6 +1010,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
|
||||
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
|
||||
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
|
||||
printf(" -dt N, --defrag-thold N\n");
|
||||
printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
|
||||
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
|
||||
printf(" --no-penalize-nl do not penalize newline token\n");
|
||||
printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp);
|
||||
|
@ -1285,6 +1293,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
|||
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
||||
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
||||
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
|
||||
cparams.defrag_thold = params.defrag_thold;
|
||||
cparams.offload_kqv = !params.no_kv_offload;
|
||||
|
||||
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
||||
|
|
|
@ -75,6 +75,7 @@ struct gpt_params {
|
|||
float yarn_beta_fast = 32.0f; // YaRN low correction dim
|
||||
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
||||
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
||||
float defrag_thold = -1.0f; // KV cache defragmentation threshold
|
||||
int32_t rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
|
||||
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
|
||||
|
||||
|
|
|
@ -182,7 +182,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||
llama_kv_cache_defrag (ctx);
|
||||
//llama_kv_cache_defrag (ctx);
|
||||
llama_kv_cache_update (ctx);
|
||||
|
||||
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
||||
|
@ -213,7 +213,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||
llama_kv_cache_defrag (ctx);
|
||||
//llama_kv_cache_defrag (ctx);
|
||||
llama_kv_cache_update (ctx);
|
||||
|
||||
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
||||
|
|
|
@ -1428,6 +1428,10 @@ struct llama_server_context
|
|||
split_multiprompt_task(task_id, task);
|
||||
}
|
||||
} else {
|
||||
// an empty prompt can make slot become buggy
|
||||
if (task.data.contains("prompt") && task.data["prompt"].is_string() && task.data["prompt"].get<std::string>().empty()) {
|
||||
task.data["prompt"] = " "; // add a space so that we have one token
|
||||
}
|
||||
queue_tasks.post(task);
|
||||
}
|
||||
}
|
||||
|
|
73
ggml-cuda.cu
73
ggml-cuda.cu
|
@ -696,18 +696,20 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
|||
return a;
|
||||
}
|
||||
|
||||
//static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||
//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
//#pragma unroll
|
||||
// for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
// a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
||||
// }
|
||||
// return a;
|
||||
//#else
|
||||
// (void) a;
|
||||
// NO_DEVICE_CODE;
|
||||
//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
//}
|
||||
#ifdef GGML_CUDA_F16
|
||||
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
||||
}
|
||||
return a;
|
||||
#else
|
||||
(void) a;
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
}
|
||||
#endif // GGML_CUDA_F16
|
||||
|
||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||
#pragma unroll
|
||||
|
@ -2521,10 +2523,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
|
|||
#endif
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[row] = tmp;
|
||||
|
@ -2625,10 +2624,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
|
|||
#endif
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[row] = tmp;
|
||||
|
@ -2761,10 +2757,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
|
|||
#endif
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (tid == 0) {
|
||||
dst[row] = tmp;
|
||||
|
@ -2877,10 +2870,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
|
|||
#endif
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[row] = tmp;
|
||||
|
@ -2987,10 +2977,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
|
|||
#endif
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (tid == 0) {
|
||||
dst[row] = tmp;
|
||||
|
@ -3025,11 +3012,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
|
|||
float amax = fabsf(xi);
|
||||
float sum = xi;
|
||||
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
|
||||
sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
|
||||
}
|
||||
amax = warp_reduce_max(amax);
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
const float d = amax / 127;
|
||||
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
|
||||
|
@ -6222,10 +6206,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
|
|||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (tid == 0) {
|
||||
#ifdef GGML_CUDA_F16
|
||||
|
@ -6275,10 +6256,7 @@ static __global__ void mul_mat_p021_f16_f32(
|
|||
const int idst = channel*nrows_dst + row_dst;
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[idst] = tmp;
|
||||
|
@ -6321,10 +6299,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
|||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[idst] = tmp;
|
||||
|
|
|
@ -10248,8 +10248,12 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
|||
|
||||
uint64_t aux64;
|
||||
|
||||
__m256i v_gindex;
|
||||
const uint16_t * gindex = (const uint16_t *)&v_gindex;
|
||||
typedef union m256i_uint16 {
|
||||
__m256i reg;
|
||||
uint16_t s[16];
|
||||
} m256i_uint16_t;
|
||||
|
||||
m256i_uint16_t v_gindex;
|
||||
|
||||
__m256 accum = _mm256_setzero_ps();
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
@ -10264,13 +10268,13 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
|||
memcpy(&aux64, sc, 8); sc += 8;
|
||||
const __m128i qh = _mm_shuffle_epi8(_mm_set_epi64x(aux64 >> 4, aux64), shuffle_h);
|
||||
const __m256i hbit = _mm256_cvtepu8_epi16(_mm_and_si128(qh, m8));
|
||||
v_gindex = _mm256_or_si256(_mm256_cvtepu8_epi16(ql), _mm256_slli_epi16(hbit, 5));
|
||||
v_gindex.reg = _mm256_or_si256(_mm256_cvtepu8_epi16(ql), _mm256_slli_epi16(hbit, 5));
|
||||
const __m128i scales = _mm_or_si128(_mm_slli_epi16(_mm_and_si128(qh, m7), 1), m1);
|
||||
|
||||
for (int i32 = 0; i32 < 4; ++i32) {
|
||||
const __m256i q8b = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
const __m256i q1b = _mm256_set_epi64x(iq1s_grid[gindex[4*i32+3]], iq1s_grid[gindex[4*i32+2]],
|
||||
iq1s_grid[gindex[4*i32+1]], iq1s_grid[gindex[4*i32+0]]);
|
||||
const __m256i q1b = _mm256_set_epi64x(iq1s_grid[v_gindex.s[4*i32+3]], iq1s_grid[v_gindex.s[4*i32+2]],
|
||||
iq1s_grid[v_gindex.s[4*i32+1]], iq1s_grid[v_gindex.s[4*i32+0]]);
|
||||
const __m256i dot = mul_add_epi8(q1b, q8b);
|
||||
const __m256i s16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, shuffle_s[i32]));
|
||||
const __m256i p = _mm256_madd_epi16(s16, dot);
|
||||
|
|
97
llama.cpp
97
llama.cpp
|
@ -1641,6 +1641,7 @@ struct llama_cparams {
|
|||
float yarn_attn_factor;
|
||||
float yarn_beta_fast;
|
||||
float yarn_beta_slow;
|
||||
float defrag_thold;
|
||||
|
||||
bool mul_mat_q;
|
||||
bool offload_kqv;
|
||||
|
@ -5117,16 +5118,16 @@ struct llm_build_context {
|
|||
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
const int id = ids[i];
|
||||
for (uint32_t i = 0; i < ids.size(); ++i) {
|
||||
const uint32_t id = ids[i];
|
||||
|
||||
if (i == id || id == n_kv) {
|
||||
if (i == id || id == ids.size()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int nm = 1;
|
||||
uint32_t nm = 1;
|
||||
|
||||
while (i + nm < n_kv && (int) ids[i + nm] == id + nm) {
|
||||
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
|
||||
nm++;
|
||||
}
|
||||
|
||||
|
@ -5158,6 +5159,8 @@ struct llm_build_context {
|
|||
i += nm - 1;
|
||||
}
|
||||
|
||||
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
|
@ -7938,6 +7941,8 @@ static int llama_decode_internal(
|
|||
batch.seq_id = seq_id_arr.data();
|
||||
}
|
||||
|
||||
llama_kv_cache_update(&lctx);
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (kv_self.head > kv_self.used + 2*n_tokens) {
|
||||
|
@ -7956,8 +7961,6 @@ static int llama_decode_internal(
|
|||
// line above and below originally commented out
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
|
||||
llama_kv_cache_update(&lctx);
|
||||
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
||||
|
||||
|
@ -8007,6 +8010,18 @@ static int llama_decode_internal(
|
|||
}
|
||||
}
|
||||
|
||||
// decide if we need to defrag the kv cache
|
||||
if (cparams.defrag_thold >= 0.0f) {
|
||||
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used + n_tokens)/float(kv_self.n) : 0.0f;
|
||||
|
||||
// queue defragmentation for next llama_kv_cache_update
|
||||
if (fragmentation > cparams.defrag_thold) {
|
||||
//LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
|
||||
|
||||
llama_kv_cache_defrag(kv_self);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_PERF
|
||||
// print timing information per ggml operation (for debugging purposes)
|
||||
// requires GGML_PERF to be defined
|
||||
|
@ -8098,12 +8113,16 @@ static int llama_decode_internal(
|
|||
static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
|
||||
const uint32_t n_kv = llama_kv_cache_cell_max(kv_self);
|
||||
const uint32_t n_used = kv_self.used;
|
||||
|
||||
assert(n_used <= n_kv);
|
||||
|
||||
const int64_t t_start = ggml_time_us();
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
// number of cells moved
|
||||
uint32_t n_moves = 0;
|
||||
|
@ -8127,15 +8146,26 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
|||
|
||||
// found a hole - fill it with data from the end of the cache
|
||||
|
||||
// determine the size of the hole
|
||||
uint32_t nh = 1;
|
||||
|
||||
// determine the size of the hole
|
||||
while (i0 + nh < n_used && kv_self.cells[i0 + nh].is_empty()) {
|
||||
nh++;
|
||||
}
|
||||
|
||||
// starting from the end, find nh non-empty cells
|
||||
// each move requires 6*n_layer tensors (see build_defrag)
|
||||
// - source view, destination view, copy operation
|
||||
// - x2 for keys and values
|
||||
//
|
||||
if (6*(n_moves + nh)*n_layer >= LLAMA_MAX_NODES) {
|
||||
// the graph is too big, we cannot move more cells
|
||||
break;
|
||||
}
|
||||
|
||||
uint32_t nf = 0;
|
||||
uint32_t is = n_kv - 1;
|
||||
|
||||
// starting from the end, find nh non-empty cells
|
||||
for (; is > i0; --is) {
|
||||
const auto & cell1 = kv_self.cells[is];
|
||||
|
||||
|
@ -8156,11 +8186,17 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
|||
|
||||
nf = 0;
|
||||
|
||||
uint32_t i1 = is;
|
||||
|
||||
// are we moving a continuous block of memory?
|
||||
bool cont = false;
|
||||
|
||||
// go back and move the nf cells to the hole
|
||||
for (uint32_t i1 = is; i1 < n_kv; ++i1) {
|
||||
const auto & cell1 = kv_self.cells[i1];
|
||||
for (; i1 < n_kv; ++i1) {
|
||||
auto & cell1 = kv_self.cells[i1];
|
||||
|
||||
if (cell1.is_empty() || ids[i1] != n_kv) {
|
||||
cont = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -8170,11 +8206,23 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
|||
// move the cell meta data
|
||||
kv_self.cells[i0 + nf] = cell1;
|
||||
|
||||
n_moves++;
|
||||
// clear the old cell and move the head there
|
||||
cell1 = llama_kv_cell();
|
||||
kv_self.head = n_used;
|
||||
|
||||
if (!cont) {
|
||||
n_moves++;
|
||||
cont = true;
|
||||
}
|
||||
|
||||
nf++;
|
||||
|
||||
if (nf == nh) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, n_kv, i0, i0 + nh);
|
||||
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
|
||||
|
||||
i0 += nh - 1;
|
||||
}
|
||||
|
@ -8183,15 +8231,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
|||
return;
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
|
||||
//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
|
||||
|
||||
kv_self.head = n_used;
|
||||
kv_self.used = n_used;
|
||||
|
||||
// zero the rest of the cells
|
||||
for (uint32_t i = n_used; i < n_kv; ++i) {
|
||||
kv_self.cells[i] = llama_kv_cell();
|
||||
}
|
||||
//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
|
||||
|
||||
#if 0
|
||||
// CPU defrag
|
||||
|
@ -8203,9 +8245,6 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
|||
// likely not worth the effort, as we have ggml_graph based defrag
|
||||
//
|
||||
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
|
@ -8274,9 +8313,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
|||
llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
|
||||
#endif
|
||||
|
||||
const int64_t t_end = ggml_time_us();
|
||||
//const int64_t t_end = ggml_time_us();
|
||||
|
||||
LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
|
||||
//LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
|
||||
}
|
||||
|
||||
static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||
|
@ -11670,6 +11709,7 @@ struct llama_context_params llama_context_default_params() {
|
|||
/*.yarn_beta_fast =*/ 32.0f,
|
||||
/*.yarn_beta_slow =*/ 1.0f,
|
||||
/*.yarn_orig_ctx =*/ 0,
|
||||
/*.defrag_thold =*/ -1.0f,
|
||||
/*.cb_eval =*/ nullptr,
|
||||
/*.cb_eval_user_data =*/ nullptr,
|
||||
/*.type_k =*/ GGML_TYPE_F16,
|
||||
|
@ -11834,6 +11874,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
cparams.yarn_attn_factor = params.yarn_attn_factor;
|
||||
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
||||
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
||||
cparams.defrag_thold = params.defrag_thold;
|
||||
cparams.mul_mat_q = params.mul_mat_q;
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.do_pooling = params.do_pooling;
|
||||
|
@ -12035,7 +12076,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
}
|
||||
|
||||
// buffer used to store the computation graph and the tensor meta data
|
||||
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead());
|
||||
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead_custom(LLAMA_MAX_NODES, false));
|
||||
|
||||
ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES);
|
||||
|
||||
|
|
1
llama.h
1
llama.h
|
@ -245,6 +245,7 @@ extern "C" {
|
|||
float yarn_beta_fast; // YaRN low correction dim
|
||||
float yarn_beta_slow; // YaRN high correction dim
|
||||
uint32_t yarn_orig_ctx; // YaRN original context size
|
||||
float defrag_thold; // defragment the KV cache if holes/size > thold, < 0 disabled (default)
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
void * cb_eval_user_data;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue