From 27d53cb4ee92fe96dde9528c84738e3232810584 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 13 Oct 2024 16:11:38 +0300 Subject: [PATCH] llama.vim : logic to evict old chunks that are similar to new one --- examples/llama.vim | 73 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 56 insertions(+), 17 deletions(-) diff --git a/examples/llama.vim b/examples/llama.vim index 8d85fb862..6e1840a54 100644 --- a/examples/llama.vim +++ b/examples/llama.vim @@ -98,7 +98,8 @@ function! llama#init() let s:line_cur_prefix = '' let s:line_cur_suffix = '' - let s:ring_n_chunks = [] + let s:ring_chunks = [] + let s:ring_n_evict = 0 let s:pos_y_pick = -9999 " last y where we picked a chunk let s:pos_dx = 0 @@ -128,6 +129,25 @@ function! llama#init() silent! call llama#fim_cancel() endfunction +" TODO: figure out something better +function! s:chunk_sim(c0, c1) + let l:lines0 = len(a:c0) + let l:lines1 = len(a:c1) + + let l:common = 0 + + for l:line0 in a:c0 + for l:line1 in a:c1 + if l:line0 == l:line1 + let l:common += 1 + break + endif + endfor + endfor + + return 2.0 * l:common / (l:lines0 + l:lines1) +endfunction + function! s:pick_chunk(text, no_mod) " do not pick chunks from buffers with pending changes or buffers that are not files if a:no_mod && (getbufvar(bufnr('%'), '&modified') || !buflisted(bufnr('%')) || !filereadable(expand('%'))) @@ -138,20 +158,25 @@ function! s:pick_chunk(text, no_mod) return endif - if len(a:text) + 1 < g:llama_config.ring_chunk_size - let l:chunk = join(a:text, "\n") - else - let l:l0 = s:rand(0, len(a:text) - g:llama_config.ring_chunk_size) - let l:l1 = l:l0 + g:llama_config.ring_chunk_size - - let l:chunk = join(a:text[l:l0:l:l1], "\n") + if len(a:text) < 3 + return endif + if len(a:text) + 1 < g:llama_config.ring_chunk_size + let l:chunk = a:text + else + let l:l0 = s:rand(0, max([0, len(a:text) - g:llama_config.ring_chunk_size])) + let l:l1 = min([l:l0 + g:llama_config.ring_chunk_size, len(a:text)]) + + let l:chunk = a:text[l:l0:l:l1] + endif + + let l:chunk_str = join(l:chunk, "\n") + " check if this chunk is already added - " TODO: smarter check for string similarity to evict old chunks that are very similart to the new one let l:exist = v:false - for i in range(len(s:ring_n_chunks)) - if s:ring_n_chunks[i] == l:chunk + for i in range(len(s:ring_chunks)) + if s:ring_chunks[i].data == l:chunk let l:exist = v:true break endif @@ -161,11 +186,19 @@ function! s:pick_chunk(text, no_mod) return endif - if len(s:ring_n_chunks) == g:llama_config.ring_n_chunks - call remove(s:ring_n_chunks, 0) + " evict chunks that are very similar to the new one + for i in range(len(s:ring_chunks) - 1, 0, -1) + if s:chunk_sim(s:ring_chunks[i].data, l:chunk) > 0.9 + call remove(s:ring_chunks, i) + let s:ring_n_evict += 1 + endif + endfor + + if len(s:ring_chunks) == g:llama_config.ring_n_chunks + call remove(s:ring_chunks, 0) endif - call add(s:ring_n_chunks, l:chunk) + call add(s:ring_chunks, {'data': l:chunk, 'str': l:chunk_str, 'time': reltime()}) endfunction function! llama#fim(is_auto) abort @@ -213,6 +246,12 @@ function! llama#fim(is_auto) abort let s:pos_y_pick = s:pos_y endif + " array of strings + let l:extra_context = [] + for l:chunk in s:ring_chunks + call add(l:extra_context, l:chunk.str) + endfor + let l:request = json_encode({ \ 'prompt': "", \ 'input_prefix': l:prefix, @@ -223,7 +262,7 @@ function! llama#fim(is_auto) abort \ 'stream': v:false, \ 'samplers': ["top_k", "infill"], \ 'cache_prompt': v:true, - \ 'extra_context': s:ring_n_chunks, + \ 'extra_context': l:extra_context, \ 't_max_prompt_ms': g:llama_config.t_max_prompt_ms, \ 't_max_predict_ms': g:llama_config.t_max_predict_ms \ }) @@ -418,9 +457,9 @@ function! s:fim_on_stdout(job_id, data, event) dict " prefix the info string with whitespace in order to offset it to the right of the fim overlay let l:prefix = repeat(' ', len(s:content[0]) - len(s:line_cur_suffix) + 3) - let l:info = printf("%s | context: %d / %d | prompt: %d (%.2f ms, %.2f t/s) | predict: %d (%.2f ms, %.2f t/s) | total: %.2f ms", + let l:info = printf("%s | context: %d / %d / %d / %d | prompt: %d (%.2f ms, %.2f t/s) | predict: %d (%.2f ms, %.2f t/s) | total: %.2f ms", \ g:llama_config.show_info == 2 ? l:prefix : '', - \ l:n_cached, l:n_ctx, + \ l:n_cached, l:n_ctx, len(s:ring_chunks), s:ring_n_evict, \ l:n_prompt, l:t_prompt_ms, l:s_prompt, \ l:n_predict, l:t_predict_ms, l:s_predict, \ 1000.0 * reltimefloat(reltime(s:t_fim_start))