Merge branch 'master' into concedo_experimental
# Conflicts: # tests/test-grad0.c
This commit is contained in:
commit
f5247be0d7
10 changed files with 5442 additions and 265 deletions
|
@ -37,6 +37,7 @@ else()
|
||||||
add_subdirectory(save-load-state)
|
add_subdirectory(save-load-state)
|
||||||
add_subdirectory(benchmark)
|
add_subdirectory(benchmark)
|
||||||
add_subdirectory(baby-llama)
|
add_subdirectory(baby-llama)
|
||||||
|
add_subdirectory(train-text-from-scratch)
|
||||||
if (LLAMA_METAL)
|
if (LLAMA_METAL)
|
||||||
add_subdirectory(metal)
|
add_subdirectory(metal)
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -79,34 +79,39 @@ struct ggml_tensor * randomize_tensor_normal(
|
||||||
int ndims,
|
int ndims,
|
||||||
const int64_t ne[],
|
const int64_t ne[],
|
||||||
struct random_normal_distribution * rnd) {
|
struct random_normal_distribution * rnd) {
|
||||||
|
float scale = 1.0; // xavier
|
||||||
switch (ndims) {
|
switch (ndims) {
|
||||||
case 1:
|
case 1:
|
||||||
|
scale /= sqrtf(ne[0]);
|
||||||
for (int i0 = 0; i0 < ne[0]; i0++) {
|
for (int i0 = 0; i0 < ne[0]; i0++) {
|
||||||
((float *)tensor->data)[i0] = frand_normal(rnd);
|
((float *)tensor->data)[i0] = scale * frand_normal(rnd);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
|
scale /= sqrtf(ne[0]+ne[1]);
|
||||||
for (int i1 = 0; i1 < ne[1]; i1++) {
|
for (int i1 = 0; i1 < ne[1]; i1++) {
|
||||||
for (int i0 = 0; i0 < ne[0]; i0++) {
|
for (int i0 = 0; i0 < ne[0]; i0++) {
|
||||||
((float *)tensor->data)[i1*ne[0] + i0] = frand_normal(rnd);
|
((float *)tensor->data)[i1*ne[0] + i0] = scale * frand_normal(rnd);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
|
scale /= sqrtf(ne[0]+ne[1]);
|
||||||
for (int i2 = 0; i2 < ne[2]; i2++) {
|
for (int i2 = 0; i2 < ne[2]; i2++) {
|
||||||
for (int i1 = 0; i1 < ne[1]; i1++) {
|
for (int i1 = 0; i1 < ne[1]; i1++) {
|
||||||
for (int i0 = 0; i0 < ne[0]; i0++) {
|
for (int i0 = 0; i0 < ne[0]; i0++) {
|
||||||
((float *)tensor->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand_normal(rnd);
|
((float *)tensor->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = scale * frand_normal(rnd);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
|
scale /= sqrtf(ne[0]+ne[1]);
|
||||||
for (int i3 = 0; i3 < ne[3]; i3++) {
|
for (int i3 = 0; i3 < ne[3]; i3++) {
|
||||||
for (int i2 = 0; i2 < ne[2]; i2++) {
|
for (int i2 = 0; i2 < ne[2]; i2++) {
|
||||||
for (int i1 = 0; i1 < ne[1]; i1++) {
|
for (int i1 = 0; i1 < ne[1]; i1++) {
|
||||||
for (int i0 = 0; i0 < ne[0]; i0++) {
|
for (int i0 = 0; i0 < ne[0]; i0++) {
|
||||||
((float *)tensor->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand_normal(rnd);
|
((float *)tensor->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = scale * frand_normal(rnd);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -148,8 +153,8 @@ struct llama_hparams_lora {
|
||||||
uint32_t n_rot = 64;
|
uint32_t n_rot = 64;
|
||||||
uint32_t n_lora = 64;
|
uint32_t n_lora = 64;
|
||||||
|
|
||||||
bool operator!=(const llama_hparams & other) const {
|
bool operator!=(const llama_hparams_lora & other) const {
|
||||||
return memcmp(this, &other, sizeof(llama_hparams));
|
return memcmp(this, &other, sizeof(llama_hparams_lora)) != 0;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -331,6 +331,13 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
std::vector<llama_token> embd;
|
std::vector<llama_token> embd;
|
||||||
|
|
||||||
|
// do one empty run to warm up the model
|
||||||
|
{
|
||||||
|
const std::vector<llama_token> tmp = { llama_token_bos(), };
|
||||||
|
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
|
||||||
|
llama_reset_timings(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
||||||
// predict
|
// predict
|
||||||
if (embd.size() > 0) {
|
if (embd.size() > 0) {
|
||||||
|
|
4
examples/train-text-from-scratch/CMakeLists.txt
Normal file
4
examples/train-text-from-scratch/CMakeLists.txt
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
set(TARGET train-text-from-scratch)
|
||||||
|
add_executable(${TARGET} train-text-from-scratch.cpp)
|
||||||
|
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
22
examples/train-text-from-scratch/README.md
Normal file
22
examples/train-text-from-scratch/README.md
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
# train-text-from-scratch
|
||||||
|
|
||||||
|
Basic usage instructions:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# get training data
|
||||||
|
wget https://github.com/brunoklein99/deep-learning-notes/blob/master/shakespeare.txt
|
||||||
|
|
||||||
|
# train
|
||||||
|
./bin/train-text-from-scratch \
|
||||||
|
--vocab-model ../models/ggml-vocab.bin \
|
||||||
|
--ctx 64 --embd 256 --head 8 --layer 16 \
|
||||||
|
--checkpoint-in chk-shakespeare-256x16.bin \
|
||||||
|
--checkpoint-out chk-shakespeare-256x16.bin \
|
||||||
|
--model-out ggml-shakespeare-256x16-f32.bin \
|
||||||
|
--train-data "shakespeare.txt" \
|
||||||
|
-t 6 -b 16 -n 32 --seed 1 --adam-iter 16 \
|
||||||
|
--print-details-interval 0 --predict 16 --use-flash
|
||||||
|
|
||||||
|
# predict
|
||||||
|
./bin/main -m ggml-shakespeare-256x16-f32.bin
|
||||||
|
```
|
3399
examples/train-text-from-scratch/train-text-from-scratch.cpp
Normal file
3399
examples/train-text-from-scratch/train-text-from-scratch.cpp
Normal file
File diff suppressed because it is too large
Load diff
127
ggml.h
127
ggml.h
|
@ -296,6 +296,7 @@ extern "C" {
|
||||||
GGML_OP_SUM_ROWS,
|
GGML_OP_SUM_ROWS,
|
||||||
GGML_OP_MEAN,
|
GGML_OP_MEAN,
|
||||||
GGML_OP_REPEAT,
|
GGML_OP_REPEAT,
|
||||||
|
GGML_OP_REPEAT_BACK,
|
||||||
GGML_OP_ABS,
|
GGML_OP_ABS,
|
||||||
GGML_OP_SGN,
|
GGML_OP_SGN,
|
||||||
GGML_OP_NEG,
|
GGML_OP_NEG,
|
||||||
|
@ -309,6 +310,7 @@ extern "C" {
|
||||||
GGML_OP_RMS_NORM_BACK,
|
GGML_OP_RMS_NORM_BACK,
|
||||||
|
|
||||||
GGML_OP_MUL_MAT,
|
GGML_OP_MUL_MAT,
|
||||||
|
GGML_OP_OUT_PROD,
|
||||||
|
|
||||||
GGML_OP_SCALE,
|
GGML_OP_SCALE,
|
||||||
GGML_OP_SET,
|
GGML_OP_SET,
|
||||||
|
@ -324,6 +326,7 @@ extern "C" {
|
||||||
GGML_OP_DIAG_MASK_INF,
|
GGML_OP_DIAG_MASK_INF,
|
||||||
GGML_OP_DIAG_MASK_ZERO,
|
GGML_OP_DIAG_MASK_ZERO,
|
||||||
GGML_OP_SOFT_MAX,
|
GGML_OP_SOFT_MAX,
|
||||||
|
GGML_OP_SOFT_MAX_BACK,
|
||||||
GGML_OP_ROPE,
|
GGML_OP_ROPE,
|
||||||
GGML_OP_ROPE_BACK,
|
GGML_OP_ROPE_BACK,
|
||||||
GGML_OP_ALIBI,
|
GGML_OP_ALIBI,
|
||||||
|
@ -333,10 +336,14 @@ extern "C" {
|
||||||
|
|
||||||
GGML_OP_FLASH_ATTN,
|
GGML_OP_FLASH_ATTN,
|
||||||
GGML_OP_FLASH_FF,
|
GGML_OP_FLASH_FF,
|
||||||
|
GGML_OP_FLASH_ATTN_BACK,
|
||||||
|
|
||||||
GGML_OP_MAP_UNARY,
|
GGML_OP_MAP_UNARY,
|
||||||
GGML_OP_MAP_BINARY,
|
GGML_OP_MAP_BINARY,
|
||||||
|
|
||||||
|
GGML_OP_CROSS_ENTROPY_LOSS,
|
||||||
|
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
|
||||||
|
|
||||||
GGML_OP_COUNT,
|
GGML_OP_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -574,6 +581,11 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_add1_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_acc(
|
GGML_API struct ggml_tensor * ggml_acc(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
@ -645,6 +657,11 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_repeat_back(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_abs(
|
GGML_API struct ggml_tensor * ggml_abs(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
@ -698,14 +715,22 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
// A: m rows, n columns
|
// A: n columns, m rows
|
||||||
// B: p rows, n columns (i.e. we transpose it internally)
|
// B: n columns, p rows (i.e. we transpose it internally)
|
||||||
// result is m columns, p rows
|
// result is m columns, p rows
|
||||||
GGML_API struct ggml_tensor * ggml_mul_mat(
|
GGML_API struct ggml_tensor * ggml_mul_mat(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
// A: m columns, n rows,
|
||||||
|
// B: p columns, n rows,
|
||||||
|
// result is m columns, p rows
|
||||||
|
GGML_API struct ggml_tensor * ggml_out_prod(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
//
|
//
|
||||||
// operations on tensors without backpropagation
|
// operations on tensors without backpropagation
|
||||||
//
|
//
|
||||||
|
@ -916,6 +941,17 @@ extern "C" {
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_soft_max_back(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
// in-place, returns view(a)
|
||||||
|
GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
// rotary position embedding
|
// rotary position embedding
|
||||||
// if mode & 1 == 1, skip n_past elements
|
// if mode & 1 == 1, skip n_past elements
|
||||||
// if mode & 2 == 1, GPT-NeoX style
|
// if mode & 2 == 1, GPT-NeoX style
|
||||||
|
@ -982,6 +1018,14 @@ extern "C" {
|
||||||
struct ggml_tensor * v,
|
struct ggml_tensor * v,
|
||||||
bool masked);
|
bool masked);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * q,
|
||||||
|
struct ggml_tensor * k,
|
||||||
|
struct ggml_tensor * v,
|
||||||
|
struct ggml_tensor * d,
|
||||||
|
bool masked);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_flash_ff(
|
GGML_API struct ggml_tensor * ggml_flash_ff(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
@ -1005,6 +1049,19 @@ extern "C" {
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
ggml_binary_op_f32_t fun);
|
ggml_binary_op_f32_t fun);
|
||||||
|
|
||||||
|
// loss function
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
struct ggml_tensor * c);
|
||||||
|
|
||||||
//
|
//
|
||||||
// automatic differentiation
|
// automatic differentiation
|
||||||
//
|
//
|
||||||
|
@ -1099,6 +1156,8 @@ extern "C" {
|
||||||
struct {
|
struct {
|
||||||
int n_iter;
|
int n_iter;
|
||||||
|
|
||||||
|
float sched; // schedule multiplier (fixed, decay or warmup)
|
||||||
|
float decay; // weight decay for AdamW, use 0.0f to disable
|
||||||
float alpha; // learning rate
|
float alpha; // learning rate
|
||||||
float beta1;
|
float beta1;
|
||||||
float beta2;
|
float beta2;
|
||||||
|
@ -1123,6 +1182,49 @@ extern "C" {
|
||||||
} lbfgs;
|
} lbfgs;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ggml_opt_context {
|
||||||
|
struct ggml_context * ctx;
|
||||||
|
struct ggml_opt_params params;
|
||||||
|
|
||||||
|
int iter;
|
||||||
|
int64_t nx; // number of parameter elements
|
||||||
|
|
||||||
|
bool just_initialized;
|
||||||
|
|
||||||
|
struct {
|
||||||
|
struct ggml_tensor * x; // view of the parameters
|
||||||
|
struct ggml_tensor * g1; // gradient
|
||||||
|
struct ggml_tensor * g2; // gradient squared
|
||||||
|
struct ggml_tensor * m; // first moment
|
||||||
|
struct ggml_tensor * v; // second moment
|
||||||
|
struct ggml_tensor * mh; // first moment hat
|
||||||
|
struct ggml_tensor * vh; // second moment hat
|
||||||
|
struct ggml_tensor * pf; // past function values
|
||||||
|
float fx_best;
|
||||||
|
float fx_prev;
|
||||||
|
int n_no_improvement;
|
||||||
|
} adam;
|
||||||
|
|
||||||
|
struct {
|
||||||
|
struct ggml_tensor * x; // current parameters
|
||||||
|
struct ggml_tensor * xp; // previous parameters
|
||||||
|
struct ggml_tensor * g; // current gradient
|
||||||
|
struct ggml_tensor * gp; // previous gradient
|
||||||
|
struct ggml_tensor * d; // search direction
|
||||||
|
struct ggml_tensor * pf; // past function values
|
||||||
|
struct ggml_tensor * lmal; // the L-BFGS memory alpha
|
||||||
|
struct ggml_tensor * lmys; // the L-BFGS memory ys
|
||||||
|
struct ggml_tensor * lms; // the L-BFGS memory s
|
||||||
|
struct ggml_tensor * lmy; // the L-BFGS memory y
|
||||||
|
float fx_best;
|
||||||
|
float step;
|
||||||
|
int j;
|
||||||
|
int k;
|
||||||
|
int end;
|
||||||
|
int n_no_improvement;
|
||||||
|
} lbfgs;
|
||||||
|
};
|
||||||
|
|
||||||
GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
|
GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
|
||||||
|
|
||||||
// optimize the function defined by the tensor f
|
// optimize the function defined by the tensor f
|
||||||
|
@ -1131,6 +1233,27 @@ extern "C" {
|
||||||
struct ggml_opt_params params,
|
struct ggml_opt_params params,
|
||||||
struct ggml_tensor * f);
|
struct ggml_tensor * f);
|
||||||
|
|
||||||
|
// initialize optimizer context
|
||||||
|
GGML_API void ggml_opt_init(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_opt_context * opt,
|
||||||
|
struct ggml_opt_params params,
|
||||||
|
int64_t nx);
|
||||||
|
|
||||||
|
// continue optimizing the function defined by the tensor f
|
||||||
|
GGML_API enum ggml_opt_result ggml_opt_resume(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_opt_context * opt,
|
||||||
|
struct ggml_tensor * f);
|
||||||
|
|
||||||
|
// continue optimizing the function defined by the tensor f
|
||||||
|
GGML_API enum ggml_opt_result ggml_opt_resume_g(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_opt_context * opt,
|
||||||
|
struct ggml_tensor * f,
|
||||||
|
struct ggml_cgraph * gf,
|
||||||
|
struct ggml_cgraph * gb);
|
||||||
|
|
||||||
//
|
//
|
||||||
// quantization
|
// quantization
|
||||||
//
|
//
|
||||||
|
|
25
llama.cpp
25
llama.cpp
|
@ -1036,6 +1036,12 @@ static void llama_model_load_internal(
|
||||||
case 40: model.type = e_model::MODEL_13B; break;
|
case 40: model.type = e_model::MODEL_13B; break;
|
||||||
case 60: model.type = e_model::MODEL_30B; break;
|
case 60: model.type = e_model::MODEL_30B; break;
|
||||||
case 80: model.type = e_model::MODEL_65B; break;
|
case 80: model.type = e_model::MODEL_65B; break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
if (hparams.n_layer < 32) {
|
||||||
|
model.type = e_model::MODEL_7B;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
}
|
}
|
||||||
|
|
||||||
hparams.n_ctx = n_ctx;
|
hparams.n_ctx = n_ctx;
|
||||||
|
@ -1200,6 +1206,7 @@ static void llama_model_load_internal(
|
||||||
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
|
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
|
||||||
|
|
||||||
(void) vram_scratch;
|
(void) vram_scratch;
|
||||||
|
(void) n_batch;
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
vram_scratch = n_batch * MB;
|
vram_scratch = n_batch * MB;
|
||||||
ggml_cuda_set_scratch_size(vram_scratch);
|
ggml_cuda_set_scratch_size(vram_scratch);
|
||||||
|
@ -1227,6 +1234,7 @@ static void llama_model_load_internal(
|
||||||
model.tensors_by_name.emplace_back(lt.name, lt.ggml_tensor);
|
model.tensors_by_name.emplace_back(lt.name, lt.ggml_tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
(void) tensor_split;
|
||||||
#if defined(GGML_USE_CUBLAS)
|
#if defined(GGML_USE_CUBLAS)
|
||||||
{
|
{
|
||||||
ggml_cuda_set_tensor_split(tensor_split);
|
ggml_cuda_set_tensor_split(tensor_split);
|
||||||
|
@ -2161,6 +2169,10 @@ llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_tok
|
||||||
return -log2f(candidate.p) > *mu;
|
return -log2f(candidate.p) > *mu;
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
if (candidates->size == 0) {
|
||||||
|
candidates->size = 1;
|
||||||
|
}
|
||||||
|
|
||||||
// Normalize the probabilities of the remaining words
|
// Normalize the probabilities of the remaining words
|
||||||
llama_sample_softmax(ctx, candidates);
|
llama_sample_softmax(ctx, candidates);
|
||||||
|
|
||||||
|
@ -3287,6 +3299,19 @@ int llama_n_embd(const struct llama_context * ctx) {
|
||||||
return ctx->model.hparams.n_embd;
|
return ctx->model.hparams.n_embd;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int llama_get_vocab(
|
||||||
|
const struct llama_context * ctx,
|
||||||
|
const char * * strings,
|
||||||
|
float * scores,
|
||||||
|
int capacity) {
|
||||||
|
int n = std::min(capacity, (int) ctx->vocab.id_to_token.size());
|
||||||
|
for (int i = 0; i<n; ++i) {
|
||||||
|
strings[i] = ctx->vocab.id_to_token[i].tok.c_str();
|
||||||
|
scores[i] = ctx->vocab.id_to_token[i].score;
|
||||||
|
}
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
|
||||||
float * llama_get_logits(struct llama_context * ctx) {
|
float * llama_get_logits(struct llama_context * ctx) {
|
||||||
return ctx->logits.data();
|
return ctx->logits.data();
|
||||||
}
|
}
|
||||||
|
|
8
llama.h
8
llama.h
|
@ -220,6 +220,14 @@ extern "C" {
|
||||||
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
|
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
|
||||||
LLAMA_API int llama_n_embd (const struct llama_context * ctx);
|
LLAMA_API int llama_n_embd (const struct llama_context * ctx);
|
||||||
|
|
||||||
|
// Get the vocabulary as output parameters.
|
||||||
|
// Returns number of results.
|
||||||
|
LLAMA_API int llama_get_vocab(
|
||||||
|
const struct llama_context * ctx,
|
||||||
|
const char * * strings,
|
||||||
|
float * scores,
|
||||||
|
int capacity);
|
||||||
|
|
||||||
// Token logits obtained from the last call to llama_eval()
|
// Token logits obtained from the last call to llama_eval()
|
||||||
// The logits for the last token are stored in the last row
|
// The logits for the last token are stored in the last row
|
||||||
// Can be mutated in order to change the probabilities of the next token
|
// Can be mutated in order to change the probabilities of the next token
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue