Merge branch 'ggerganov:master' into server_branch

This commit is contained in:
pudepiedj 2024-02-27 13:39:30 +00:00 committed by GitHub
commit ad4f567d8e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 126 additions and 86 deletions

View file

@ -381,8 +381,13 @@ ifdef LLAMA_BLIS
endif # LLAMA_BLIS
ifdef LLAMA_CUBLAS
MK_CPPFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include -I/usr/local/cuda/targets/aarch64-linux/include
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib -L/usr/lib/wsl/lib
ifneq ('', '$(wildcard /opt/cuda)')
CUDA_PATH ?= /opt/cuda
else
CUDA_PATH ?= /usr/local/cuda
endif
MK_CPPFLAGS += -DGGML_USE_CUBLAS -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
OBJS += ggml-cuda.o
MK_NVCCFLAGS += -use_fast_math
ifdef LLAMA_FATAL_WARNINGS

View file

@ -335,6 +335,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.yarn_beta_slow = std::stof(argv[i]);
} else if (arg == "--defrag-thold" || arg == "-dt") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.defrag_thold = std::stof(argv[i]);
} else if (arg == "--samplers") {
if (++i >= argc) {
invalid_param = true;
@ -1004,6 +1010,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
printf(" -dt N, --defrag-thold N\n");
printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
printf(" --no-penalize-nl do not penalize newline token\n");
printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp);
@ -1285,6 +1293,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
cparams.defrag_thold = params.defrag_thold;
cparams.offload_kqv = !params.no_kv_offload;
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);

View file

@ -75,6 +75,7 @@ struct gpt_params {
float yarn_beta_fast = 32.0f; // YaRN low correction dim
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
float defrag_thold = -1.0f; // KV cache defragmentation threshold
int32_t rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;

View file

@ -182,7 +182,7 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
llama_kv_cache_defrag (ctx);
//llama_kv_cache_defrag (ctx);
llama_kv_cache_update (ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
@ -213,7 +213,7 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
llama_kv_cache_defrag (ctx);
//llama_kv_cache_defrag (ctx);
llama_kv_cache_update (ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;

View file

@ -1428,6 +1428,10 @@ struct llama_server_context
split_multiprompt_task(task_id, task);
}
} else {
// an empty prompt can make slot become buggy
if (task.data.contains("prompt") && task.data["prompt"].is_string() && task.data["prompt"].get<std::string>().empty()) {
task.data["prompt"] = " "; // add a space so that we have one token
}
queue_tasks.post(task);
}
}

View file

@ -696,18 +696,20 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
return a;
}
//static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
//#pragma unroll
// for (int mask = 16; mask > 0; mask >>= 1) {
// a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
// }
// return a;
//#else
// (void) a;
// NO_DEVICE_CODE;
//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
//}
#ifdef GGML_CUDA_F16
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
}
return a;
#else
(void) a;
NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
}
#endif // GGML_CUDA_F16
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
@ -2521,10 +2523,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
#endif
// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
tmp = warp_reduce_sum(tmp);
if (threadIdx.x == 0) {
dst[row] = tmp;
@ -2625,10 +2624,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
#endif
// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
tmp = warp_reduce_sum(tmp);
if (threadIdx.x == 0) {
dst[row] = tmp;
@ -2761,10 +2757,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
#endif
// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
tmp = warp_reduce_sum(tmp);
if (tid == 0) {
dst[row] = tmp;
@ -2877,10 +2870,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
#endif
// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
tmp = warp_reduce_sum(tmp);
if (threadIdx.x == 0) {
dst[row] = tmp;
@ -2987,10 +2977,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
#endif
// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
tmp = warp_reduce_sum(tmp);
if (tid == 0) {
dst[row] = tmp;
@ -3025,11 +3012,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
float amax = fabsf(xi);
float sum = xi;
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
}
amax = warp_reduce_max(amax);
sum = warp_reduce_sum(sum);
const float d = amax / 127;
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
@ -6222,10 +6206,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
}
// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
tmp = warp_reduce_sum(tmp);
if (tid == 0) {
#ifdef GGML_CUDA_F16
@ -6275,10 +6256,7 @@ static __global__ void mul_mat_p021_f16_f32(
const int idst = channel*nrows_dst + row_dst;
// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
tmp = warp_reduce_sum(tmp);
if (threadIdx.x == 0) {
dst[idst] = tmp;
@ -6321,10 +6299,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
}
// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
tmp = warp_reduce_sum(tmp);
if (threadIdx.x == 0) {
dst[idst] = tmp;

View file

@ -10248,8 +10248,12 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
uint64_t aux64;
__m256i v_gindex;
const uint16_t * gindex = (const uint16_t *)&v_gindex;
typedef union m256i_uint16 {
__m256i reg;
uint16_t s[16];
} m256i_uint16_t;
m256i_uint16_t v_gindex;
__m256 accum = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
@ -10264,13 +10268,13 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
memcpy(&aux64, sc, 8); sc += 8;
const __m128i qh = _mm_shuffle_epi8(_mm_set_epi64x(aux64 >> 4, aux64), shuffle_h);
const __m256i hbit = _mm256_cvtepu8_epi16(_mm_and_si128(qh, m8));
v_gindex = _mm256_or_si256(_mm256_cvtepu8_epi16(ql), _mm256_slli_epi16(hbit, 5));
v_gindex.reg = _mm256_or_si256(_mm256_cvtepu8_epi16(ql), _mm256_slli_epi16(hbit, 5));
const __m128i scales = _mm_or_si128(_mm_slli_epi16(_mm_and_si128(qh, m7), 1), m1);
for (int i32 = 0; i32 < 4; ++i32) {
const __m256i q8b = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
const __m256i q1b = _mm256_set_epi64x(iq1s_grid[gindex[4*i32+3]], iq1s_grid[gindex[4*i32+2]],
iq1s_grid[gindex[4*i32+1]], iq1s_grid[gindex[4*i32+0]]);
const __m256i q1b = _mm256_set_epi64x(iq1s_grid[v_gindex.s[4*i32+3]], iq1s_grid[v_gindex.s[4*i32+2]],
iq1s_grid[v_gindex.s[4*i32+1]], iq1s_grid[v_gindex.s[4*i32+0]]);
const __m256i dot = mul_add_epi8(q1b, q8b);
const __m256i s16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, shuffle_s[i32]));
const __m256i p = _mm256_madd_epi16(s16, dot);

View file

@ -1641,6 +1641,7 @@ struct llama_cparams {
float yarn_attn_factor;
float yarn_beta_fast;
float yarn_beta_slow;
float defrag_thold;
bool mul_mat_q;
bool offload_kqv;
@ -5117,16 +5118,16 @@ struct llm_build_context {
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
for (int i = 0; i < n_kv; ++i) {
const int id = ids[i];
for (uint32_t i = 0; i < ids.size(); ++i) {
const uint32_t id = ids[i];
if (i == id || id == n_kv) {
if (i == id || id == ids.size()) {
continue;
}
int nm = 1;
uint32_t nm = 1;
while (i + nm < n_kv && (int) ids[i + nm] == id + nm) {
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
nm++;
}
@ -5158,6 +5159,8 @@ struct llm_build_context {
i += nm - 1;
}
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
return gf;
}
@ -7938,6 +7941,8 @@ static int llama_decode_internal(
batch.seq_id = seq_id_arr.data();
}
llama_kv_cache_update(&lctx);
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*n_tokens) {
@ -7956,8 +7961,6 @@ static int llama_decode_internal(
// line above and below originally commented out
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
llama_kv_cache_update(&lctx);
ggml_backend_sched_reset(lctx.sched);
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
@ -8007,6 +8010,18 @@ static int llama_decode_internal(
}
}
// decide if we need to defrag the kv cache
if (cparams.defrag_thold >= 0.0f) {
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used + n_tokens)/float(kv_self.n) : 0.0f;
// queue defragmentation for next llama_kv_cache_update
if (fragmentation > cparams.defrag_thold) {
//LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
llama_kv_cache_defrag(kv_self);
}
}
#ifdef GGML_PERF
// print timing information per ggml operation (for debugging purposes)
// requires GGML_PERF to be defined
@ -8098,12 +8113,16 @@ static int llama_decode_internal(
static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
auto & kv_self = lctx.kv_self;
const auto & hparams = lctx.model.hparams;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_kv = llama_kv_cache_cell_max(kv_self);
const uint32_t n_used = kv_self.used;
assert(n_used <= n_kv);
const int64_t t_start = ggml_time_us();
//const int64_t t_start = ggml_time_us();
// number of cells moved
uint32_t n_moves = 0;
@ -8127,15 +8146,26 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
// found a hole - fill it with data from the end of the cache
// determine the size of the hole
uint32_t nh = 1;
// determine the size of the hole
while (i0 + nh < n_used && kv_self.cells[i0 + nh].is_empty()) {
nh++;
}
// starting from the end, find nh non-empty cells
// each move requires 6*n_layer tensors (see build_defrag)
// - source view, destination view, copy operation
// - x2 for keys and values
//
if (6*(n_moves + nh)*n_layer >= LLAMA_MAX_NODES) {
// the graph is too big, we cannot move more cells
break;
}
uint32_t nf = 0;
uint32_t is = n_kv - 1;
// starting from the end, find nh non-empty cells
for (; is > i0; --is) {
const auto & cell1 = kv_self.cells[is];
@ -8156,11 +8186,17 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
nf = 0;
uint32_t i1 = is;
// are we moving a continuous block of memory?
bool cont = false;
// go back and move the nf cells to the hole
for (uint32_t i1 = is; i1 < n_kv; ++i1) {
const auto & cell1 = kv_self.cells[i1];
for (; i1 < n_kv; ++i1) {
auto & cell1 = kv_self.cells[i1];
if (cell1.is_empty() || ids[i1] != n_kv) {
cont = false;
continue;
}
@ -8170,11 +8206,23 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
// move the cell meta data
kv_self.cells[i0 + nf] = cell1;
n_moves++;
// clear the old cell and move the head there
cell1 = llama_kv_cell();
kv_self.head = n_used;
if (!cont) {
n_moves++;
cont = true;
}
nf++;
if (nf == nh) {
break;
}
}
LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, n_kv, i0, i0 + nh);
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
i0 += nh - 1;
}
@ -8183,15 +8231,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
return;
}
LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
kv_self.head = n_used;
kv_self.used = n_used;
// zero the rest of the cells
for (uint32_t i = n_used; i < n_kv; ++i) {
kv_self.cells[i] = llama_kv_cell();
}
//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
#if 0
// CPU defrag
@ -8203,9 +8245,6 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
// likely not worth the effort, as we have ggml_graph based defrag
//
const auto & hparams = lctx.model.hparams;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
@ -8274,9 +8313,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
#endif
const int64_t t_end = ggml_time_us();
//const int64_t t_end = ggml_time_us();
LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
//LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
}
static void llama_kv_cache_update_internal(struct llama_context & lctx) {
@ -11670,6 +11709,7 @@ struct llama_context_params llama_context_default_params() {
/*.yarn_beta_fast =*/ 32.0f,
/*.yarn_beta_slow =*/ 1.0f,
/*.yarn_orig_ctx =*/ 0,
/*.defrag_thold =*/ -1.0f,
/*.cb_eval =*/ nullptr,
/*.cb_eval_user_data =*/ nullptr,
/*.type_k =*/ GGML_TYPE_F16,
@ -11834,6 +11874,7 @@ struct llama_context * llama_new_context_with_model(
cparams.yarn_attn_factor = params.yarn_attn_factor;
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.defrag_thold = params.defrag_thold;
cparams.mul_mat_q = params.mul_mat_q;
cparams.offload_kqv = params.offload_kqv;
cparams.do_pooling = params.do_pooling;
@ -12035,7 +12076,7 @@ struct llama_context * llama_new_context_with_model(
}
// buffer used to store the computation graph and the tensor meta data
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead());
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead_custom(LLAMA_MAX_NODES, false));
ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES);

View file

@ -245,6 +245,7 @@ extern "C" {
float yarn_beta_fast; // YaRN low correction dim
float yarn_beta_slow; // YaRN high correction dim
uint32_t yarn_orig_ctx; // YaRN original context size
float defrag_thold; // defragment the KV cache if holes/size > thold, < 0 disabled (default)
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;