activate threading in baby-llama-text

This commit is contained in:
xaedes 2023-05-14 20:58:43 +02:00
parent d9b5268728
commit a703d7a85f
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1199,10 +1199,12 @@ int main(int argc, char ** argv) {
struct llama_context * lctx = llama_init_from_file(fn_model, llama_params);
printf("%s: tokenize training data\n", __func__);
std::vector<llama_token> train_tokens;
if (tokenize_file(lctx, fn_train, train_tokens) < 0) {
fprintf(stderr, "%s: failed to tokenize file '%s'\n", __func__, fn_train);
}
printf("%s: number of training tokens: %d\n", __func__, train_tokens.size());
struct my_llama_model model;
model.hparams.n_vocab = llama_n_vocab(lctx);
@ -1225,7 +1227,7 @@ int main(int argc, char ** argv) {
model.ctx = ggml_init(lcparams);
kv_self.ctx = model.ctx;
printf("init model\n");
printf("%s: init model\n", __func__);
init_model(&model);
set_param_model(&model);
randomize_model(&model, 1337, 0.0f, 1.0f, -1.0f, +1.0f);
@ -1238,6 +1240,8 @@ int main(int argc, char ** argv) {
int n_tokens = model.hparams.n_ctx;
int n_vocab = model.hparams.n_vocab;
printf("%s: begin training\n", __func__);
for (int ex=0; ex<n_examples; ++ex) {
struct ggml_init_params params = {
/*.mem_size =*/ compute_size,
@ -1254,7 +1258,7 @@ int main(int argc, char ** argv) {
int n_past = 0;
ggml_cgraph gf = {};
gf.n_threads = 1;
gf.n_threads = 4;
get_example_targets_batch(ctx0, train_tokens.data(), train_tokens.size(), ex, tokens_input, targets);
@ -1271,9 +1275,12 @@ int main(int argc, char ** argv) {
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
opt_params_adam.print_forward_graph = false;
opt_params_adam.print_backward_graph = false;
opt_params_adam.n_threads = gf.n_threads;
opt_params_adam.adam.n_iter = 16;
opt_params_lbfgs.print_forward_graph = false;
opt_params_lbfgs.print_backward_graph = false;
opt_params_adam.adam.n_iter = 16;
opt_params_lbfgs.n_threads = gf.n_threads;
opt_params_lbfgs.lbfgs.n_iter = 16;
ggml_opt(ctx0, opt_params_adam, e);
// ggml_opt(ctx0, opt_params_lbfgs, e);