add train function using automatic gradient checkpointing backward pass and allocator

This commit is contained in:
xaedes 2023-08-06 23:07:57 +02:00
parent d43af4b543
commit 2bf422eafd
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1,4 +1,5 @@
#include "ggml.h"
#include "ggml-alloc.h"
#include "llama.h"
#include <unordered_map>
#include <vector>
@ -1342,6 +1343,291 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
return inpL;
}
static size_t hash(void * p) {
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
}
static size_t hash_find(void * hash_table[], void * p) {
size_t h = hash(p);
// linear probing
size_t i = h;
while (hash_table[i] != NULL && hash_table[i] != p) {
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
if (i == h) {
// visited all hash table entries -> not found
return GGML_GRAPH_HASHTABLE_SIZE;
}
}
return i;
}
static bool hash_insert(void * hash_table[], void * p) {
size_t h = hash(p);
size_t i = hash_find(hash_table, p);
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
if (hash_table[i] == p) {
return true;
}
// insert
GGML_ASSERT(hash_table[i] == NULL);
hash_table[i] = p;
return false;
}
static bool hash_contains(void * hash_table[], void * p) {
size_t i = hash_find(hash_table, p);
return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
}
struct hash_map {
void * keys[GGML_GRAPH_HASHTABLE_SIZE];
void * vals[GGML_GRAPH_HASHTABLE_SIZE];
};
static const size_t HASH_MAP_SIZE = sizeof(struct hash_map);
struct hash_map * new_hash_map(struct ggml_context * ctx, struct ggml_tensor * * out_buf) {
struct ggml_tensor * buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, HASH_MAP_SIZE);
if (out_buf) {
* out_buf = buf;
}
struct hash_map * result = (struct hash_map *) ((char *) buf->data);
*result = (struct hash_map) {
/*.keys =*/ { NULL },
/*.vals =*/ { NULL },
};
for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
result->keys[i] = NULL;
result->vals[i] = NULL;
}
return result;
};
struct ggml_tensor * ggml_recompute_graph_node(
struct ggml_context * ctx,
struct ggml_cgraph * graph,
struct hash_map * replacements,
struct ggml_tensor * node) {
if (node == NULL) {
return NULL;
}
if (node->is_param) {
return node;
}
if (!hash_contains(graph->visited_hash_table, node)) {
return node;
}
size_t i = hash_find(replacements->keys, node);
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
if (replacements->keys[i] == p) {
return replacements->vals[i];
}
struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
// insert clone into replacements
GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
replacements->keys[i] = node;
replacements->vals[i] = clone;
clone->op = node->op;
clone->grad = node->grad;
clone->is_param = node->is_param;
clone->extra = node->extra;
for (int k = 0; k < GGML_MAX_SRC; ++k) {
clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
}
GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
memcpy(clone->name, node->name, sizeof(node->name));
return clone;
};
void ggml_build_backward_gradient_checkpointing(
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct ggml_cgraph * gb,
struct ggml_cgraph * gb_tmp,
struct ggml_tensor * * checkpoints,
int n_checkpoints) {
*gb_tmp = *gf;
ggml_build_backward_expand(ctx, gf, gb_tmp, true);
if (n_checkpoints <= 0) {
*gb = *gb_tmp;
return;
}
struct hash_map * replacements = new_hash_map(ctx, NULL);
// insert checkpoints in replacements
for (int i = 0; i < n_checkpoints; ++i) {
size_t k = hash_find(replacements->keys, node);
GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
replacements->keys[k] = checkpoints[i];
replacements->vals[k] = checkpoints[i];
}
*gb = *gf;
// rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
// replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
// by recomputing them from checkpoints
for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
struct ggml_tensor * node = gb_tmp->nodes[i];
for (int k = 0; k < GGML_MAX_SRC; ++k) {
// insert new tensors recomputing src, reusing already made replacements,
// remember replacements: remember new tensors with mapping from corresponding gf nodes
// recurse for input tensors,
// unless (i.e. terminating when) input tensors are checkpoints
node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
}
// insert rewritten backward node with replacements made into resulting backward graph gb
ggml_build_forward_expand(gb, node);
}
}
struct ggml_tensor * llama_build_train_graphs(
struct my_llama_model * model,
struct ggml_allocr * alloc,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct ggml_cgraph * gb,
struct ggml_cgraph * gb_tmp,
struct ggml_tensor * * logits,
struct ggml_tensor * tokens_input,
struct ggml_tensor * targets,
const int n_tokens,
const int n_batch,
const bool enable_flash_attn,
const bool enable_checkpointing) {
ggml_set_scratch(ctx, { 0, 0, nullptr, });
const int n_past = 0;
const int N = n_tokens;
const auto & hparams = model->hparams;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_head = hparams.n_head;
const int n_rot = hparams.n_rot;
const int n_ff = get_n_ff(&hparams);
const int rope_mode = 0;
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
struct ggml_tensor * t00 = ggml_reshape_1d(ctx, tokens_input, N*n_batch); assert_shape_1d(t00, N*n_batch);
struct ggml_tensor * t01 = ggml_get_rows(ctx, model->tok_embeddings, t00); assert_shape_2d(t01, n_embd, N*n_batch);
struct ggml_tensor * cur = t01;
std::vector<struct ggml_tensor *> checkpoints;
checkpoints.push_back(cur);
struct ggml_tensor * kv_scale;
if (flash_attn) {
kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head)));
}
for (int il = 0; il < n_layer; ++il) {
struct my_llama_layer & layer = model->layers[il];
struct ggml_tensor * t02 = ggml_rms_norm (ctx, cur, rms_norm_eps); assert_shape_2d(t02, n_embd, N*n_batch);
struct ggml_tensor * t03 = ggml_repeat (ctx, layer.attention_norm, t02); assert_shape_2d(t03, n_embd, N*n_batch);
struct ggml_tensor * t04 = ggml_mul (ctx, t02, t03); assert_shape_2d(t04, n_embd, N*n_batch);
struct ggml_tensor * t05 = ggml_mul_mat (ctx, layer.wq, t04); assert_shape_2d(t05, n_embd, N*n_batch);
struct ggml_tensor * t06 = ggml_reshape_4d (ctx, t05, n_embd/n_head, n_head, N, n_batch); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch);
struct ggml_tensor * t07 = ggml_rope_inplace (ctx, t06, n_past, n_rot, rope_mode, n_ctx); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
struct ggml_tensor * t08 = ggml_mul_mat (ctx, layer.wk, t04); assert_shape_2d(t08, n_embd, N*n_batch);
struct ggml_tensor * t09 = ggml_reshape_4d (ctx, t08, n_embd/n_head, n_head, N, n_batch); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch);
struct ggml_tensor * t10 = ggml_rope_inplace (ctx, t09, n_past, n_rot, rope_mode, n_ctx); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch);
struct ggml_tensor * t11 = ggml_mul_mat (ctx, t04, layer.wv); assert_shape_2d(t11, N*n_batch, n_embd);
struct ggml_tensor * t12 = ggml_reshape_4d (ctx, t11, N, n_batch, n_embd/n_head, n_head); assert_shape_4d(t12, N, n_batch, n_embd/n_head, n_head);
struct ggml_tensor * t13 = ggml_permute (ctx, t07, 0, 2, 1, 3); assert_shape_4d(t13, n_embd/n_head, N, n_head, n_batch);
struct ggml_tensor * t14 = ggml_permute (ctx, t10, 0, 2, 1, 3); assert_shape_4d(t14, n_embd/n_head, N, n_head, n_batch);
struct ggml_tensor * t15 = ggml_permute (ctx, t12, 0, 3, 1, 2); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch);
struct ggml_tensor * t16;
if (enable_flash_attn) {
t16 = ggml_flash_attn(ctx, t13, t14, t15, true); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
} else {
struct ggml_tensor * t16_0 = ggml_mul_mat (ctx, t14, t13); assert_shape_4d(t16_0, N, N, n_head, n_batch);
struct ggml_tensor * t16_1 = ggml_scale_inplace (ctx, t16_0, kv_scale); assert_shape_4d(t16_1, N, N, n_head, n_batch);
struct ggml_tensor * t16_2 = ggml_diag_mask_inf_inplace(ctx, t16_1, n_past); assert_shape_4d(t16_2, N, N, n_head, n_batch);
struct ggml_tensor * t16_3 = ggml_soft_max_inplace (ctx, t16_2); assert_shape_4d(t16_3, N, N, n_head, n_batch);
t16 = ggml_mul_mat(ctx, t15, t16_3); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
}
struct ggml_tensor * t17 = ggml_permute (ctx, t16, 0, 2, 1, 3); assert_shape_4d(t17, n_embd/n_head, n_head, N, n_batch);
struct ggml_tensor * t18 = ggml_cont (ctx, t17); assert_shape_4d(t18, n_embd/n_head, n_head, N, n_batch);
struct ggml_tensor * t19 = ggml_reshape_2d (ctx, t18, n_embd, N*n_batch); assert_shape_2d(t19, n_embd, N*n_batch);
struct ggml_tensor * t20 = ggml_mul_mat (ctx, layer.wo, t19); assert_shape_2d(t20, n_embd, N*n_batch);
struct ggml_tensor * t21 = ggml_add (ctx, t20, cur); assert_shape_2d(t21, n_embd, N*n_batch);
struct ggml_tensor * t22 = ggml_rms_norm (ctx, t21, rms_norm_eps); assert_shape_2d(t22, n_embd, N*n_batch);
struct ggml_tensor * t23 = ggml_repeat (ctx, layer.ffn_norm, t22); assert_shape_2d(t23, n_embd, N*n_batch);
struct ggml_tensor * t24 = ggml_mul (ctx, t23, t22); assert_shape_2d(t24, n_embd, N*n_batch);
struct ggml_tensor * t25 = ggml_mul_mat (ctx, layer.w3, t24); assert_shape_2d(t25, n_ff, N*n_batch);
struct ggml_tensor * t26 = ggml_mul_mat (ctx, layer.w1, t24); assert_shape_2d(t26, n_ff, N*n_batch);
struct ggml_tensor * t27 = ggml_silu (ctx, t26); assert_shape_2d(t27, n_ff, N*n_batch);
struct ggml_tensor * t28 = ggml_mul (ctx, t27, t25); assert_shape_2d(t28, n_ff, N*n_batch);
struct ggml_tensor * t29 = ggml_mul_mat (ctx, layer.w2, t28); assert_shape_2d(t29, n_embd, N*n_batch);
struct ggml_tensor * t30 = ggml_add (ctx, t21, t29); assert_shape_2d(t30, n_embd, N*n_batch);
cur = t30;
checkpoints.push_back(cur);
}
struct ggml_tensor * t31 = ggml_rms_norm (ctx, cur, rms_norm_eps); assert_shape_2d(t31, n_embd, N*n_batch);
struct ggml_tensor * t32 = ggml_repeat (ctx, model->norm, t31); assert_shape_2d(t32, n_embd, N*n_batch);
struct ggml_tensor * t33 = ggml_mul (ctx, t32, t31); assert_shape_2d(t33, n_embd, N*n_batch);
struct ggml_tensor * t34 = ggml_mul_mat (ctx, model->output, t33); assert_shape_2d(t34, n_vocab, N*n_batch);
struct ggml_tensor * t35 = ggml_reshape_3d (ctx, t34, n_vocab, N, n_batch); assert_shape_3d(t35, n_vocab, N, n_batch);
struct ggml_tensor * t36 = ggml_cross_entropy_loss(ctx, t35, targets); assert_shape_1d(t36, 1);
checkpoints.push_back(t31);
checkpoints.push_back(t32);
checkpoints.push_back(t33);
checkpoints.push_back(t34);
checkpoints.push_back(t35);
checkpoints.push_back(t36);
ggml_build_forward_expand(gf, t36);
if (enable_checkpointing) {
ggml_build_backward_gradient_checkpointing(ctx, gf, gb, gb_tmp, checkpoints.data(), (int) checkpoints.size());
} else {
*gb = *gf;
ggml_build_backward_expand(ctx, gf, gb, true);
}
if (alloc) {
// make sure t35 and t36 are not reallocated by inserting new temporary node depending on them
struct ggml_tensor * dep = ggml_scale_inplace(ctx, t35, t36);
int n_nodes_before = gb->n_nodes;
ggml_build_forward_expand(gb, dep);
int n_nodes_after = gb->n_nodes;
GGML_ASSERT(n_nodes_after == n_nodes_before + 1);
ggml_allocr_reset(alloc);
ggml_allocr_alloc_graph(alloc, gb);
// remove the additional node that was insert
gb->nodes[n_nodes_after-1] = NULL;
gb->n_nodes = n_nodes_before;
}
*logits = t35;
return t36;
}
// expand the graph nodes without creating leafs.
struct ggml_tensor * expand(struct ggml_cgraph * g, struct ggml_tensor * t) {
// check if already visited