move gradient checkpointing code into ggml, new API function:
// build gradient checkpointing backward graph gb for gf using provided checkpoints // gb_tmp will contain original backward graph with rewritten backward process nodes, // but without the second forward pass nodes. GGML_API void ggml_build_backward_gradient_checkpointing( struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, struct ggml_cgraph * gb_tmp, struct ggml_tensor * * checkpoints, int n_checkpoints);
This commit is contained in:
parent
2392b6725b
commit
d487e0531f
4 changed files with 154 additions and 387 deletions
|
@ -623,179 +623,6 @@ void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int6
|
||||||
GGML_ASSERT(tensor->ne[3] == ne3);
|
GGML_ASSERT(tensor->ne[3] == ne3);
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t hash(void * p) {
|
|
||||||
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
|
|
||||||
}
|
|
||||||
|
|
||||||
static size_t hash_find(void * hash_table[], void * p) {
|
|
||||||
size_t h = hash(p);
|
|
||||||
|
|
||||||
// linear probing
|
|
||||||
size_t i = h;
|
|
||||||
while (hash_table[i] != NULL && hash_table[i] != p) {
|
|
||||||
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
|
|
||||||
if (i == h) {
|
|
||||||
// visited all hash table entries -> not found
|
|
||||||
return GGML_GRAPH_HASHTABLE_SIZE;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool hash_insert(void * hash_table[], void * p) {
|
|
||||||
size_t i = hash_find(hash_table, p);
|
|
||||||
|
|
||||||
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
|
||||||
|
|
||||||
if (hash_table[i] == p) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// insert
|
|
||||||
GGML_ASSERT(hash_table[i] == NULL);
|
|
||||||
hash_table[i] = p;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool hash_contains(void * hash_table[], void * p) {
|
|
||||||
size_t i = hash_find(hash_table, p);
|
|
||||||
return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct hash_map {
|
|
||||||
void * keys[GGML_GRAPH_HASHTABLE_SIZE];
|
|
||||||
void * vals[GGML_GRAPH_HASHTABLE_SIZE];
|
|
||||||
};
|
|
||||||
|
|
||||||
struct hash_map * new_hash_map() {
|
|
||||||
struct hash_map * result = new struct hash_map;
|
|
||||||
for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
|
|
||||||
result->keys[i] = NULL;
|
|
||||||
result->vals[i] = NULL;
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
};
|
|
||||||
|
|
||||||
void free_hash_map(struct hash_map * map) {
|
|
||||||
delete map;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor * ggml_recompute_graph_node(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_cgraph * graph,
|
|
||||||
struct hash_map * replacements,
|
|
||||||
struct ggml_tensor * node) {
|
|
||||||
|
|
||||||
if (node == NULL) {
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (node->is_param) {
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!hash_contains(graph->visited_hash_table, node)) {
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
int count_children = 0;
|
|
||||||
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
|
||||||
if (node->src[k]) {
|
|
||||||
++count_children;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (count_children == 0) {
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t i = hash_find(replacements->keys, node);
|
|
||||||
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
|
||||||
if (replacements->keys[i] == node) {
|
|
||||||
return (struct ggml_tensor *) replacements->vals[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
|
|
||||||
|
|
||||||
// insert clone into replacements
|
|
||||||
GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
|
|
||||||
replacements->keys[i] = node;
|
|
||||||
replacements->vals[i] = clone;
|
|
||||||
|
|
||||||
clone->op = node->op;
|
|
||||||
clone->grad = node->grad;
|
|
||||||
clone->is_param = node->is_param;
|
|
||||||
clone->extra = node->extra;
|
|
||||||
for (int k = 0; k < GGML_MAX_DIMS; ++k) {
|
|
||||||
clone->nb[k] = node->nb[k];
|
|
||||||
}
|
|
||||||
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
|
||||||
clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
|
|
||||||
}
|
|
||||||
if (node->view_src != NULL) {
|
|
||||||
// GGML_ASSERT(node->view_src->data != NULL);
|
|
||||||
clone->data = (node->view_src->data == NULL)
|
|
||||||
? NULL // view_src not yet allocated
|
|
||||||
: (char *) node->view_src->data // view_src already allocated
|
|
||||||
+ node->view_offs;
|
|
||||||
clone->view_src = node->view_src;
|
|
||||||
clone->view_offs = node->view_offs;
|
|
||||||
}
|
|
||||||
|
|
||||||
GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
|
|
||||||
GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
|
|
||||||
memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
|
|
||||||
ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
|
|
||||||
|
|
||||||
return clone;
|
|
||||||
};
|
|
||||||
|
|
||||||
void ggml_build_backward_gradient_checkpointing(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_cgraph * gf,
|
|
||||||
struct ggml_cgraph * gb,
|
|
||||||
struct ggml_cgraph * gb_tmp,
|
|
||||||
struct ggml_tensor * * checkpoints,
|
|
||||||
int n_checkpoints) {
|
|
||||||
*gb_tmp = *gf;
|
|
||||||
ggml_build_backward_expand(ctx, gf, gb_tmp, true);
|
|
||||||
|
|
||||||
if (n_checkpoints <= 0) {
|
|
||||||
*gb = *gb_tmp;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct hash_map * replacements = new_hash_map();
|
|
||||||
|
|
||||||
// insert checkpoints in replacements
|
|
||||||
for (int i = 0; i < n_checkpoints; ++i) {
|
|
||||||
size_t k = hash_find(replacements->keys, checkpoints[i]);
|
|
||||||
GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
|
||||||
GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
|
|
||||||
replacements->keys[k] = checkpoints[i];
|
|
||||||
replacements->vals[k] = checkpoints[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
*gb = *gf;
|
|
||||||
// rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
|
|
||||||
// replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
|
|
||||||
// by recomputing them from checkpoints
|
|
||||||
for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
|
|
||||||
struct ggml_tensor * node = gb_tmp->nodes[i];
|
|
||||||
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
|
||||||
// insert new tensors recomputing src, reusing already made replacements,
|
|
||||||
// remember replacements: remember new tensors with mapping from corresponding gf nodes
|
|
||||||
// recurse for input tensors,
|
|
||||||
// unless (i.e. terminating when) input tensors are checkpoints
|
|
||||||
node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
|
|
||||||
}
|
|
||||||
// insert rewritten backward node with replacements made into resulting backward graph gb
|
|
||||||
ggml_build_forward_expand(gb, node);
|
|
||||||
}
|
|
||||||
|
|
||||||
free_hash_map(replacements);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor * llama_build_lora_finetune_graphs(
|
struct ggml_tensor * llama_build_lora_finetune_graphs(
|
||||||
struct my_llama_model * model,
|
struct my_llama_model * model,
|
||||||
struct my_llama_lora * lora,
|
struct my_llama_lora * lora,
|
||||||
|
|
|
@ -451,204 +451,6 @@ void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int6
|
||||||
GGML_ASSERT(tensor->ne[3] == ne3);
|
GGML_ASSERT(tensor->ne[3] == ne3);
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t hash(void * p) {
|
|
||||||
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
|
|
||||||
}
|
|
||||||
|
|
||||||
static size_t hash_find(void * hash_table[], void * p) {
|
|
||||||
size_t h = hash(p);
|
|
||||||
|
|
||||||
// linear probing
|
|
||||||
size_t i = h;
|
|
||||||
while (hash_table[i] != NULL && hash_table[i] != p) {
|
|
||||||
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
|
|
||||||
if (i == h) {
|
|
||||||
// visited all hash table entries -> not found
|
|
||||||
return GGML_GRAPH_HASHTABLE_SIZE;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool hash_insert(void * hash_table[], void * p) {
|
|
||||||
//size_t h = hash(p);
|
|
||||||
size_t i = hash_find(hash_table, p);
|
|
||||||
|
|
||||||
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
|
||||||
|
|
||||||
if (hash_table[i] == p) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// insert
|
|
||||||
GGML_ASSERT(hash_table[i] == NULL);
|
|
||||||
hash_table[i] = p;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool hash_contains(void * hash_table[], void * p) {
|
|
||||||
size_t i = hash_find(hash_table, p);
|
|
||||||
return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct hash_map {
|
|
||||||
void * keys[GGML_GRAPH_HASHTABLE_SIZE];
|
|
||||||
void * vals[GGML_GRAPH_HASHTABLE_SIZE];
|
|
||||||
};
|
|
||||||
//static const size_t HASH_MAP_SIZE = sizeof(struct hash_map);
|
|
||||||
|
|
||||||
struct hash_map * new_hash_map() {
|
|
||||||
struct hash_map * result = new struct hash_map;
|
|
||||||
for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
|
|
||||||
result->keys[i] = NULL;
|
|
||||||
result->vals[i] = NULL;
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
};
|
|
||||||
|
|
||||||
void free_hash_map(struct hash_map * map) {
|
|
||||||
delete map;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool ggml_is_view(struct ggml_tensor * t) {
|
|
||||||
return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
|
|
||||||
t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
|
|
||||||
}
|
|
||||||
|
|
||||||
static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
|
|
||||||
switch (t->op) {
|
|
||||||
case GGML_OP_PERMUTE:
|
|
||||||
case GGML_OP_RESHAPE:
|
|
||||||
case GGML_OP_TRANSPOSE:
|
|
||||||
case GGML_OP_VIEW:
|
|
||||||
return t->src[0];
|
|
||||||
case GGML_OP_CPY:
|
|
||||||
return t->src[1];
|
|
||||||
default:
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
|
|
||||||
struct ggml_tensor * parent = t;
|
|
||||||
do {
|
|
||||||
parent = get_view_parent(parent);
|
|
||||||
} while (ggml_is_view(parent));
|
|
||||||
return parent;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor * ggml_recompute_graph_node(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_cgraph * graph,
|
|
||||||
struct hash_map * replacements,
|
|
||||||
struct ggml_tensor * node) {
|
|
||||||
|
|
||||||
if (node == NULL) {
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (node->is_param) {
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!hash_contains(graph->visited_hash_table, node)) {
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
int count_children = 0;
|
|
||||||
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
|
||||||
if (node->src[k]) {
|
|
||||||
++count_children;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (count_children == 0) {
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t i = hash_find(replacements->keys, node);
|
|
||||||
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
|
||||||
if (replacements->keys[i] == node) {
|
|
||||||
return (struct ggml_tensor *) replacements->vals[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
|
|
||||||
|
|
||||||
// insert clone into replacements
|
|
||||||
GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
|
|
||||||
replacements->keys[i] = node;
|
|
||||||
replacements->vals[i] = clone;
|
|
||||||
|
|
||||||
clone->op = node->op;
|
|
||||||
clone->grad = node->grad;
|
|
||||||
clone->is_param = node->is_param;
|
|
||||||
clone->extra = node->extra;
|
|
||||||
for (int k = 0; k < GGML_MAX_DIMS; ++k) {
|
|
||||||
clone->nb[k] = node->nb[k];
|
|
||||||
}
|
|
||||||
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
|
||||||
clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
|
|
||||||
}
|
|
||||||
if (ggml_is_view(clone)) {
|
|
||||||
struct ggml_tensor * source = get_view_source(clone);
|
|
||||||
GGML_ASSERT(source != NULL);
|
|
||||||
clone->data = source->data;
|
|
||||||
}
|
|
||||||
|
|
||||||
GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
|
|
||||||
GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
|
|
||||||
memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
|
|
||||||
ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
|
|
||||||
|
|
||||||
return clone;
|
|
||||||
};
|
|
||||||
|
|
||||||
void ggml_build_backward_gradient_checkpointing(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_cgraph * gf,
|
|
||||||
struct ggml_cgraph * gb,
|
|
||||||
struct ggml_cgraph * gb_tmp,
|
|
||||||
struct ggml_tensor * * checkpoints,
|
|
||||||
int n_checkpoints) {
|
|
||||||
*gb_tmp = *gf;
|
|
||||||
ggml_build_backward_expand(ctx, gf, gb_tmp, true);
|
|
||||||
|
|
||||||
if (n_checkpoints <= 0) {
|
|
||||||
*gb = *gb_tmp;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct hash_map * replacements = new_hash_map();
|
|
||||||
|
|
||||||
// insert checkpoints in replacements
|
|
||||||
for (int i = 0; i < n_checkpoints; ++i) {
|
|
||||||
size_t k = hash_find(replacements->keys, checkpoints[i]);
|
|
||||||
GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
|
||||||
GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
|
|
||||||
replacements->keys[k] = checkpoints[i];
|
|
||||||
replacements->vals[k] = checkpoints[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
*gb = *gf;
|
|
||||||
// rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
|
|
||||||
// replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
|
|
||||||
// by recomputing them from checkpoints
|
|
||||||
for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
|
|
||||||
struct ggml_tensor * node = gb_tmp->nodes[i];
|
|
||||||
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
|
||||||
// insert new tensors recomputing src, reusing already made replacements,
|
|
||||||
// remember replacements: remember new tensors with mapping from corresponding gf nodes
|
|
||||||
// recurse for input tensors,
|
|
||||||
// unless (i.e. terminating when) input tensors are checkpoints
|
|
||||||
node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
|
|
||||||
}
|
|
||||||
// insert rewritten backward node with replacements made into resulting backward graph gb
|
|
||||||
ggml_build_forward_expand(gb, node);
|
|
||||||
}
|
|
||||||
|
|
||||||
free_hash_map(replacements);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor * llama_build_train_graphs(
|
struct ggml_tensor * llama_build_train_graphs(
|
||||||
struct my_llama_model * model,
|
struct my_llama_model * model,
|
||||||
struct ggml_allocr * alloc,
|
struct ggml_allocr * alloc,
|
||||||
|
@ -794,13 +596,13 @@ struct ggml_tensor * llama_build_train_graphs(
|
||||||
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
|
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
|
||||||
// input gradient
|
// input gradient
|
||||||
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
|
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
|
||||||
GGML_ASSERT(t36->grad->data == NULL && !ggml_is_view(t36->grad));
|
GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
|
||||||
ggml_allocr_alloc(alloc, t36->grad);
|
ggml_allocr_alloc(alloc, t36->grad);
|
||||||
|
|
||||||
// allocating checkpoints in one block to reduce memory fragmentation
|
// allocating checkpoints in one block to reduce memory fragmentation
|
||||||
// note: they will be freed in reverse order
|
// note: they will be freed in reverse order
|
||||||
for (int i = 0; i < (int) checkpoints.size(); ++i) {
|
for (int i = 0; i < (int) checkpoints.size(); ++i) {
|
||||||
if (checkpoints[i]->data == NULL && !ggml_is_view(checkpoints[i])) {
|
if (checkpoints[i]->data == NULL && checkpoints[i]->view_src == NULL) {
|
||||||
ggml_allocr_alloc(alloc, checkpoints[i]);
|
ggml_allocr_alloc(alloc, checkpoints[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
156
ggml.c
156
ggml.c
|
@ -16174,7 +16174,7 @@ static size_t hash(void * p) {
|
||||||
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
|
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool hash_insert(void * hash_table[], void * p) {
|
static size_t hash_find(void * hash_table[], void * p) {
|
||||||
size_t h = hash(p);
|
size_t h = hash(p);
|
||||||
|
|
||||||
// linear probing
|
// linear probing
|
||||||
|
@ -16182,38 +16182,166 @@ static bool hash_insert(void * hash_table[], void * p) {
|
||||||
while (hash_table[i] != NULL && hash_table[i] != p) {
|
while (hash_table[i] != NULL && hash_table[i] != p) {
|
||||||
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
|
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
|
||||||
if (i == h) {
|
if (i == h) {
|
||||||
// hash table is full
|
// visited all hash table entries -> not found
|
||||||
GGML_ASSERT(false);
|
return GGML_GRAPH_HASHTABLE_SIZE;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool hash_insert(void * hash_table[], void * p) {
|
||||||
|
size_t i = hash_find(hash_table, p);
|
||||||
|
|
||||||
|
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
||||||
|
|
||||||
if (hash_table[i] == p) {
|
if (hash_table[i] == p) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// insert
|
// insert
|
||||||
|
GGML_ASSERT(hash_table[i] == NULL);
|
||||||
hash_table[i] = p;
|
hash_table[i] = p;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool hash_contains(void * hash_table[], void * p) {
|
static bool hash_contains(void * hash_table[], void * p) {
|
||||||
size_t h = hash(p);
|
size_t i = hash_find(hash_table, p);
|
||||||
|
return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
|
||||||
|
}
|
||||||
|
|
||||||
// linear probing
|
struct hash_map {
|
||||||
size_t i = h;
|
void * keys[GGML_GRAPH_HASHTABLE_SIZE];
|
||||||
while (hash_table[i] != NULL && hash_table[i] != p) {
|
void * vals[GGML_GRAPH_HASHTABLE_SIZE];
|
||||||
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
|
};
|
||||||
if (i == h) {
|
|
||||||
// hash table is full
|
struct hash_map * new_hash_map() {
|
||||||
return false;
|
struct hash_map * result = malloc(sizeof(struct hash_map));
|
||||||
|
for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
|
||||||
|
result->keys[i] = NULL;
|
||||||
|
result->vals[i] = NULL;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
};
|
||||||
|
|
||||||
|
void free_hash_map(struct hash_map * map) {
|
||||||
|
free(map);
|
||||||
|
}
|
||||||
|
|
||||||
|
// gradient checkpointing
|
||||||
|
|
||||||
|
static struct ggml_tensor * ggml_recompute_graph_node(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_cgraph * graph,
|
||||||
|
struct hash_map * replacements,
|
||||||
|
struct ggml_tensor * node) {
|
||||||
|
|
||||||
|
if (node == NULL) {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->is_param) {
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!hash_contains(graph->visited_hash_table, node)) {
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
int count_children = 0;
|
||||||
|
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
||||||
|
if (node->src[k]) {
|
||||||
|
++count_children;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hash_table[i] == p) {
|
if (count_children == 0) {
|
||||||
return true;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
return false;
|
size_t i = hash_find(replacements->keys, node);
|
||||||
|
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
||||||
|
if (replacements->keys[i] == node) {
|
||||||
|
return (struct ggml_tensor *) replacements->vals[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
|
||||||
|
|
||||||
|
// insert clone into replacements
|
||||||
|
GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
|
||||||
|
replacements->keys[i] = node;
|
||||||
|
replacements->vals[i] = clone;
|
||||||
|
|
||||||
|
clone->op = node->op;
|
||||||
|
clone->grad = node->grad;
|
||||||
|
clone->is_param = node->is_param;
|
||||||
|
clone->extra = node->extra;
|
||||||
|
for (int k = 0; k < GGML_MAX_DIMS; ++k) {
|
||||||
|
clone->nb[k] = node->nb[k];
|
||||||
|
}
|
||||||
|
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
||||||
|
clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
|
||||||
|
}
|
||||||
|
if (node->view_src != NULL) {
|
||||||
|
clone->data = (node->view_src->data == NULL)
|
||||||
|
? NULL // view_src not yet allocated
|
||||||
|
: (char *) node->view_src->data // view_src already allocated
|
||||||
|
+ node->view_offs;
|
||||||
|
clone->view_src = node->view_src;
|
||||||
|
clone->view_offs = node->view_offs;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
|
||||||
|
GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
|
||||||
|
memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
|
||||||
|
ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
|
||||||
|
|
||||||
|
return clone;
|
||||||
|
};
|
||||||
|
|
||||||
|
void ggml_build_backward_gradient_checkpointing(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_cgraph * gf,
|
||||||
|
struct ggml_cgraph * gb,
|
||||||
|
struct ggml_cgraph * gb_tmp,
|
||||||
|
struct ggml_tensor * * checkpoints,
|
||||||
|
int n_checkpoints) {
|
||||||
|
*gb_tmp = *gf;
|
||||||
|
ggml_build_backward_expand(ctx, gf, gb_tmp, true);
|
||||||
|
|
||||||
|
if (n_checkpoints <= 0) {
|
||||||
|
*gb = *gb_tmp;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct hash_map * replacements = new_hash_map();
|
||||||
|
|
||||||
|
// insert checkpoints in replacements
|
||||||
|
for (int i = 0; i < n_checkpoints; ++i) {
|
||||||
|
size_t k = hash_find(replacements->keys, checkpoints[i]);
|
||||||
|
GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
||||||
|
GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
|
||||||
|
replacements->keys[k] = checkpoints[i];
|
||||||
|
replacements->vals[k] = checkpoints[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
*gb = *gf;
|
||||||
|
// rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
|
||||||
|
// replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
|
||||||
|
// by recomputing them from checkpoints
|
||||||
|
for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
|
||||||
|
struct ggml_tensor * node = gb_tmp->nodes[i];
|
||||||
|
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
||||||
|
// insert new tensors recomputing src, reusing already made replacements,
|
||||||
|
// remember replacements: remember new tensors with mapping from corresponding gf nodes
|
||||||
|
// recurse for input tensors,
|
||||||
|
// unless (i.e. terminating when) input tensors are replacments (like checkpoints)
|
||||||
|
node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
|
||||||
|
}
|
||||||
|
// insert rewritten backward node with replacements made into resulting backward graph gb
|
||||||
|
ggml_build_forward_expand(gb, node);
|
||||||
|
}
|
||||||
|
|
||||||
|
free_hash_map(replacements);
|
||||||
}
|
}
|
||||||
|
|
||||||
// functions to change gradients considering the case that input a might be initial gradient with zero value
|
// functions to change gradients considering the case that input a might be initial gradient with zero value
|
||||||
|
|
10
ggml.h
10
ggml.h
|
@ -1664,6 +1664,16 @@ extern "C" {
|
||||||
// dump the graph into a file using the dot format
|
// dump the graph into a file using the dot format
|
||||||
GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
|
GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
|
||||||
|
|
||||||
|
// build gradient checkpointing backward graph gb for gf using provided checkpoints
|
||||||
|
// gb_tmp will contain original backward graph with rewritten backward process nodes,
|
||||||
|
// but without the second forward pass nodes.
|
||||||
|
GGML_API void ggml_build_backward_gradient_checkpointing(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_cgraph * gf,
|
||||||
|
struct ggml_cgraph * gb,
|
||||||
|
struct ggml_cgraph * gb_tmp,
|
||||||
|
struct ggml_tensor * * checkpoints,
|
||||||
|
int n_checkpoints);
|
||||||
//
|
//
|
||||||
// optimization
|
// optimization
|
||||||
//
|
//
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue