Merge branch 'master' into gg/server-check-ctx
This commit is contained in:
commit
bd67115981
15 changed files with 1010 additions and 462 deletions
|
@ -53,7 +53,7 @@ if [ ! -z ${GG_BUILD_SYCL} ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_SYCL=1 DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON"
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_SYCL=1 -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON"
|
||||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_VULKAN} ]; then
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
"
|
||||
" requires:
|
||||
"
|
||||
" - neovim
|
||||
" - neovim or vim
|
||||
" - curl
|
||||
" - llama.cpp server instance
|
||||
" - FIM-compatible model
|
||||
|
@ -10,7 +10,7 @@
|
|||
" sample config:
|
||||
"
|
||||
" - Tab - accept the current suggestion
|
||||
" - Shift+Tab - accept just the first line of the segguestion
|
||||
" - Shift+Tab - accept just the first line of the suggestion
|
||||
" - Ctrl+F - toggle FIM completion manually
|
||||
"
|
||||
" make symlink or copy this file to ~/.config/nvim/autoload/llama.vim
|
||||
|
@ -43,8 +43,8 @@
|
|||
"
|
||||
|
||||
" colors (adjust to your liking)
|
||||
highlight llama_hl_hint guifg=#ff772f
|
||||
highlight llama_hl_info guifg=#77ff2f
|
||||
highlight llama_hl_hint guifg=#ff772f ctermfg=202
|
||||
highlight llama_hl_info guifg=#77ff2f ctermfg=119
|
||||
|
||||
" general parameters:
|
||||
"
|
||||
|
@ -81,7 +81,7 @@ let s:default_config = {
|
|||
\ 'n_suffix': 64,
|
||||
\ 'n_predict': 128,
|
||||
\ 't_max_prompt_ms': 500,
|
||||
\ 't_max_predict_ms': 1000,
|
||||
\ 't_max_predict_ms': 3000,
|
||||
\ 'show_info': 2,
|
||||
\ 'auto_fim': v:true,
|
||||
\ 'max_line_suffix': 8,
|
||||
|
@ -93,6 +93,18 @@ let s:default_config = {
|
|||
|
||||
let g:llama_config = get(g:, 'llama_config', s:default_config)
|
||||
|
||||
function! s:get_indent(str)
|
||||
let l:count = 0
|
||||
for i in range(len(a:str))
|
||||
if a:str[i] == "\t"
|
||||
let l:count += &tabstop - 1
|
||||
else
|
||||
break
|
||||
endif
|
||||
endfor
|
||||
return l:count
|
||||
endfunction
|
||||
|
||||
function! s:rand(i0, i1) abort
|
||||
return a:i0 + rand() % (a:i1 - a:i0 + 1)
|
||||
endfunction
|
||||
|
@ -129,6 +141,21 @@ function! llama#init()
|
|||
|
||||
let s:current_job = v:null
|
||||
|
||||
let s:ghost_text_nvim = exists('*nvim_buf_get_mark')
|
||||
let s:ghost_text_vim = has('textprop')
|
||||
|
||||
if s:ghost_text_vim
|
||||
let s:hlgroup_hint = 'llama_hl_hint'
|
||||
let s:hlgroup_info = 'llama_hl_info'
|
||||
|
||||
if empty(prop_type_get(s:hlgroup_hint))
|
||||
call prop_type_add(s:hlgroup_hint, {'highlight': s:hlgroup_hint})
|
||||
endif
|
||||
if empty(prop_type_get(s:hlgroup_info))
|
||||
call prop_type_add(s:hlgroup_info, {'highlight': s:hlgroup_info})
|
||||
endif
|
||||
endif
|
||||
|
||||
augroup llama
|
||||
autocmd!
|
||||
autocmd InsertEnter * inoremap <expr> <silent> <C-F> llama#fim_inline(v:false)
|
||||
|
@ -317,13 +344,22 @@ function! s:ring_update()
|
|||
\ 't_max_predict_ms': 1
|
||||
\ })
|
||||
|
||||
let l:curl_command = printf(
|
||||
\ "curl --silent --no-buffer --request POST --url %s --header \"Content-Type: application/json\" --data %s",
|
||||
\ g:llama_config.endpoint, shellescape(l:request)
|
||||
\ )
|
||||
let l:curl_command = [
|
||||
\ "curl",
|
||||
\ "--silent",
|
||||
\ "--no-buffer",
|
||||
\ "--request", "POST",
|
||||
\ "--url", g:llama_config.endpoint,
|
||||
\ "--header", "Content-Type: application/json",
|
||||
\ "--data", l:request
|
||||
\ ]
|
||||
|
||||
" no callbacks because we don't need to process the response
|
||||
call jobstart(l:curl_command, {})
|
||||
if s:ghost_text_nvim
|
||||
call jobstart(l:curl_command, {})
|
||||
elseif s:ghost_text_vim
|
||||
call job_start(l:curl_command, {})
|
||||
endif
|
||||
endfunction
|
||||
|
||||
" necessary for 'inoremap <expr>'
|
||||
|
@ -418,24 +454,37 @@ function! llama#fim(is_auto) abort
|
|||
\ 't_max_predict_ms': g:llama_config.t_max_predict_ms
|
||||
\ })
|
||||
|
||||
let l:curl_command = printf(
|
||||
\ "curl --silent --no-buffer --request POST --url %s --header \"Content-Type: application/json\" --data %s",
|
||||
\ g:llama_config.endpoint, shellescape(l:request)
|
||||
\ )
|
||||
let l:curl_command = [
|
||||
\ "curl",
|
||||
\ "--silent",
|
||||
\ "--no-buffer",
|
||||
\ "--request", "POST",
|
||||
\ "--url", g:llama_config.endpoint,
|
||||
\ "--header", "Content-Type: application/json",
|
||||
\ "--data", l:request
|
||||
\ ]
|
||||
|
||||
if s:current_job != v:null
|
||||
call jobstop(s:current_job)
|
||||
if s:ghost_text_nvim
|
||||
call jobstop(s:current_job)
|
||||
elseif s:ghost_text_vim
|
||||
call job_stop(s:current_job)
|
||||
endif
|
||||
endif
|
||||
|
||||
" send the request asynchronously
|
||||
let s:current_job = jobstart(l:curl_command, {
|
||||
\ 'on_stdout': function('s:fim_on_stdout'),
|
||||
\ 'on_exit': function('s:fim_on_exit'),
|
||||
\ 'stdout_buffered': v:true,
|
||||
\ 'pos_x': s:pos_x,
|
||||
\ 'pos_y': s:pos_y,
|
||||
\ 'is_auto': a:is_auto
|
||||
\ })
|
||||
if s:ghost_text_nvim
|
||||
let s:current_job = jobstart(l:curl_command, {
|
||||
\ 'on_stdout': function('s:fim_on_stdout', [s:pos_x, s:pos_y, a:is_auto]),
|
||||
\ 'on_exit': function('s:fim_on_exit'),
|
||||
\ 'stdout_buffered': v:true
|
||||
\ })
|
||||
elseif s:ghost_text_vim
|
||||
let s:current_job = job_start(l:curl_command, {
|
||||
\ 'out_cb': function('s:fim_on_stdout', [s:pos_x, s:pos_y, a:is_auto]),
|
||||
\ 'exit_cb': function('s:fim_on_exit')
|
||||
\ })
|
||||
endif
|
||||
|
||||
" TODO: per-file location
|
||||
let l:delta_y = abs(s:pos_y - s:pos_y_pick)
|
||||
|
@ -482,9 +531,13 @@ function! llama#fim_cancel()
|
|||
" clear the virtual text
|
||||
let l:bufnr = bufnr('%')
|
||||
|
||||
let l:id_vt_fim = nvim_create_namespace('vt_fim')
|
||||
|
||||
call nvim_buf_clear_namespace(l:bufnr, l:id_vt_fim, 0, -1)
|
||||
if s:ghost_text_nvim
|
||||
let l:id_vt_fim = nvim_create_namespace('vt_fim')
|
||||
call nvim_buf_clear_namespace(l:bufnr, l:id_vt_fim, 0, -1)
|
||||
elseif s:ghost_text_vim
|
||||
call prop_remove({'type': s:hlgroup_hint, 'all': v:true})
|
||||
call prop_remove({'type': s:hlgroup_info, 'all': v:true})
|
||||
endif
|
||||
|
||||
" remove the mappings
|
||||
silent! iunmap <buffer> <Tab>
|
||||
|
@ -499,13 +552,18 @@ function! s:on_move()
|
|||
endfunction
|
||||
|
||||
" callback that processes the FIM result from the server and displays the suggestion
|
||||
function! s:fim_on_stdout(job_id, data, event) dict
|
||||
let l:raw = join(a:data, "\n")
|
||||
function! s:fim_on_stdout(pos_x, pos_y, is_auto, job_id, data, event = v:null)
|
||||
if s:ghost_text_nvim
|
||||
let l:raw = join(a:data, "\n")
|
||||
elseif s:ghost_text_vim
|
||||
let l:raw = a:data
|
||||
endif
|
||||
|
||||
if len(l:raw) == 0
|
||||
return
|
||||
endif
|
||||
|
||||
if self.pos_x != col('.') - 1 || self.pos_y != line('.')
|
||||
if a:pos_x != col('.') - 1 || a:pos_y != line('.')
|
||||
return
|
||||
endif
|
||||
|
||||
|
@ -514,14 +572,14 @@ function! s:fim_on_stdout(job_id, data, event) dict
|
|||
return
|
||||
endif
|
||||
|
||||
let s:pos_x = self.pos_x
|
||||
let s:pos_y = self.pos_y
|
||||
let s:pos_x = a:pos_x
|
||||
let s:pos_y = a:pos_y
|
||||
|
||||
let s:can_accept = v:true
|
||||
let l:has_info = v:false
|
||||
|
||||
if s:can_accept && v:shell_error
|
||||
if !self.is_auto
|
||||
if !a:is_auto
|
||||
call add(s:content, "<| curl error: is the server on? |>")
|
||||
endif
|
||||
let s:can_accept = v:false
|
||||
|
@ -642,7 +700,9 @@ function! s:fim_on_stdout(job_id, data, event) dict
|
|||
" display virtual text with the suggestion
|
||||
let l:bufnr = bufnr('%')
|
||||
|
||||
let l:id_vt_fim = nvim_create_namespace('vt_fim')
|
||||
if s:ghost_text_nvim
|
||||
let l:id_vt_fim = nvim_create_namespace('vt_fim')
|
||||
endif
|
||||
|
||||
" construct the info message
|
||||
if g:llama_config.show_info > 0 && l:has_info
|
||||
|
@ -671,15 +731,41 @@ function! s:fim_on_stdout(job_id, data, event) dict
|
|||
endif
|
||||
|
||||
" display the suggestion and append the info to the end of the first line
|
||||
call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, s:pos_x - 1, {
|
||||
\ 'virt_text': [[s:content[0], 'llama_hl_hint'], [l:info, 'llama_hl_info']],
|
||||
\ 'virt_text_win_col': virtcol('.') - 1
|
||||
\ })
|
||||
if s:ghost_text_nvim
|
||||
call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, s:pos_x - 1, {
|
||||
\ 'virt_text': [[s:content[0], 'llama_hl_hint'], [l:info, 'llama_hl_info']],
|
||||
\ 'virt_text_win_col': virtcol('.') - 1
|
||||
\ })
|
||||
|
||||
call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, 0, {
|
||||
\ 'virt_lines': map(s:content[1:], {idx, val -> [[val, 'llama_hl_hint']]}),
|
||||
\ 'virt_text_win_col': virtcol('.')
|
||||
\ })
|
||||
call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, 0, {
|
||||
\ 'virt_lines': map(s:content[1:], {idx, val -> [[val, 'llama_hl_hint']]}),
|
||||
\ 'virt_text_win_col': virtcol('.')
|
||||
\ })
|
||||
elseif s:ghost_text_vim
|
||||
let l:new_suffix = s:content[0]
|
||||
if !empty(l:new_suffix)
|
||||
call prop_add(s:pos_y, s:pos_x + 1, {
|
||||
\ 'type': s:hlgroup_hint,
|
||||
\ 'text': l:new_suffix
|
||||
\ })
|
||||
endif
|
||||
for line in s:content[1:]
|
||||
call prop_add(s:pos_y, 0, {
|
||||
\ 'type': s:hlgroup_hint,
|
||||
\ 'text': line,
|
||||
\ 'text_padding_left': s:get_indent(line),
|
||||
\ 'text_align': 'below'
|
||||
\ })
|
||||
endfor
|
||||
if !empty(l:info)
|
||||
call prop_add(s:pos_y, 0, {
|
||||
\ 'type': s:hlgroup_info,
|
||||
\ 'text': l:info,
|
||||
\ 'text_padding_left': col('$'),
|
||||
\ 'text_wrap': 'truncate'
|
||||
\ })
|
||||
endif
|
||||
endif
|
||||
|
||||
" setup accept shortcuts
|
||||
inoremap <buffer> <Tab> <C-O>:call llama#fim_accept(v:false)<CR>
|
||||
|
@ -688,7 +774,7 @@ function! s:fim_on_stdout(job_id, data, event) dict
|
|||
let s:hint_shown = v:true
|
||||
endfunction
|
||||
|
||||
function! s:fim_on_exit(job_id, exit_code, event) dict
|
||||
function! s:fim_on_exit(job_id, exit_code, event = v:null)
|
||||
if a:exit_code != 0
|
||||
echom "Job failed with exit code: " . a:exit_code
|
||||
endif
|
||||
|
|
|
@ -319,6 +319,18 @@ node index.js
|
|||
- The prompt is a string or an array with the first element given as a string
|
||||
- The model's `tokenizer.ggml.add_bos_token` metadata is `true`
|
||||
|
||||
These input shapes and data type are allowed for `prompt`:
|
||||
|
||||
- Single string: `"string"`
|
||||
- Single sequence of tokens: `[12, 34, 56]`
|
||||
- Mixed tokens and strings: `[12, 34, "string", 56, 78]`
|
||||
|
||||
Multiple prompts are also supported. In this case, the completion result will be an array.
|
||||
|
||||
- Only strings: `["string1", "string2"]`
|
||||
- Strings and sequences of tokens: `["string1", [12, 34, 56]]`
|
||||
- Mixed types: `[[12, 34, "string", 56, 78], [12, 34, 56], "string"]`
|
||||
|
||||
`temperature`: Adjust the randomness of the generated text. Default: `0.8`
|
||||
|
||||
`dynatemp_range`: Dynamic temperature range. The final temperature will be in the range of `[temperature - dynatemp_range; temperature + dynatemp_range]` Default: `0.0`, which is disabled.
|
||||
|
|
|
@ -43,21 +43,6 @@
|
|||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
|
||||
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
|
||||
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
enum stop_type {
|
||||
|
@ -68,6 +53,7 @@ enum stop_type {
|
|||
// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
|
||||
enum slot_state {
|
||||
SLOT_STATE_IDLE,
|
||||
SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
|
||||
SLOT_STATE_PROCESSING_PROMPT,
|
||||
SLOT_STATE_DONE_PROMPT,
|
||||
SLOT_STATE_GENERATING,
|
||||
|
@ -79,7 +65,7 @@ enum server_state {
|
|||
};
|
||||
|
||||
enum server_task_type {
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
SERVER_TASK_TYPE_INFERENCE,
|
||||
SERVER_TASK_TYPE_CANCEL,
|
||||
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
||||
SERVER_TASK_TYPE_METRICS,
|
||||
|
@ -89,21 +75,22 @@ enum server_task_type {
|
|||
SERVER_TASK_TYPE_SET_LORA,
|
||||
};
|
||||
|
||||
enum server_task_cmpl_type {
|
||||
SERVER_TASK_CMPL_TYPE_NORMAL,
|
||||
SERVER_TASK_CMPL_TYPE_EMBEDDING,
|
||||
SERVER_TASK_CMPL_TYPE_RERANK,
|
||||
SERVER_TASK_CMPL_TYPE_INFILL,
|
||||
enum server_task_inf_type {
|
||||
SERVER_TASK_INF_TYPE_COMPLETION,
|
||||
SERVER_TASK_INF_TYPE_EMBEDDING,
|
||||
SERVER_TASK_INF_TYPE_RERANK,
|
||||
SERVER_TASK_INF_TYPE_INFILL,
|
||||
};
|
||||
|
||||
struct server_task {
|
||||
int id = -1; // to be filled by server_queue
|
||||
int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
|
||||
|
||||
llama_tokens prompt_tokens;
|
||||
server_task_type type;
|
||||
json data;
|
||||
|
||||
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
||||
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
||||
|
||||
// utility function
|
||||
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
|
||||
|
@ -161,26 +148,20 @@ struct server_slot {
|
|||
int32_t i_batch = -1;
|
||||
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
|
||||
|
||||
// n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
|
||||
int32_t n_prompt_tokens = 0;
|
||||
int32_t n_prompt_tokens_processed = 0;
|
||||
|
||||
json prompt; // can be either a string, array of strings or array of token ids
|
||||
|
||||
json input_prefix;
|
||||
json input_suffix;
|
||||
json input_extra;
|
||||
|
||||
// when a task is submitted, we first tokenize the prompt and store it here
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
std::vector<llama_token> extra_tokens;
|
||||
// input prompt tokens
|
||||
llama_tokens prompt_tokens;
|
||||
|
||||
size_t last_nl_pos = 0;
|
||||
|
||||
std::string generated_text;
|
||||
std::vector<llama_token> cache_tokens;
|
||||
llama_tokens cache_tokens;
|
||||
std::vector<completion_token_output> generated_token_probs;
|
||||
|
||||
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
||||
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
||||
|
||||
bool has_next_token = true;
|
||||
bool has_new_line = false;
|
||||
|
@ -229,7 +210,7 @@ struct server_slot {
|
|||
n_past = 0;
|
||||
n_sent_text = 0;
|
||||
n_sent_token_probs = 0;
|
||||
cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
||||
inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
||||
|
||||
generated_token_probs.clear();
|
||||
}
|
||||
|
@ -734,42 +715,6 @@ struct server_context {
|
|||
metrics.init();
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokenize(const json & json_prompt, bool add_special, bool parse_special) const {
|
||||
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
|
||||
// or the first element of the json_prompt array is a string.
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
|
||||
if (json_prompt.is_array()) {
|
||||
bool first = true;
|
||||
for (const auto & p : json_prompt) {
|
||||
if (p.is_string()) {
|
||||
auto s = p.template get<std::string>();
|
||||
|
||||
std::vector<llama_token> p;
|
||||
if (first) {
|
||||
p = common_tokenize(ctx, s, add_special, parse_special);
|
||||
first = false;
|
||||
} else {
|
||||
p = common_tokenize(ctx, s, false, parse_special);
|
||||
}
|
||||
|
||||
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
|
||||
} else {
|
||||
if (first) {
|
||||
first = false;
|
||||
}
|
||||
|
||||
prompt_tokens.push_back(p.template get<llama_token>());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto s = json_prompt.template get<std::string>();
|
||||
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
|
||||
}
|
||||
|
||||
return prompt_tokens;
|
||||
}
|
||||
|
||||
server_slot * get_slot_by_id(int id) {
|
||||
for (server_slot & slot : slots) {
|
||||
if (slot.id == id) {
|
||||
|
@ -794,22 +739,16 @@ struct server_context {
|
|||
continue;
|
||||
}
|
||||
|
||||
// skip the slot if it does not contains prompt
|
||||
if (!slot.prompt.is_string()) {
|
||||
// skip the slot if it does not contains cached tokens
|
||||
if (slot.prompt_tokens.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// current slot's prompt
|
||||
std::string slot_prompt = slot.prompt.get<std::string>();
|
||||
|
||||
// length of the current slot's prompt
|
||||
int slot_prompt_len = slot_prompt.size();
|
||||
|
||||
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
|
||||
int lcp_len = longest_common_prefix(slot_prompt, prompt);
|
||||
int lcp_len = longest_common_prefix(slot.cache_tokens, slot.prompt_tokens);
|
||||
|
||||
// fraction of the common substring length compared to the current slot's prompt length
|
||||
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
|
||||
similarity = static_cast<float>(lcp_len) / static_cast<int>(slot.prompt_tokens.size());
|
||||
|
||||
// select the current slot if the criteria match
|
||||
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
|
||||
|
@ -914,57 +853,6 @@ struct server_context {
|
|||
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
|
||||
}
|
||||
|
||||
// infill
|
||||
slot.input_prefix = json_value(data, "input_prefix", json());
|
||||
slot.input_suffix = json_value(data, "input_suffix", json());
|
||||
slot.input_extra = json_value(data, "input_extra", json());
|
||||
|
||||
SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.input_extra.size());
|
||||
for (const auto & chunk : slot.input_extra) {
|
||||
// { "text": string, "filename": string }
|
||||
if (!chunk.contains("text") || !chunk["text"].is_string()) {
|
||||
send_error(task, "extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
|
||||
// filename is optional
|
||||
if (chunk.contains("filename") && !chunk["filename"].is_string()) {
|
||||
send_error(task, "extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "extra_context chunk in file '%s':\n%s\n", chunk.value("filename", "").c_str(), chunk.value("text", "").c_str());
|
||||
}
|
||||
|
||||
// get prompt
|
||||
{
|
||||
const auto & prompt = data.find("prompt");
|
||||
if (prompt == data.end()) {
|
||||
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((prompt->is_string()) ||
|
||||
(prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
|
||||
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
|
||||
slot.prompt = *prompt;
|
||||
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
|
||||
slot.prompt = prompt->at(0);
|
||||
} else if (prompt->is_array() && prompt->size() > 1) {
|
||||
// array of strings
|
||||
for (const auto & el : *prompt) {
|
||||
if (!el.is_string()) {
|
||||
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
slot.prompt = *prompt;
|
||||
} else {
|
||||
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
slot.sparams.logit_bias.clear();
|
||||
|
||||
|
@ -1044,8 +932,7 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
||||
slot.prompt_tokens.clear();
|
||||
slot.state = SLOT_STATE_STARTED;
|
||||
|
||||
SLT_INF(slot, "%s", "processing task\n");
|
||||
|
||||
|
@ -1297,7 +1184,7 @@ struct server_context {
|
|||
};
|
||||
|
||||
if (slot.sparams.n_probs > 0) {
|
||||
const std::vector<llama_token> to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
|
||||
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
|
||||
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
|
||||
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
|
||||
|
||||
|
@ -1333,7 +1220,7 @@ struct server_context {
|
|||
{"tokens_predicted", slot.n_decoded},
|
||||
{"tokens_evaluated", slot.n_prompt_tokens},
|
||||
{"generation_settings", get_formated_generation(slot)},
|
||||
{"prompt", slot.prompt},
|
||||
{"prompt", common_detokenize(ctx, slot.prompt_tokens)},
|
||||
{"has_new_line", slot.has_new_line},
|
||||
{"truncated", slot.truncated},
|
||||
{"stopped_eos", slot.stopped_eos},
|
||||
|
@ -1348,7 +1235,7 @@ struct server_context {
|
|||
if (slot.sparams.n_probs > 0) {
|
||||
std::vector<completion_token_output> probs;
|
||||
if (!slot.params.stream && slot.stopped_word) {
|
||||
const std::vector<llama_token> stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
|
||||
const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
|
||||
|
||||
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
|
||||
probs = std::vector<completion_token_output>(
|
||||
|
@ -1457,19 +1344,17 @@ struct server_context {
|
|||
// Functions to create new task(s) and receive result(s)
|
||||
//
|
||||
|
||||
std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) {
|
||||
// break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
|
||||
std::vector<server_task> create_tasks_inference(json data, server_task_inf_type inf_type) {
|
||||
std::vector<server_task> tasks;
|
||||
auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
|
||||
auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
|
||||
SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
|
||||
server_task task;
|
||||
task.id = queue_tasks.get_new_id();
|
||||
task.cmpl_type = cmpl_type;
|
||||
task.type = SERVER_TASK_TYPE_COMPLETION;
|
||||
if (replace_prompt) {
|
||||
task.data = task_data;
|
||||
task.data["prompt"] = std::move(prompt);
|
||||
} else {
|
||||
task.data = std::move(task_data);
|
||||
}
|
||||
task.id = queue_tasks.get_new_id();
|
||||
task.inf_type = inf_type;
|
||||
task.type = SERVER_TASK_TYPE_INFERENCE;
|
||||
task.data = task_data;
|
||||
task.prompt_tokens = std::move(prompt_tokens);
|
||||
tasks.push_back(std::move(task));
|
||||
};
|
||||
|
||||
|
@ -1478,41 +1363,49 @@ struct server_context {
|
|||
throw std::runtime_error(error_msg);
|
||||
}
|
||||
|
||||
json prompt = data.at("prompt");
|
||||
|
||||
// if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task
|
||||
if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
|
||||
data["index"] = 0;
|
||||
create_task(data, false, nullptr);
|
||||
} else if (prompt.is_array()) {
|
||||
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
|
||||
std::vector<json> prompts = prompt;
|
||||
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
// prompts[0] is the question
|
||||
// the rest are the answers/documents
|
||||
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
|
||||
for (size_t i = 1; i < prompts.size(); i++) {
|
||||
json qd;
|
||||
qd.push_back(prompts[0]);
|
||||
qd.push_back(prompts[i]);
|
||||
data["index"] = i - 1;
|
||||
create_task(data, true, qd);
|
||||
}
|
||||
} else {
|
||||
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
|
||||
for (size_t i = 0; i < prompts.size(); i++) {
|
||||
const auto & e = prompts[i];
|
||||
if (e.is_string() || json_is_array_of_numbers(e)) {
|
||||
// because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
|
||||
bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL;
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
|
||||
switch (inf_type) {
|
||||
case SERVER_TASK_INF_TYPE_RERANK:
|
||||
{
|
||||
// prompts[0] is the question
|
||||
// the rest are the answers/documents
|
||||
GGML_ASSERT(tokenized_prompts.size() > 1);
|
||||
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) tokenized_prompts.size() - 1);
|
||||
for (size_t i = 1; i < tokenized_prompts.size(); i++) {
|
||||
data["index"] = i - 1;
|
||||
auto tokens = format_rerank(model, tokenized_prompts[0], tokenized_prompts[i]);
|
||||
create_task(data, tokens);
|
||||
}
|
||||
} break;
|
||||
case SERVER_TASK_INF_TYPE_INFILL:
|
||||
{
|
||||
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
data["index"] = i;
|
||||
create_task(data, true, e);
|
||||
} else {
|
||||
throw std::runtime_error(error_msg);
|
||||
auto tokens = format_infill(
|
||||
ctx,
|
||||
data.at("input_prefix"),
|
||||
data.at("input_suffix"),
|
||||
data.at("input_extra"),
|
||||
params.n_batch,
|
||||
params.n_predict,
|
||||
slots[0].n_ctx, // TODO: there should be a better way
|
||||
params.spm_infill,
|
||||
tokenized_prompts[i]
|
||||
);
|
||||
create_task(data, tokens);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
data["index"] = i;
|
||||
create_task(data, tokenized_prompts[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// invalid case
|
||||
throw std::runtime_error(error_msg);
|
||||
}
|
||||
|
||||
return tasks;
|
||||
|
@ -1534,7 +1427,7 @@ struct server_context {
|
|||
queue_tasks.post(cancel_tasks, true);
|
||||
}
|
||||
|
||||
// receive the results from task(s) created by create_tasks_cmpl
|
||||
// receive the results from task(s) created by create_tasks_inference
|
||||
void receive_cmpl_results(
|
||||
const std::unordered_set<int> & id_tasks,
|
||||
const std::function<void(std::vector<server_task_result>&)> & result_handler,
|
||||
|
@ -1558,7 +1451,7 @@ struct server_context {
|
|||
result_handler(results);
|
||||
}
|
||||
|
||||
// receive the results from task(s) created by create_tasks_cmpl, in stream mode
|
||||
// receive the results from task(s) created by create_tasks_inference, in stream mode
|
||||
void receive_cmpl_results_stream(
|
||||
const std::unordered_set<int> & id_tasks, const
|
||||
std::function<bool(server_task_result&)> & result_handler, const
|
||||
|
@ -1591,7 +1484,7 @@ struct server_context {
|
|||
|
||||
void process_single_task(const server_task & task) {
|
||||
switch (task.type) {
|
||||
case SERVER_TASK_TYPE_COMPLETION:
|
||||
case SERVER_TASK_TYPE_INFERENCE:
|
||||
{
|
||||
const int id_slot = json_value(task.data, "id_slot", -1);
|
||||
|
||||
|
@ -1623,9 +1516,10 @@ struct server_context {
|
|||
|
||||
slot->reset();
|
||||
|
||||
slot->id_task = task.id;
|
||||
slot->cmpl_type = task.cmpl_type;
|
||||
slot->index = json_value(task.data, "index", 0);
|
||||
slot->id_task = task.id;
|
||||
slot->inf_type = task.inf_type;
|
||||
slot->index = json_value(task.data, "index", 0);
|
||||
slot->prompt_tokens = std::move(task.prompt_tokens);
|
||||
|
||||
if (!launch_slot_with_task(*slot, task)) {
|
||||
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
|
||||
|
@ -1658,7 +1552,7 @@ struct server_context {
|
|||
slot_data["id"] = slot.id;
|
||||
slot_data["id_task"] = slot.id_task;
|
||||
slot_data["state"] = slot.state;
|
||||
slot_data["prompt"] = slot.prompt;
|
||||
slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens);
|
||||
slot_data["next_token"] = {
|
||||
{"has_next_token", slot.has_next_token},
|
||||
{"has_new_line", slot.has_new_line},
|
||||
|
@ -1785,9 +1679,6 @@ struct server_context {
|
|||
}
|
||||
slot->cache_tokens.resize(token_count);
|
||||
|
||||
// TODO: maybe detokenize the slot->cache_tokens instead?
|
||||
slot->prompt = string_format("[restored %d tokens from file]", (int) token_count);
|
||||
|
||||
const int64_t t_end = ggml_time_us();
|
||||
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
||||
|
||||
|
@ -1954,142 +1845,18 @@ struct server_context {
|
|||
if (params.cont_batching || batch.n_tokens == 0) {
|
||||
for (auto & slot : slots) {
|
||||
// this slot still has a prompt to be processed
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT) {
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
||||
auto & prompt_tokens = slot.prompt_tokens;
|
||||
|
||||
// we haven't tokenized the prompt yet - do it now:
|
||||
if (prompt_tokens.empty()) {
|
||||
SLT_INF(slot, "tokenizing prompt, len = %d\n", (int) slot.prompt.size());
|
||||
|
||||
// TODO: maybe move branch to outside of this loop in the future
|
||||
if (slot.state == SLOT_STATE_STARTED) {
|
||||
slot.t_start_process_prompt = ggml_time_us();
|
||||
slot.t_start_generation = 0;
|
||||
|
||||
switch (slot.cmpl_type) {
|
||||
case SERVER_TASK_CMPL_TYPE_NORMAL:
|
||||
case SERVER_TASK_CMPL_TYPE_EMBEDDING:
|
||||
{
|
||||
prompt_tokens = tokenize(slot.prompt, llama_add_bos_token(model), true);
|
||||
} break;
|
||||
case SERVER_TASK_CMPL_TYPE_RERANK:
|
||||
{
|
||||
// require slot.prompt to be array of 2 strings
|
||||
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
|
||||
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
|
||||
slot.release();
|
||||
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
|
||||
continue;
|
||||
}
|
||||
|
||||
// prompt: [BOS]query[EOS][SEP]doc[EOS]
|
||||
prompt_tokens.clear();
|
||||
prompt_tokens.push_back(llama_token_bos(model));
|
||||
{
|
||||
const auto part = tokenize(slot.prompt[0], false, false);
|
||||
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
||||
}
|
||||
prompt_tokens.push_back(llama_token_eos(model));
|
||||
prompt_tokens.push_back(llama_token_sep(model));
|
||||
{
|
||||
const auto part = tokenize(slot.prompt[1], false, false);
|
||||
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
||||
}
|
||||
prompt_tokens.push_back(llama_token_eos(model));
|
||||
} break;
|
||||
case SERVER_TASK_CMPL_TYPE_INFILL:
|
||||
{
|
||||
// TODO: optimize this block by reducing memory allocations and movement
|
||||
|
||||
// use FIM repo-level pattern:
|
||||
// ref: https://arxiv.org/pdf/2409.12186
|
||||
//
|
||||
// [FIM_REP]myproject
|
||||
// [FIM_SEP]filename0
|
||||
// extra chunk 0
|
||||
// [FIM_SEP]filename1
|
||||
// extra chunk 1
|
||||
// ...
|
||||
// [FIM_SEP]filename
|
||||
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
|
||||
//
|
||||
auto tokens_prefix = tokenize(slot.input_prefix, false, false);
|
||||
auto tokens_suffix = tokenize(slot.input_suffix, false, false);
|
||||
auto tokens_prompt = tokenize(slot.prompt, false, false);
|
||||
|
||||
slot.extra_tokens.clear();
|
||||
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
|
||||
static const auto k_fim_repo = tokenize("myproject\n", false, false);
|
||||
|
||||
slot.extra_tokens.push_back(llama_token_fim_rep(model));
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
|
||||
}
|
||||
|
||||
for (const auto & chunk : slot.input_extra) {
|
||||
// { "text": string, "filename": string }
|
||||
const std::string text = chunk.value("text", "");
|
||||
const std::string filename = chunk.value("filename", "tmp");
|
||||
|
||||
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
||||
const auto k_fim_file = tokenize(filename + "\n", false, false);
|
||||
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
||||
} else {
|
||||
// chunk separator in binary form to avoid confusing the AI
|
||||
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
|
||||
static const auto k_chunk_prefix_tokens = tokenize(k_chunk_prefix_str, false, false);
|
||||
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
|
||||
}
|
||||
|
||||
const auto chunk_tokens = tokenize(text, false, false);
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
|
||||
}
|
||||
|
||||
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
||||
// TODO: current filename
|
||||
static const auto k_fim_file = tokenize("filename\n", false, false);
|
||||
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
||||
}
|
||||
|
||||
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
|
||||
const int n_suffix_take = std::min<int>(tokens_suffix.size(), (n_batch/4));
|
||||
const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4) - 3);
|
||||
|
||||
// fill the rest of the context with extra chunks
|
||||
const int n_extra_take = std::min<int>(std::max<int>(0, slot.n_ctx - (n_batch) - 2*slot.n_predict), slot.extra_tokens.size());
|
||||
|
||||
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
|
||||
tokens_suffix.resize(n_suffix_take);
|
||||
|
||||
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
|
||||
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
|
||||
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
|
||||
|
||||
auto embd_inp = params.spm_infill ? tokens_suffix : tokens_prefix;
|
||||
auto embd_end = params.spm_infill ? tokens_prefix : tokens_suffix;
|
||||
|
||||
if (llama_add_bos_token(model)) {
|
||||
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", slot.n_ctx, n_extra_take, (int) slot.extra_tokens.size());
|
||||
|
||||
// put the extra context before the FIM prefix
|
||||
embd_inp.insert(embd_inp.begin(), slot.extra_tokens.end() - n_extra_take, slot.extra_tokens.end());
|
||||
|
||||
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
|
||||
embd_inp.push_back(llama_token_fim_mid(model));
|
||||
|
||||
prompt_tokens = std::move(embd_inp);
|
||||
} break;
|
||||
}
|
||||
|
||||
slot.n_past = 0;
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
||||
|
||||
SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
|
||||
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
|
||||
|
||||
// print prompt tokens (for debugging)
|
||||
if (1) {
|
||||
|
@ -2114,7 +1881,7 @@ struct server_context {
|
|||
continue;
|
||||
}
|
||||
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
||||
if (slot.n_prompt_tokens > n_ubatch) {
|
||||
slot.release();
|
||||
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
|
||||
|
@ -2149,7 +1916,7 @@ struct server_context {
|
|||
const int n_block_size = n_left / 2;
|
||||
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||
|
||||
std::vector<llama_token> new_tokens(
|
||||
llama_tokens new_tokens(
|
||||
prompt_tokens.begin(),
|
||||
prompt_tokens.begin() + slot.params.n_keep);
|
||||
|
||||
|
@ -2168,17 +1935,10 @@ struct server_context {
|
|||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||
}
|
||||
|
||||
common_sampler_reset(slot.smpl);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
// reuse any previously computed tokens that are common with the new prompt
|
||||
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
|
||||
|
||||
// push the prompt into the sampling context (do not apply grammar)
|
||||
for (int i = 0; i < slot.n_past; ++i) {
|
||||
common_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
|
||||
}
|
||||
|
||||
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
||||
if (params.n_cache_reuse > 0) {
|
||||
size_t head_c = slot.n_past; // cache
|
||||
|
@ -2211,8 +1971,6 @@ struct server_context {
|
|||
for (size_t i = 0; i < n_match; i++) {
|
||||
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
|
||||
|
||||
common_sampler_accept(slot.smpl, slot.cache_tokens[head_p + i], false);
|
||||
|
||||
slot.n_past++;
|
||||
}
|
||||
|
||||
|
@ -2239,7 +1997,7 @@ struct server_context {
|
|||
}
|
||||
|
||||
// non-causal tasks require to fit the entire prompt in the physical batch
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
||||
// cannot fit the prompt in the current batch - will try next iter
|
||||
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
||||
continue;
|
||||
|
@ -2248,8 +2006,8 @@ struct server_context {
|
|||
|
||||
// check that we are in the right batch_type, if not defer the slot
|
||||
const bool slot_type =
|
||||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
|
||||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
|
||||
slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
|
||||
slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0;
|
||||
|
||||
if (batch_type == -1) {
|
||||
batch_type = slot_type;
|
||||
|
@ -2264,8 +2022,6 @@ struct server_context {
|
|||
|
||||
// there is no common part left
|
||||
slot.n_past = 0;
|
||||
|
||||
common_sampler_reset(slot.smpl);
|
||||
}
|
||||
|
||||
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
|
||||
|
@ -2293,6 +2049,13 @@ struct server_context {
|
|||
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
|
||||
common_sampler_reset(slot.smpl);
|
||||
|
||||
// Process all prompt tokens through sampler system
|
||||
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
|
||||
common_sampler_accept(slot.smpl, prompt_tokens[i], false);
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
|
@ -2362,7 +2125,7 @@ struct server_context {
|
|||
}
|
||||
|
||||
if (slot.state == SLOT_STATE_DONE_PROMPT) {
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
|
||||
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
|
||||
// prompt evaluated for embedding
|
||||
send_embedding(slot, batch_view);
|
||||
slot.release();
|
||||
|
@ -2370,7 +2133,7 @@ struct server_context {
|
|||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
||||
send_rerank(slot, batch_view);
|
||||
slot.release();
|
||||
slot.i_batch = -1;
|
||||
|
@ -2924,13 +2687,13 @@ int main(int argc, char ** argv) {
|
|||
res_ok(res, {{ "success", true }});
|
||||
};
|
||||
|
||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
|
||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
|
||||
if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type);
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
|
@ -2976,10 +2739,11 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||
json data = json::parse(req.body);
|
||||
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
|
||||
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
|
||||
};
|
||||
|
||||
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||
// check model compatibility
|
||||
std::string err;
|
||||
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
||||
err += "prefix token is missing. ";
|
||||
|
@ -2990,14 +2754,42 @@ int main(int argc, char ** argv) {
|
|||
if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
||||
err += "middle token is missing. ";
|
||||
}
|
||||
|
||||
if (!err.empty()) {
|
||||
res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
json data = json::parse(req.body);
|
||||
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
|
||||
|
||||
// validate input
|
||||
if (!data.contains("input_prefix")) {
|
||||
res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||
}
|
||||
|
||||
if (!data.contains("input_suffix")) {
|
||||
res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||
}
|
||||
|
||||
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
|
||||
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
json input_extra = json_value(data, "input_extra", json::array());
|
||||
for (const auto & chunk : input_extra) {
|
||||
// { "text": string, "filename": string }
|
||||
if (!chunk.contains("text") || !chunk.at("text").is_string()) {
|
||||
res_error(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
// filename is optional
|
||||
if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
|
||||
res_error(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
}
|
||||
data["input_extra"] = input_extra; // default to empty array if it's not exist
|
||||
|
||||
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
|
||||
};
|
||||
|
||||
// TODO: maybe merge this function with "handle_completions_generic"
|
||||
|
@ -3009,7 +2801,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL);
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
|
@ -3082,7 +2874,7 @@ int main(int argc, char ** argv) {
|
|||
const bool add_special = json_value(body, "add_special", false);
|
||||
const bool with_pieces = json_value(body, "with_pieces", false);
|
||||
|
||||
std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special, true);
|
||||
llama_tokens tokens = tokenize_mixed(ctx_server.ctx, body.at("content"), add_special, true);
|
||||
|
||||
if (with_pieces) {
|
||||
for (const auto& token : tokens) {
|
||||
|
@ -3119,7 +2911,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
std::string content;
|
||||
if (body.count("tokens") != 0) {
|
||||
const std::vector<llama_token> tokens = body.at("tokens");
|
||||
const llama_tokens tokens = body.at("tokens");
|
||||
content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
|
||||
}
|
||||
|
||||
|
@ -3153,7 +2945,7 @@ int main(int argc, char ** argv) {
|
|||
json responses = json::array();
|
||||
bool error = false;
|
||||
{
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
|
@ -3230,7 +3022,7 @@ int main(int argc, char ** argv) {
|
|||
json responses = json::array();
|
||||
bool error = false;
|
||||
{
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
|
|
36
examples/server/tests/features/infill.feature
Normal file
36
examples/server/tests/features/infill.feature
Normal file
|
@ -0,0 +1,36 @@
|
|||
@llama.cpp
|
||||
@infill
|
||||
Feature: llama.cpp server
|
||||
|
||||
# The current model is made by adding FIM tokens to the existing stories260K
|
||||
# We may want to use a better model in the future, maybe something like SmolLM 360M
|
||||
|
||||
Background: Server startup
|
||||
Given a server listening on localhost:8080
|
||||
And a model file tinyllamas/stories260K-infill.gguf from HF repo ggml-org/models
|
||||
And a model file test-model-infill.gguf
|
||||
And a model alias tinyllama-infill
|
||||
And 42 as server seed
|
||||
And 1024 as batch size
|
||||
And 1024 as ubatch size
|
||||
And 2048 KV cache size
|
||||
And 64 max tokens to predict
|
||||
And 0.0 temperature
|
||||
Then the server is starting
|
||||
Then the server is healthy
|
||||
|
||||
Scenario: Infill without input_extra
|
||||
Given a prompt "Complete this"
|
||||
And an infill input extra none none
|
||||
And an infill input prefix "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_"
|
||||
And an infill input suffix "}\n"
|
||||
And an infill request with no api error
|
||||
Then 64 tokens are predicted matching One|day|she|saw|big|scary|bird
|
||||
|
||||
Scenario: Infill with input_extra
|
||||
Given a prompt "Complete this"
|
||||
And an infill input extra "llama.h" "LLAMA_API int32_t llama_n_threads();\n"
|
||||
And an infill input prefix "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_"
|
||||
And an infill input suffix "}\n"
|
||||
And an infill request with no api error
|
||||
Then 64 tokens are predicted matching cuts|Jimmy|mom|came|into|the|room"
|
|
@ -80,6 +80,11 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
|||
context.lora_file = None
|
||||
context.disable_ctx_shift = False
|
||||
|
||||
# infill
|
||||
context.infill_input_extra = None
|
||||
context.infill_input_suffix = ''
|
||||
context.infill_input_prefix = ''
|
||||
|
||||
context.tasks_result = []
|
||||
context.concurrent_tasks = []
|
||||
context.prompts = []
|
||||
|
@ -291,6 +296,28 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
|
|||
assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
|
||||
|
||||
|
||||
@step('an infill request with {api_error} api error')
|
||||
@async_run_until_complete
|
||||
async def step_request_completion(context, api_error: Literal['raised'] | str):
|
||||
if api_error != 'no':
|
||||
raise ValueError(f'api_error={api_error} is not yet implemented')
|
||||
payload = {
|
||||
"prompt": context.prompts[0],
|
||||
"input_suffix": context.infill_input_suffix,
|
||||
"input_prefix": context.infill_input_prefix,
|
||||
"n_predict": context.n_predict,
|
||||
"seed": context.seed,
|
||||
"temperature": context.temperature,
|
||||
}
|
||||
if context.infill_input_extra is not None:
|
||||
payload['input_extra'] = context.infill_input_extra
|
||||
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
|
||||
async with session.post(f'{context.base_url}/infill',
|
||||
json=payload) as response:
|
||||
assert response.status == 200
|
||||
context.tasks_result = [await response.json()]
|
||||
|
||||
|
||||
@step('{predicted_n:d} tokens are predicted matching {re_content}')
|
||||
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
|
||||
context.completion = context.tasks_result.pop()
|
||||
|
@ -539,6 +566,25 @@ def step_a_prompt_prompt(context, prompt):
|
|||
context.n_prompts = len(context.prompts)
|
||||
|
||||
|
||||
# TODO: allow this to be repeated
|
||||
@step('an infill input extra {filename} {text}')
|
||||
def step_infill_input_extra(context, filename, text):
|
||||
if filename == 'none':
|
||||
context.infill_input_extra = None
|
||||
else:
|
||||
context.infill_input_extra = [{'filename': filename, 'text': text}]
|
||||
|
||||
|
||||
@step('an infill input suffix {text}')
|
||||
def step_infill_input_suffix(context, text):
|
||||
context.infill_input_suffix = text
|
||||
|
||||
|
||||
@step('an infill input prefix {text}')
|
||||
def step_infill_input_prefix(context, text):
|
||||
context.infill_input_prefix = text
|
||||
|
||||
|
||||
@step('{num_prompts:d} prompts {prompt} with seed {seed:d}')
|
||||
def step_many_prompts(context, num_prompts, prompt, seed):
|
||||
if context.seed is None:
|
||||
|
|
|
@ -24,6 +24,22 @@
|
|||
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
using llama_tokens = std::vector<llama_token>;
|
||||
|
||||
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
|
||||
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
|
||||
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
|
||||
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
||||
enum error_type {
|
||||
|
@ -52,9 +68,235 @@ static T json_value(const json & body, const std::string & key, const T & defaul
|
|||
}
|
||||
|
||||
//
|
||||
// chat template utils
|
||||
// tokenizer and input processing utils
|
||||
//
|
||||
|
||||
static bool json_is_array_of_numbers(const json & data) {
|
||||
if (data.is_array()) {
|
||||
for (const auto & e : data) {
|
||||
if (!e.is_number_integer()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// is array having BOTH numbers & strings?
|
||||
static bool json_is_array_of_mixed_numbers_strings(const json & data) {
|
||||
bool seen_string = false;
|
||||
bool seen_number = false;
|
||||
if (data.is_array()) {
|
||||
for (const auto & e : data) {
|
||||
seen_string |= e.is_string();
|
||||
seen_number |= e.is_number_integer();
|
||||
if (seen_number && seen_string) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* this handles 2 cases:
|
||||
* - only string, example: "string"
|
||||
* - mixed string and tokens, example: [12, 34, "string", 56, 78]
|
||||
*/
|
||||
static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
|
||||
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
|
||||
// or the first element of the json_prompt array is a string.
|
||||
llama_tokens prompt_tokens;
|
||||
|
||||
if (json_prompt.is_array()) {
|
||||
bool first = true;
|
||||
for (const auto & p : json_prompt) {
|
||||
if (p.is_string()) {
|
||||
auto s = p.template get<std::string>();
|
||||
|
||||
llama_tokens p;
|
||||
if (first) {
|
||||
p = common_tokenize(ctx, s, add_special, parse_special);
|
||||
first = false;
|
||||
} else {
|
||||
p = common_tokenize(ctx, s, false, parse_special);
|
||||
}
|
||||
|
||||
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
|
||||
} else {
|
||||
if (first) {
|
||||
first = false;
|
||||
}
|
||||
|
||||
prompt_tokens.push_back(p.template get<llama_token>());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto s = json_prompt.template get<std::string>();
|
||||
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
|
||||
}
|
||||
|
||||
return prompt_tokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* break the input "prompt" object into multiple prompt if needed, then tokenize them
|
||||
* this supports these cases:
|
||||
* - "prompt": "string"
|
||||
* - "prompt": [12, 34, 56]
|
||||
* - "prompt": [12, 34, "string", 56, 78]
|
||||
* and multiple prompts (multi-tasks):
|
||||
* - "prompt": ["string1", "string2"]
|
||||
* - "prompt": ["string1", [12, 34, 56]]
|
||||
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
|
||||
*/
|
||||
static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
|
||||
std::vector<llama_tokens> result;
|
||||
if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
|
||||
// string or mixed
|
||||
result.push_back(tokenize_mixed(ctx, json_prompt, add_special, parse_special));
|
||||
} else if (json_is_array_of_numbers(json_prompt)) {
|
||||
// array of tokens
|
||||
result.push_back(json_prompt.get<llama_tokens>());
|
||||
} else if (json_prompt.is_array()) {
|
||||
// array of prompts
|
||||
result.reserve(json_prompt.size());
|
||||
for (const auto & p : json_prompt) {
|
||||
if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) {
|
||||
result.push_back(tokenize_mixed(ctx, p, add_special, parse_special));
|
||||
} else if (json_is_array_of_numbers(p)) {
|
||||
// array of tokens
|
||||
result.push_back(p.get<llama_tokens>());
|
||||
} else {
|
||||
throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// template utils
|
||||
//
|
||||
|
||||
// format rerank task: [BOS]query[EOS][SEP]doc[EOS]
|
||||
static llama_tokens format_rerank(const struct llama_model * model, const llama_tokens & query, const llama_tokens & doc) {
|
||||
llama_tokens result;
|
||||
result.reserve(doc.size() + query.size() + 4);
|
||||
result.push_back(llama_token_bos(model));
|
||||
result.insert(result.end(), query.begin(), query.end());
|
||||
result.push_back(llama_token_eos(model));
|
||||
result.push_back(llama_token_sep(model));
|
||||
result.insert(result.end(), doc.begin(), doc.end());
|
||||
result.push_back(llama_token_eos(model));
|
||||
return result;
|
||||
}
|
||||
|
||||
// format infill task
|
||||
static llama_tokens format_infill(
|
||||
const llama_context * ctx,
|
||||
const json & input_prefix,
|
||||
const json & input_suffix,
|
||||
const json & input_extra,
|
||||
const int n_batch,
|
||||
const int n_predict,
|
||||
const int n_ctx,
|
||||
const bool spm_infill,
|
||||
const llama_tokens & tokens_prompt
|
||||
) {
|
||||
// TODO: optimize this block by reducing memory allocations and movement
|
||||
|
||||
// use FIM repo-level pattern:
|
||||
// ref: https://arxiv.org/pdf/2409.12186
|
||||
//
|
||||
// [FIM_REP]myproject
|
||||
// [FIM_SEP]filename0
|
||||
// extra chunk 0
|
||||
// [FIM_SEP]filename1
|
||||
// extra chunk 1
|
||||
// ...
|
||||
// [FIM_SEP]filename
|
||||
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
|
||||
//
|
||||
llama_tokens extra_tokens;
|
||||
extra_tokens.reserve(n_ctx);
|
||||
|
||||
auto model = llama_get_model(ctx);
|
||||
auto tokens_prefix = tokenize_mixed(ctx, input_prefix, false, false);
|
||||
auto tokens_suffix = tokenize_mixed(ctx, input_suffix, false, false);
|
||||
|
||||
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
|
||||
// TODO: make project name an input
|
||||
static const auto k_fim_repo = common_tokenize(ctx, "myproject\n", false, false);
|
||||
|
||||
extra_tokens.push_back(llama_token_fim_rep(model));
|
||||
extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
|
||||
}
|
||||
for (const auto & chunk : input_extra) {
|
||||
// { "text": string, "filename": string }
|
||||
const std::string text = json_value(chunk, "text", std::string());
|
||||
const std::string filename = json_value(chunk, "filename", std::string("tmp"));
|
||||
|
||||
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
||||
const auto k_fim_file = common_tokenize(ctx, filename + "\n", false, false);
|
||||
|
||||
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model));
|
||||
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
||||
} else {
|
||||
// chunk separator in binary form to avoid confusing the AI
|
||||
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
|
||||
static const auto k_chunk_prefix_tokens = common_tokenize(ctx, k_chunk_prefix_str, false, false);
|
||||
|
||||
extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
|
||||
}
|
||||
|
||||
const auto chunk_tokens = common_tokenize(ctx, text, false, false);
|
||||
extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
|
||||
}
|
||||
|
||||
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
||||
// TODO: current filename
|
||||
static const auto k_fim_file = common_tokenize(ctx, "filename\n", false, false);
|
||||
|
||||
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model));
|
||||
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
||||
}
|
||||
|
||||
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
|
||||
const int n_suffix_take = std::min<int>(tokens_suffix.size(), (n_batch/4));
|
||||
const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4) - 3);
|
||||
|
||||
// fill the rest of the context with extra chunks
|
||||
const int n_extra_take = std::min<int>(std::max<int>(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size());
|
||||
|
||||
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
|
||||
tokens_suffix.resize(n_suffix_take);
|
||||
|
||||
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
|
||||
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
|
||||
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
|
||||
|
||||
auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix;
|
||||
auto embd_end = spm_infill ? tokens_prefix : tokens_suffix;
|
||||
|
||||
if (llama_add_bos_token(model)) {
|
||||
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
||||
}
|
||||
|
||||
SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
|
||||
|
||||
// put the extra context before the FIM prefix
|
||||
embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
|
||||
|
||||
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
|
||||
embd_inp.push_back(llama_token_fim_mid(model));
|
||||
|
||||
return embd_inp;
|
||||
}
|
||||
|
||||
// Format given chat. If tmpl is empty, we take the template from model metadata
|
||||
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
|
||||
std::vector<common_chat_msg> chat;
|
||||
|
@ -229,18 +471,6 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin
|
|||
return std::string::npos;
|
||||
}
|
||||
|
||||
static bool json_is_array_of_numbers(const json & data) {
|
||||
if (data.is_array()) {
|
||||
for (const auto & e : data) {
|
||||
if (!e.is_number()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: reuse llama_detokenize
|
||||
template <class Iter>
|
||||
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
|
||||
|
|
|
@ -1151,8 +1151,8 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
|
|||
void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
|
||||
char * src_ptr = (char *) src->data;
|
||||
char * dst_ptr = (char *) dst;
|
||||
const char * src_ptr = (const char *) src->data;
|
||||
char * dst_ptr = (char *) dst;
|
||||
|
||||
const int64_t ne0 = src->ne[0];
|
||||
const int64_t nb0 = src->nb[0];
|
||||
|
@ -1162,7 +1162,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
|
|||
const enum ggml_type type = src->type;
|
||||
const int64_t ts = ggml_type_size(type);
|
||||
const int64_t bs = ggml_blck_size(type);
|
||||
int64_t i1_diff = i1_high - i1_low;
|
||||
const int64_t i1_diff = i1_high - i1_low;
|
||||
|
||||
const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
|
||||
if (nb0 == ts && nb1 == ts*ne0/bs) {
|
||||
|
@ -1479,13 +1479,18 @@ static void ggml_cuda_op_mul_mat(
|
|||
if (src0_is_contiguous) {
|
||||
dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data;
|
||||
} else {
|
||||
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), ggml_nbytes(src0));
|
||||
// If src0 is not contiguous it will be copied to a temporary buffer.
|
||||
// This buffer needs to be cleared entirely because multiple regions will function as padding.
|
||||
const size_t nbytes_data = ggml_nbytes(src0);
|
||||
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
|
||||
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
|
||||
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));
|
||||
}
|
||||
|
||||
// If src0 is on a temporary compute buffers (partial offloading) there may be some padding that needs to be cleared:
|
||||
// If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
|
||||
if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
|
||||
const int64_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
|
||||
const int64_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
|
||||
const size_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
|
||||
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
|
||||
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream));
|
||||
}
|
||||
|
||||
|
@ -3141,7 +3146,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_ROPE:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_IM2COL:
|
||||
return op->src[0]->type == GGML_TYPE_F16;
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
|
|
|
@ -91,9 +91,9 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
const int64_t OH = is_2D ? dst->ne[2] : 1;
|
||||
const int64_t OW = dst->ne[1];
|
||||
|
||||
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
||||
const int64_t batch = src1->ne[3];
|
||||
const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
|
||||
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
||||
const int64_t batch = src1->ne[is_2D ? 3 : 2];
|
||||
const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
|
||||
|
||||
if(dst->type == GGML_TYPE_F16) {
|
||||
im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
|
||||
|
|
|
@ -8,8 +8,6 @@ void ggml_cuda_op_mul_mat_q(
|
|||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
|
||||
const int64_t nb01 = src0->nb[1];
|
||||
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
const int64_t ne11 = src1->ne[1];
|
||||
GGML_ASSERT(ne10 % QK8_1 == 0);
|
||||
|
@ -17,7 +15,7 @@ void ggml_cuda_op_mul_mat_q(
|
|||
const int64_t ne0 = dst->ne[0];
|
||||
|
||||
const int64_t row_diff = row_high - row_low;
|
||||
const int64_t stride00 = nb01 / ggml_type_size(src0->type);
|
||||
const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
|
||||
|
||||
int id = ggml_cuda_get_device();
|
||||
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
||||
|
|
|
@ -241,6 +241,8 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
||||
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
||||
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
|
||||
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
|
||||
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
||||
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
||||
|
@ -272,6 +274,8 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_SIN,
|
||||
GGML_METAL_KERNEL_TYPE_COS,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||
|
||||
GGML_METAL_KERNEL_TYPE_COUNT
|
||||
};
|
||||
|
@ -685,6 +689,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
||||
|
@ -716,6 +722,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
||||
}
|
||||
|
||||
[metal_library release];
|
||||
|
@ -844,8 +852,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|||
case GGML_OP_IM2COL:
|
||||
return op->src[0]->type == GGML_TYPE_F16;
|
||||
case GGML_OP_POOL_1D:
|
||||
case GGML_OP_POOL_2D:
|
||||
return false;
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ARANGE:
|
||||
|
@ -2545,6 +2553,8 @@ static void ggml_metal_encode_node(
|
|||
} break;
|
||||
case GGML_OP_IM2COL:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
||||
|
@ -2574,30 +2584,54 @@ static void ggml_metal_encode_node(
|
|||
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
||||
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
|
||||
|
||||
const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
|
||||
|
||||
switch (dst->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
|
||||
case GGML_TYPE_F32: {
|
||||
pipeline = (is_gt_mttpt ?
|
||||
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
|
||||
:
|
||||
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline);
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
pipeline = (is_gt_mttpt ?
|
||||
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
|
||||
:
|
||||
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline);
|
||||
} break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
|
||||
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
|
||||
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
|
||||
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
|
||||
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
|
||||
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
|
||||
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
|
||||
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
|
||||
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
|
||||
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
|
||||
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2];
|
||||
[encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
|
||||
[encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
|
||||
[encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
|
||||
[encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
|
||||
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
|
||||
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
|
||||
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
|
||||
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
|
||||
[encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
|
||||
[encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
||||
if (is_gt_mttpt) {
|
||||
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
|
||||
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
|
||||
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
|
||||
|
||||
const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
|
||||
|
||||
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
||||
} else {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_UPSCALE:
|
||||
{
|
||||
|
@ -3001,6 +3035,64 @@ static void ggml_metal_encode_node(
|
|||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_POOL_2D:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);
|
||||
|
||||
const int32_t * opts = dst->op_params;
|
||||
enum ggml_op_pool op = opts[0];
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
switch (src0t) {
|
||||
case GGML_TYPE_F32: {
|
||||
switch(op) {
|
||||
case GGML_OP_POOL_AVG:
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break;
|
||||
case GGML_OP_POOL_MAX:
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break;
|
||||
default: GGML_ASSERT(false && "not implemented");
|
||||
}
|
||||
} break;
|
||||
default: GGML_ASSERT(false && "not implemented");
|
||||
}
|
||||
|
||||
const int32_t k0 = opts[1];
|
||||
const int32_t k1 = opts[2];
|
||||
const int32_t s0 = opts[3];
|
||||
const int32_t s1 = opts[4];
|
||||
const int32_t p0 = opts[5];
|
||||
const int32_t p1 = opts[6];
|
||||
|
||||
const int64_t IH = src0->ne[1];
|
||||
const int64_t IW = src0->ne[0];
|
||||
|
||||
const int64_t N = dst->ne[3];
|
||||
const int64_t OC = dst->ne[2];
|
||||
const int64_t OH = dst->ne[1];
|
||||
const int64_t OW = dst->ne[0];
|
||||
|
||||
const int64_t parallel_elements = N * OC * OH * OW;
|
||||
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
|
||||
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
|
||||
[encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
|
||||
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
|
||||
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
|
||||
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
|
||||
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
|
||||
[encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
|
||||
[encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
|
||||
[encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
|
||||
[encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
|
||||
[encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
||||
|
|
|
@ -1933,6 +1933,85 @@ kernel void kernel_im2col(
|
|||
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
|
||||
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
|
||||
|
||||
typedef void (im2col_ext_t)(
|
||||
device const float * x,
|
||||
device char * dst,
|
||||
constant int32_t & ofs0,
|
||||
constant int32_t & ofs1,
|
||||
constant int32_t & IW,
|
||||
constant int32_t & IH,
|
||||
constant int32_t & CHW,
|
||||
constant int32_t & s0,
|
||||
constant int32_t & s1,
|
||||
constant int32_t & p0,
|
||||
constant int32_t & p1,
|
||||
constant int32_t & d0,
|
||||
constant int32_t & d1,
|
||||
constant int32_t & N,
|
||||
constant int32_t & KH,
|
||||
constant int32_t & KW,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template <typename T>
|
||||
kernel void kernel_im2col_ext(
|
||||
device const float * x,
|
||||
device char * dst,
|
||||
constant int32_t & ofs0,
|
||||
constant int32_t & ofs1,
|
||||
constant int32_t & IW,
|
||||
constant int32_t & IH,
|
||||
constant int32_t & CHW,
|
||||
constant int32_t & s0,
|
||||
constant int32_t & s1,
|
||||
constant int32_t & p0,
|
||||
constant int32_t & p1,
|
||||
constant int32_t & d0,
|
||||
constant int32_t & d1,
|
||||
constant int32_t & N,
|
||||
constant int32_t & KH,
|
||||
constant int32_t & KW,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
|
||||
const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]
|
||||
|
||||
const int32_t d = tgpig[0] / CHW;
|
||||
const int32_t chw = tgpig[0] % CHW;
|
||||
const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
|
||||
const int32_t HW = tgpig[0] % KHW;
|
||||
|
||||
const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0];
|
||||
if (tpitg_0 >= N) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t tpitg_1 = HW / KW;
|
||||
const int32_t tpitg_2 = HW % KW;
|
||||
|
||||
const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
|
||||
const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
|
||||
|
||||
const int32_t offset_dst =
|
||||
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
|
||||
(tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
|
||||
|
||||
device T * pdst = (device T *) (dst);
|
||||
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
pdst[offset_dst] = 0.0f;
|
||||
} else {
|
||||
const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
|
||||
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
|
||||
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
|
||||
|
||||
kernel void kernel_upscale_f32(
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
|
@ -6372,3 +6451,102 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t
|
|||
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
|
||||
|
||||
kernel void kernel_pool_2d_max_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant int32_t & k0,
|
||||
constant int32_t & k1,
|
||||
constant int32_t & s0,
|
||||
constant int32_t & s1,
|
||||
constant int32_t & p0,
|
||||
constant int32_t & p1,
|
||||
constant int64_t & IH,
|
||||
constant int64_t & IW,
|
||||
constant int64_t & OH,
|
||||
constant int64_t & OW,
|
||||
constant int64_t & parallel_elements,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= parallel_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int idx = gid;
|
||||
const int I_HW = IH * IW;
|
||||
const int O_HW = OH * OW;
|
||||
const int nc = idx / O_HW;
|
||||
const int cur_oh = idx % O_HW / OW;
|
||||
const int cur_ow = idx % O_HW % OW;
|
||||
|
||||
device const float * i_ptr = src0 + nc * I_HW;
|
||||
device float * o_ptr = dst + nc * O_HW;
|
||||
|
||||
const int start_h = cur_oh * s1 - p1;
|
||||
const int bh = MAX(0, start_h);
|
||||
const int eh = MIN(IH, start_h + k1);
|
||||
const int start_w = cur_ow * s0 - p0;
|
||||
const int bw = MAX(0, start_w);
|
||||
const int ew = MIN(IW, start_w + k0);
|
||||
|
||||
float res = -INFINITY;
|
||||
|
||||
for (int i = bh; i < eh; i += 1) {
|
||||
for (int j = bw; j < ew; j += 1) {
|
||||
res = MAX(res, i_ptr[i * IW + j]);
|
||||
}
|
||||
}
|
||||
|
||||
o_ptr[cur_oh * OW + cur_ow] = res;
|
||||
}
|
||||
|
||||
kernel void kernel_pool_2d_avg_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant int32_t & k0,
|
||||
constant int32_t & k1,
|
||||
constant int32_t & s0,
|
||||
constant int32_t & s1,
|
||||
constant int32_t & p0,
|
||||
constant int32_t & p1,
|
||||
constant int64_t & IH,
|
||||
constant int64_t & IW,
|
||||
constant int64_t & OH,
|
||||
constant int64_t & OW,
|
||||
constant int64_t & parallel_elements,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= parallel_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int idx = gid;
|
||||
const int I_HW = IH * IW;
|
||||
const int O_HW = OH * OW;
|
||||
const int nc = idx / O_HW;
|
||||
const int cur_oh = idx % O_HW / OW;
|
||||
const int cur_ow = idx % O_HW % OW;
|
||||
|
||||
device const float * i_ptr = src0 + nc * I_HW;
|
||||
device float * o_ptr = dst + nc * O_HW;
|
||||
|
||||
const int start_h = cur_oh * s1 - p1;
|
||||
const int bh = MAX(0, start_h);
|
||||
const int eh = MIN(IH, start_h + k1);
|
||||
const int start_w = cur_ow * s0 - p0;
|
||||
const int bw = MAX(0, start_w);
|
||||
const int ew = MIN(IW, start_w + k0);
|
||||
// const float scale = 1. / ((eh - bh) * (ew - bw));
|
||||
const float scale = 1. / (k0 * k1);
|
||||
|
||||
float res = 0;
|
||||
|
||||
for (int i = bh; i < eh; i += 1) {
|
||||
for (int j = bw; j < ew; j += 1) {
|
||||
float cur = i_ptr[i * IW + j];
|
||||
res += cur * scale;
|
||||
}
|
||||
}
|
||||
|
||||
o_ptr[cur_oh * OW + cur_ow] = res;
|
||||
}
|
||||
|
|
|
@ -3464,7 +3464,7 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
|
|||
|
||||
size_t ggml_nbytes(const struct ggml_tensor * tensor) {
|
||||
size_t nbytes;
|
||||
size_t blck_size = ggml_blck_size(tensor->type);
|
||||
const size_t blck_size = ggml_blck_size(tensor->type);
|
||||
if (blck_size == 1) {
|
||||
nbytes = ggml_type_size(tensor->type);
|
||||
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
|
||||
|
@ -3852,10 +3852,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|||
},
|
||||
};
|
||||
|
||||
for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) {
|
||||
g_state.contexts[i].used = false;
|
||||
}
|
||||
|
||||
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
|
||||
|
||||
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
|
||||
|
|
|
@ -1 +1 @@
|
|||
2327bda7a55ac6b72614ac5ebd5c5a5e02553b9b
|
||||
6dccc647264f5429df2624f36138f601e7ce23e5
|
||||
|
|
|
@ -1650,11 +1650,12 @@ struct test_mul_mat : public test_case {
|
|||
const int64_t m;
|
||||
const int64_t n;
|
||||
const int64_t k;
|
||||
const std::array<int64_t, 2> bs; // dims 3 and 4
|
||||
const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
|
||||
const std::array<int64_t, 2> bs; // dims 3 and 4
|
||||
const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
|
||||
const std::array<int64_t, 4> per; // permutation of dimensions
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR7(type_a, type_b, m, n, k, bs, nr);
|
||||
return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, per);
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
|
@ -1669,17 +1670,44 @@ struct test_mul_mat : public test_case {
|
|||
test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
|
||||
int64_t m = 32, int64_t n = 32, int64_t k = 32,
|
||||
std::array<int64_t, 2> bs = {10, 10},
|
||||
std::array<int64_t, 2> nr = {2, 2})
|
||||
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr) {}
|
||||
std::array<int64_t, 2> nr = {2, 2},
|
||||
std::array<int64_t, 4> per = {0, 1, 2, 3})
|
||||
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
// C^T = A * B^T: (k, m) * (k, n) => (m, n)
|
||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0] , bs[1]);
|
||||
ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
|
||||
ggml_set_param(ctx, a);
|
||||
ggml_set_param(ctx, b);
|
||||
ggml_set_name(a, "a");
|
||||
ggml_set_name(b, "b");
|
||||
ggml_tensor * a;
|
||||
ggml_tensor * b;
|
||||
|
||||
const int npermuted = (per[0] != 0) + (per[1] != 1) + (per[2] != 2) + (per[3] != 3);
|
||||
if (npermuted > 0) {
|
||||
GGML_ASSERT(npermuted == 2);
|
||||
GGML_ASSERT(!ggml_is_quantized(type_a) || per[0] == 0);
|
||||
GGML_ASSERT(!ggml_is_quantized(type_b) || per[0] == 0);
|
||||
|
||||
// Create tensors with the permuted dimensions, then permute them back to the dimensions given by m,n,k.
|
||||
const int64_t ne_a[4] = {k, m, bs[0], bs[1]};
|
||||
const int64_t ne_b[4] = {k, n, bs[0]*nr[0], bs[1]*nr[1]};
|
||||
|
||||
a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]);
|
||||
b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]);
|
||||
ggml_set_param(ctx, a);
|
||||
ggml_set_param(ctx, b);
|
||||
ggml_set_name(a, "a");
|
||||
ggml_set_name(b, "b");
|
||||
|
||||
a = ggml_permute(ctx, a, per[0], per[1], per[2], per[3]);
|
||||
b = ggml_permute(ctx, b, per[0], per[1], per[2], per[3]);
|
||||
ggml_set_name(a, "a_permuted");
|
||||
ggml_set_name(b, "b_permuted");
|
||||
} else {
|
||||
a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
|
||||
b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
|
||||
ggml_set_param(ctx, a);
|
||||
ggml_set_param(ctx, b);
|
||||
ggml_set_name(a, "a");
|
||||
ggml_set_name(b, "b");
|
||||
}
|
||||
|
||||
ggml_tensor * out = ggml_mul_mat(ctx, a, b);
|
||||
ggml_set_name(out, "out");
|
||||
|
@ -3308,13 +3336,49 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
|
||||
// test cases for 1D im2col
|
||||
// im2col 1D
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
|
||||
for (int s0 : {1, 3}) {
|
||||
for (int p0 : {0, 3}) {
|
||||
for (int d0 : {1, 3}) {
|
||||
test_cases.emplace_back(new test_im2col(
|
||||
GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 2, 2, 1}, {3, 2, 2, 1},
|
||||
s0, 0, p0, 0, d0, 0, false));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// im2col 2D
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
|
||||
for (int s0 : {1, 3}) {
|
||||
for (int s1 : {1, 3}) {
|
||||
for (int p0 : {0, 3}) {
|
||||
for (int p1 : {0, 3}) {
|
||||
for (int d0 : {1, 3}) {
|
||||
for (int d1 : {1, 3}) {
|
||||
test_cases.emplace_back(new test_im2col(
|
||||
GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 2, 2}, {3, 3, 2, 2},
|
||||
s0, s1, p0, p1, d0, d1, true));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extra tests for im2col 2D
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true));
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true));
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true));
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 1024}, {3, 3, 2, 1024}, 1, 1, 1, 1, 1, 1, true));
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2048}, {3, 3, 1, 2048}, 1, 1, 1, 1, 1, 1, true));
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2048}, {3, 3, 2, 2048}, 1, 1, 1, 1, 1, 1, true));
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));
|
||||
|
||||
// sycl backend will limit task global_range < MAX_INT
|
||||
// test cases for 2D im2col with large input W and H (occurs in stable-diffusion)
|
||||
|
@ -3442,13 +3506,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
#if 1
|
||||
for (ggml_type type_a : base_types) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));
|
||||
// test cases without permutation
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));
|
||||
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {1, 1}));
|
||||
|
@ -3457,6 +3522,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
|
||||
|
||||
// test cases with permutation
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
|
||||
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
|
||||
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
|
||||
}
|
||||
}
|
||||
for (ggml_type type_a : other_types) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue