common : use common_ prefix for common library functions (#9805)
* common : use common_ prefix for common library functions --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
0e9f760eb1
commit
7eee341bee
45 changed files with 1284 additions and 1284 deletions
|
@ -15,16 +15,16 @@ static void print_usage(int, char ** argv) {
|
|||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
common_params params;
|
||||
|
||||
params.prompt = "Hello my name is";
|
||||
params.n_predict = 32;
|
||||
|
||||
if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) {
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
gpt_init();
|
||||
common_init();
|
||||
|
||||
// number of parallel batches
|
||||
int n_parallel = params.n_parallel;
|
||||
|
@ -39,7 +39,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// initialize the model
|
||||
|
||||
llama_model_params model_params = llama_model_params_from_gpt_params(params);
|
||||
llama_model_params model_params = common_model_params_to_llama(params);
|
||||
|
||||
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
|
||||
|
||||
|
@ -51,13 +51,13 @@ int main(int argc, char ** argv) {
|
|||
// tokenize the prompt
|
||||
|
||||
std::vector<llama_token> tokens_list;
|
||||
tokens_list = ::llama_tokenize(model, params.prompt, true);
|
||||
tokens_list = common_tokenize(model, params.prompt, true);
|
||||
|
||||
const int n_kv_req = tokens_list.size() + (n_predict - tokens_list.size())*n_parallel;
|
||||
|
||||
// initialize the context
|
||||
|
||||
llama_context_params ctx_params = llama_context_params_from_gpt_params(params);
|
||||
llama_context_params ctx_params = common_context_params_to_llama(params);
|
||||
|
||||
ctx_params.n_ctx = n_kv_req;
|
||||
ctx_params.n_batch = std::max(n_predict, n_parallel);
|
||||
|
@ -94,7 +94,7 @@ int main(int argc, char ** argv) {
|
|||
LOG("\n");
|
||||
|
||||
for (auto id : tokens_list) {
|
||||
LOG("%s", llama_token_to_piece(ctx, id).c_str());
|
||||
LOG("%s", common_token_to_piece(ctx, id).c_str());
|
||||
}
|
||||
|
||||
// create a llama_batch
|
||||
|
@ -108,7 +108,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// evaluate the initial prompt
|
||||
for (size_t i = 0; i < tokens_list.size(); ++i) {
|
||||
llama_batch_add(batch, tokens_list[i], i, seq_ids, false);
|
||||
common_batch_add(batch, tokens_list[i], i, seq_ids, false);
|
||||
}
|
||||
GGML_ASSERT(batch.n_tokens == (int) tokens_list.size());
|
||||
|
||||
|
@ -123,8 +123,8 @@ int main(int argc, char ** argv) {
|
|||
decoder_start_token_id = llama_token_bos(model);
|
||||
}
|
||||
|
||||
llama_batch_clear(batch);
|
||||
llama_batch_add(batch, decoder_start_token_id, 0, seq_ids, false);
|
||||
common_batch_clear(batch);
|
||||
common_batch_add(batch, decoder_start_token_id, 0, seq_ids, false);
|
||||
}
|
||||
|
||||
// llama_decode will output logits only for the last token of the prompt
|
||||
|
@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
while (n_cur <= n_predict) {
|
||||
// prepare the next batch
|
||||
llama_batch_clear(batch);
|
||||
common_batch_clear(batch);
|
||||
|
||||
// sample the next token for each parallel sequence / stream
|
||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||
|
@ -185,15 +185,15 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// if there is only one stream, we print immediately to stdout
|
||||
if (n_parallel == 1) {
|
||||
LOG("%s", llama_token_to_piece(ctx, new_token_id).c_str());
|
||||
LOG("%s", common_token_to_piece(ctx, new_token_id).c_str());
|
||||
}
|
||||
|
||||
streams[i] += llama_token_to_piece(ctx, new_token_id);
|
||||
streams[i] += common_token_to_piece(ctx, new_token_id);
|
||||
|
||||
i_batch[i] = batch.n_tokens;
|
||||
|
||||
// push this new token for next evaluation
|
||||
llama_batch_add(batch, new_token_id, n_cur, { i }, true);
|
||||
common_batch_add(batch, new_token_id, n_cur, { i }, true);
|
||||
|
||||
n_decode += 1;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue