diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 70fcdc5de..76e6ace64 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1,4 +1,5 @@ #include "ggml.h" +#include "ggml-alloc.h" #include "llama.h" #include #include @@ -1342,6 +1343,291 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn( return inpL; } + +static size_t hash(void * p) { + return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE; +} + +static size_t hash_find(void * hash_table[], void * p) { + size_t h = hash(p); + + // linear probing + size_t i = h; + while (hash_table[i] != NULL && hash_table[i] != p) { + i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE; + if (i == h) { + // visited all hash table entries -> not found + return GGML_GRAPH_HASHTABLE_SIZE; + } + } + return i; +} + +static bool hash_insert(void * hash_table[], void * p) { + size_t h = hash(p); + size_t i = hash_find(hash_table, p); + + GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full + + if (hash_table[i] == p) { + return true; + } + + // insert + GGML_ASSERT(hash_table[i] == NULL); + hash_table[i] = p; + return false; +} + +static bool hash_contains(void * hash_table[], void * p) { + size_t i = hash_find(hash_table, p); + return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p); +} + +struct hash_map { + void * keys[GGML_GRAPH_HASHTABLE_SIZE]; + void * vals[GGML_GRAPH_HASHTABLE_SIZE]; +}; +static const size_t HASH_MAP_SIZE = sizeof(struct hash_map); + +struct hash_map * new_hash_map(struct ggml_context * ctx, struct ggml_tensor * * out_buf) { + struct ggml_tensor * buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, HASH_MAP_SIZE); + if (out_buf) { + * out_buf = buf; + } + struct hash_map * result = (struct hash_map *) ((char *) buf->data); + *result = (struct hash_map) { + /*.keys =*/ { NULL }, + /*.vals =*/ { NULL }, + }; + for (int i=0; ikeys[i] = NULL; + result->vals[i] = NULL; + } + return result; +}; + +struct ggml_tensor * ggml_recompute_graph_node( + struct ggml_context * ctx, + struct ggml_cgraph * graph, + struct hash_map * replacements, + struct ggml_tensor * node) { + + if (node == NULL) { + return NULL; + } + + if (node->is_param) { + return node; + } + + if (!hash_contains(graph->visited_hash_table, node)) { + return node; + } + + size_t i = hash_find(replacements->keys, node); + GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full + if (replacements->keys[i] == p) { + return replacements->vals[i]; + } + + struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne); + + // insert clone into replacements + GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite + replacements->keys[i] = node; + replacements->vals[i] = clone; + + clone->op = node->op; + clone->grad = node->grad; + clone->is_param = node->is_param; + clone->extra = node->extra; + for (int k = 0; k < GGML_MAX_SRC; ++k) { + clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]); + } + + GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t))); + GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME); + memcpy(clone->op_params, node->op_params, sizeof(node->op_params)); + memcpy(clone->name, node->name, sizeof(node->name)); + + return clone; +}; + +void ggml_build_backward_gradient_checkpointing( + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + struct ggml_cgraph * gb_tmp, + struct ggml_tensor * * checkpoints, + int n_checkpoints) { + *gb_tmp = *gf; + ggml_build_backward_expand(ctx, gf, gb_tmp, true); + + if (n_checkpoints <= 0) { + *gb = *gb_tmp; + return; + } + + struct hash_map * replacements = new_hash_map(ctx, NULL); + + // insert checkpoints in replacements + for (int i = 0; i < n_checkpoints; ++i) { + size_t k = hash_find(replacements->keys, node); + GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full + GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite + replacements->keys[k] = checkpoints[i]; + replacements->vals[k] = checkpoints[i]; + } + + *gb = *gf; + // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes], + // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]), + // by recomputing them from checkpoints + for (int i = gf->n_nodes; in_nodes; ++i) { + struct ggml_tensor * node = gb_tmp->nodes[i]; + for (int k = 0; k < GGML_MAX_SRC; ++k) { + // insert new tensors recomputing src, reusing already made replacements, + // remember replacements: remember new tensors with mapping from corresponding gf nodes + // recurse for input tensors, + // unless (i.e. terminating when) input tensors are checkpoints + node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]); + } + // insert rewritten backward node with replacements made into resulting backward graph gb + ggml_build_forward_expand(gb, node); + } +} + +struct ggml_tensor * llama_build_train_graphs( + struct my_llama_model * model, + struct ggml_allocr * alloc, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + struct ggml_cgraph * gb_tmp, + struct ggml_tensor * * logits, + struct ggml_tensor * tokens_input, + struct ggml_tensor * targets, + const int n_tokens, + const int n_batch, + const bool enable_flash_attn, + const bool enable_checkpointing) { + + ggml_set_scratch(ctx, { 0, 0, nullptr, }); + const int n_past = 0; + const int N = n_tokens; + const auto & hparams = model->hparams; + const int n_ctx = hparams.n_ctx; + const int n_vocab = hparams.n_vocab; + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_head = hparams.n_head; + const int n_rot = hparams.n_rot; + const int n_ff = get_n_ff(&hparams); + const int rope_mode = 0; + + GGML_ASSERT(tokens_input->type == GGML_TYPE_I32); + struct ggml_tensor * t00 = ggml_reshape_1d(ctx, tokens_input, N*n_batch); assert_shape_1d(t00, N*n_batch); + struct ggml_tensor * t01 = ggml_get_rows(ctx, model->tok_embeddings, t00); assert_shape_2d(t01, n_embd, N*n_batch); + + struct ggml_tensor * cur = t01; + + std::vector checkpoints; + checkpoints.push_back(cur); + + struct ggml_tensor * kv_scale; + if (flash_attn) { + kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head))); + } + + for (int il = 0; il < n_layer; ++il) { + struct my_llama_layer & layer = model->layers[il]; + struct ggml_tensor * t02 = ggml_rms_norm (ctx, cur, rms_norm_eps); assert_shape_2d(t02, n_embd, N*n_batch); + struct ggml_tensor * t03 = ggml_repeat (ctx, layer.attention_norm, t02); assert_shape_2d(t03, n_embd, N*n_batch); + struct ggml_tensor * t04 = ggml_mul (ctx, t02, t03); assert_shape_2d(t04, n_embd, N*n_batch); + struct ggml_tensor * t05 = ggml_mul_mat (ctx, layer.wq, t04); assert_shape_2d(t05, n_embd, N*n_batch); + struct ggml_tensor * t06 = ggml_reshape_4d (ctx, t05, n_embd/n_head, n_head, N, n_batch); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch); + struct ggml_tensor * t07 = ggml_rope_inplace (ctx, t06, n_past, n_rot, rope_mode, n_ctx); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch); + struct ggml_tensor * t08 = ggml_mul_mat (ctx, layer.wk, t04); assert_shape_2d(t08, n_embd, N*n_batch); + struct ggml_tensor * t09 = ggml_reshape_4d (ctx, t08, n_embd/n_head, n_head, N, n_batch); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch); + struct ggml_tensor * t10 = ggml_rope_inplace (ctx, t09, n_past, n_rot, rope_mode, n_ctx); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch); + struct ggml_tensor * t11 = ggml_mul_mat (ctx, t04, layer.wv); assert_shape_2d(t11, N*n_batch, n_embd); + struct ggml_tensor * t12 = ggml_reshape_4d (ctx, t11, N, n_batch, n_embd/n_head, n_head); assert_shape_4d(t12, N, n_batch, n_embd/n_head, n_head); + struct ggml_tensor * t13 = ggml_permute (ctx, t07, 0, 2, 1, 3); assert_shape_4d(t13, n_embd/n_head, N, n_head, n_batch); + struct ggml_tensor * t14 = ggml_permute (ctx, t10, 0, 2, 1, 3); assert_shape_4d(t14, n_embd/n_head, N, n_head, n_batch); + struct ggml_tensor * t15 = ggml_permute (ctx, t12, 0, 3, 1, 2); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch); + struct ggml_tensor * t16; + if (enable_flash_attn) { + t16 = ggml_flash_attn(ctx, t13, t14, t15, true); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch); + } else { + struct ggml_tensor * t16_0 = ggml_mul_mat (ctx, t14, t13); assert_shape_4d(t16_0, N, N, n_head, n_batch); + struct ggml_tensor * t16_1 = ggml_scale_inplace (ctx, t16_0, kv_scale); assert_shape_4d(t16_1, N, N, n_head, n_batch); + struct ggml_tensor * t16_2 = ggml_diag_mask_inf_inplace(ctx, t16_1, n_past); assert_shape_4d(t16_2, N, N, n_head, n_batch); + struct ggml_tensor * t16_3 = ggml_soft_max_inplace (ctx, t16_2); assert_shape_4d(t16_3, N, N, n_head, n_batch); + t16 = ggml_mul_mat(ctx, t15, t16_3); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch); + } + struct ggml_tensor * t17 = ggml_permute (ctx, t16, 0, 2, 1, 3); assert_shape_4d(t17, n_embd/n_head, n_head, N, n_batch); + struct ggml_tensor * t18 = ggml_cont (ctx, t17); assert_shape_4d(t18, n_embd/n_head, n_head, N, n_batch); + struct ggml_tensor * t19 = ggml_reshape_2d (ctx, t18, n_embd, N*n_batch); assert_shape_2d(t19, n_embd, N*n_batch); + struct ggml_tensor * t20 = ggml_mul_mat (ctx, layer.wo, t19); assert_shape_2d(t20, n_embd, N*n_batch); + struct ggml_tensor * t21 = ggml_add (ctx, t20, cur); assert_shape_2d(t21, n_embd, N*n_batch); + struct ggml_tensor * t22 = ggml_rms_norm (ctx, t21, rms_norm_eps); assert_shape_2d(t22, n_embd, N*n_batch); + struct ggml_tensor * t23 = ggml_repeat (ctx, layer.ffn_norm, t22); assert_shape_2d(t23, n_embd, N*n_batch); + struct ggml_tensor * t24 = ggml_mul (ctx, t23, t22); assert_shape_2d(t24, n_embd, N*n_batch); + struct ggml_tensor * t25 = ggml_mul_mat (ctx, layer.w3, t24); assert_shape_2d(t25, n_ff, N*n_batch); + struct ggml_tensor * t26 = ggml_mul_mat (ctx, layer.w1, t24); assert_shape_2d(t26, n_ff, N*n_batch); + struct ggml_tensor * t27 = ggml_silu (ctx, t26); assert_shape_2d(t27, n_ff, N*n_batch); + struct ggml_tensor * t28 = ggml_mul (ctx, t27, t25); assert_shape_2d(t28, n_ff, N*n_batch); + struct ggml_tensor * t29 = ggml_mul_mat (ctx, layer.w2, t28); assert_shape_2d(t29, n_embd, N*n_batch); + struct ggml_tensor * t30 = ggml_add (ctx, t21, t29); assert_shape_2d(t30, n_embd, N*n_batch); + cur = t30; + checkpoints.push_back(cur); + } + struct ggml_tensor * t31 = ggml_rms_norm (ctx, cur, rms_norm_eps); assert_shape_2d(t31, n_embd, N*n_batch); + struct ggml_tensor * t32 = ggml_repeat (ctx, model->norm, t31); assert_shape_2d(t32, n_embd, N*n_batch); + struct ggml_tensor * t33 = ggml_mul (ctx, t32, t31); assert_shape_2d(t33, n_embd, N*n_batch); + struct ggml_tensor * t34 = ggml_mul_mat (ctx, model->output, t33); assert_shape_2d(t34, n_vocab, N*n_batch); + struct ggml_tensor * t35 = ggml_reshape_3d (ctx, t34, n_vocab, N, n_batch); assert_shape_3d(t35, n_vocab, N, n_batch); + struct ggml_tensor * t36 = ggml_cross_entropy_loss(ctx, t35, targets); assert_shape_1d(t36, 1); + + checkpoints.push_back(t31); + checkpoints.push_back(t32); + checkpoints.push_back(t33); + checkpoints.push_back(t34); + checkpoints.push_back(t35); + checkpoints.push_back(t36); + + ggml_build_forward_expand(gf, t36); + + if (enable_checkpointing) { + ggml_build_backward_gradient_checkpointing(ctx, gf, gb, gb_tmp, checkpoints.data(), (int) checkpoints.size()); + } else { + *gb = *gf; + ggml_build_backward_expand(ctx, gf, gb, true); + } + + if (alloc) { + // make sure t35 and t36 are not reallocated by inserting new temporary node depending on them + struct ggml_tensor * dep = ggml_scale_inplace(ctx, t35, t36); + int n_nodes_before = gb->n_nodes; + ggml_build_forward_expand(gb, dep); + + int n_nodes_after = gb->n_nodes; + GGML_ASSERT(n_nodes_after == n_nodes_before + 1); + + ggml_allocr_reset(alloc); + ggml_allocr_alloc_graph(alloc, gb); + + // remove the additional node that was insert + gb->nodes[n_nodes_after-1] = NULL; + gb->n_nodes = n_nodes_before; + } + + *logits = t35; + return t36; +} + + // expand the graph nodes without creating leafs. struct ggml_tensor * expand(struct ggml_cgraph * g, struct ggml_tensor * t) { // check if already visited