activate threading in baby-llama-text
This commit is contained in:
parent
d9b5268728
commit
a703d7a85f
1 changed files with 10 additions and 3 deletions
|
@ -1199,10 +1199,12 @@ int main(int argc, char ** argv) {
|
|||
|
||||
struct llama_context * lctx = llama_init_from_file(fn_model, llama_params);
|
||||
|
||||
printf("%s: tokenize training data\n", __func__);
|
||||
std::vector<llama_token> train_tokens;
|
||||
if (tokenize_file(lctx, fn_train, train_tokens) < 0) {
|
||||
fprintf(stderr, "%s: failed to tokenize file '%s'\n", __func__, fn_train);
|
||||
}
|
||||
printf("%s: number of training tokens: %d\n", __func__, train_tokens.size());
|
||||
|
||||
struct my_llama_model model;
|
||||
model.hparams.n_vocab = llama_n_vocab(lctx);
|
||||
|
@ -1225,7 +1227,7 @@ int main(int argc, char ** argv) {
|
|||
model.ctx = ggml_init(lcparams);
|
||||
kv_self.ctx = model.ctx;
|
||||
|
||||
printf("init model\n");
|
||||
printf("%s: init model\n", __func__);
|
||||
init_model(&model);
|
||||
set_param_model(&model);
|
||||
randomize_model(&model, 1337, 0.0f, 1.0f, -1.0f, +1.0f);
|
||||
|
@ -1238,6 +1240,8 @@ int main(int argc, char ** argv) {
|
|||
int n_tokens = model.hparams.n_ctx;
|
||||
int n_vocab = model.hparams.n_vocab;
|
||||
|
||||
printf("%s: begin training\n", __func__);
|
||||
|
||||
for (int ex=0; ex<n_examples; ++ex) {
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ compute_size,
|
||||
|
@ -1254,7 +1258,7 @@ int main(int argc, char ** argv) {
|
|||
int n_past = 0;
|
||||
|
||||
ggml_cgraph gf = {};
|
||||
gf.n_threads = 1;
|
||||
gf.n_threads = 4;
|
||||
|
||||
get_example_targets_batch(ctx0, train_tokens.data(), train_tokens.size(), ex, tokens_input, targets);
|
||||
|
||||
|
@ -1271,9 +1275,12 @@ int main(int argc, char ** argv) {
|
|||
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
|
||||
opt_params_adam.print_forward_graph = false;
|
||||
opt_params_adam.print_backward_graph = false;
|
||||
opt_params_adam.n_threads = gf.n_threads;
|
||||
opt_params_adam.adam.n_iter = 16;
|
||||
|
||||
opt_params_lbfgs.print_forward_graph = false;
|
||||
opt_params_lbfgs.print_backward_graph = false;
|
||||
opt_params_adam.adam.n_iter = 16;
|
||||
opt_params_lbfgs.n_threads = gf.n_threads;
|
||||
opt_params_lbfgs.lbfgs.n_iter = 16;
|
||||
ggml_opt(ctx0, opt_params_adam, e);
|
||||
// ggml_opt(ctx0, opt_params_lbfgs, e);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue