move gradient checkpointing code into ggml, new API function:
// build gradient checkpointing backward graph gb for gf using provided checkpoints // gb_tmp will contain original backward graph with rewritten backward process nodes, // but without the second forward pass nodes. GGML_API void ggml_build_backward_gradient_checkpointing( struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, struct ggml_cgraph * gb_tmp, struct ggml_tensor * * checkpoints, int n_checkpoints);
This commit is contained in:
parent
2392b6725b
commit
d487e0531f
4 changed files with 154 additions and 387 deletions
|
@ -623,179 +623,6 @@ void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int6
|
|||
GGML_ASSERT(tensor->ne[3] == ne3);
|
||||
}
|
||||
|
||||
static size_t hash(void * p) {
|
||||
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
|
||||
}
|
||||
|
||||
static size_t hash_find(void * hash_table[], void * p) {
|
||||
size_t h = hash(p);
|
||||
|
||||
// linear probing
|
||||
size_t i = h;
|
||||
while (hash_table[i] != NULL && hash_table[i] != p) {
|
||||
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
|
||||
if (i == h) {
|
||||
// visited all hash table entries -> not found
|
||||
return GGML_GRAPH_HASHTABLE_SIZE;
|
||||
}
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
static bool hash_insert(void * hash_table[], void * p) {
|
||||
size_t i = hash_find(hash_table, p);
|
||||
|
||||
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
||||
|
||||
if (hash_table[i] == p) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// insert
|
||||
GGML_ASSERT(hash_table[i] == NULL);
|
||||
hash_table[i] = p;
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool hash_contains(void * hash_table[], void * p) {
|
||||
size_t i = hash_find(hash_table, p);
|
||||
return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
|
||||
}
|
||||
|
||||
struct hash_map {
|
||||
void * keys[GGML_GRAPH_HASHTABLE_SIZE];
|
||||
void * vals[GGML_GRAPH_HASHTABLE_SIZE];
|
||||
};
|
||||
|
||||
struct hash_map * new_hash_map() {
|
||||
struct hash_map * result = new struct hash_map;
|
||||
for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
|
||||
result->keys[i] = NULL;
|
||||
result->vals[i] = NULL;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
void free_hash_map(struct hash_map * map) {
|
||||
delete map;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_recompute_graph_node(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * graph,
|
||||
struct hash_map * replacements,
|
||||
struct ggml_tensor * node) {
|
||||
|
||||
if (node == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (node->is_param) {
|
||||
return node;
|
||||
}
|
||||
|
||||
if (!hash_contains(graph->visited_hash_table, node)) {
|
||||
return node;
|
||||
}
|
||||
|
||||
int count_children = 0;
|
||||
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
||||
if (node->src[k]) {
|
||||
++count_children;
|
||||
}
|
||||
}
|
||||
|
||||
if (count_children == 0) {
|
||||
return node;
|
||||
}
|
||||
|
||||
size_t i = hash_find(replacements->keys, node);
|
||||
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
||||
if (replacements->keys[i] == node) {
|
||||
return (struct ggml_tensor *) replacements->vals[i];
|
||||
}
|
||||
|
||||
struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
|
||||
|
||||
// insert clone into replacements
|
||||
GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
|
||||
replacements->keys[i] = node;
|
||||
replacements->vals[i] = clone;
|
||||
|
||||
clone->op = node->op;
|
||||
clone->grad = node->grad;
|
||||
clone->is_param = node->is_param;
|
||||
clone->extra = node->extra;
|
||||
for (int k = 0; k < GGML_MAX_DIMS; ++k) {
|
||||
clone->nb[k] = node->nb[k];
|
||||
}
|
||||
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
||||
clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
|
||||
}
|
||||
if (node->view_src != NULL) {
|
||||
// GGML_ASSERT(node->view_src->data != NULL);
|
||||
clone->data = (node->view_src->data == NULL)
|
||||
? NULL // view_src not yet allocated
|
||||
: (char *) node->view_src->data // view_src already allocated
|
||||
+ node->view_offs;
|
||||
clone->view_src = node->view_src;
|
||||
clone->view_offs = node->view_offs;
|
||||
}
|
||||
|
||||
GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
|
||||
GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
|
||||
memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
|
||||
ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
|
||||
|
||||
return clone;
|
||||
};
|
||||
|
||||
void ggml_build_backward_gradient_checkpointing(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_cgraph * gb,
|
||||
struct ggml_cgraph * gb_tmp,
|
||||
struct ggml_tensor * * checkpoints,
|
||||
int n_checkpoints) {
|
||||
*gb_tmp = *gf;
|
||||
ggml_build_backward_expand(ctx, gf, gb_tmp, true);
|
||||
|
||||
if (n_checkpoints <= 0) {
|
||||
*gb = *gb_tmp;
|
||||
return;
|
||||
}
|
||||
|
||||
struct hash_map * replacements = new_hash_map();
|
||||
|
||||
// insert checkpoints in replacements
|
||||
for (int i = 0; i < n_checkpoints; ++i) {
|
||||
size_t k = hash_find(replacements->keys, checkpoints[i]);
|
||||
GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
||||
GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
|
||||
replacements->keys[k] = checkpoints[i];
|
||||
replacements->vals[k] = checkpoints[i];
|
||||
}
|
||||
|
||||
*gb = *gf;
|
||||
// rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
|
||||
// replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
|
||||
// by recomputing them from checkpoints
|
||||
for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
|
||||
struct ggml_tensor * node = gb_tmp->nodes[i];
|
||||
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
||||
// insert new tensors recomputing src, reusing already made replacements,
|
||||
// remember replacements: remember new tensors with mapping from corresponding gf nodes
|
||||
// recurse for input tensors,
|
||||
// unless (i.e. terminating when) input tensors are checkpoints
|
||||
node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
|
||||
}
|
||||
// insert rewritten backward node with replacements made into resulting backward graph gb
|
||||
ggml_build_forward_expand(gb, node);
|
||||
}
|
||||
|
||||
free_hash_map(replacements);
|
||||
}
|
||||
|
||||
struct ggml_tensor * llama_build_lora_finetune_graphs(
|
||||
struct my_llama_model * model,
|
||||
struct my_llama_lora * lora,
|
||||
|
|
|
@ -451,204 +451,6 @@ void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int6
|
|||
GGML_ASSERT(tensor->ne[3] == ne3);
|
||||
}
|
||||
|
||||
static size_t hash(void * p) {
|
||||
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
|
||||
}
|
||||
|
||||
static size_t hash_find(void * hash_table[], void * p) {
|
||||
size_t h = hash(p);
|
||||
|
||||
// linear probing
|
||||
size_t i = h;
|
||||
while (hash_table[i] != NULL && hash_table[i] != p) {
|
||||
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
|
||||
if (i == h) {
|
||||
// visited all hash table entries -> not found
|
||||
return GGML_GRAPH_HASHTABLE_SIZE;
|
||||
}
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
static bool hash_insert(void * hash_table[], void * p) {
|
||||
//size_t h = hash(p);
|
||||
size_t i = hash_find(hash_table, p);
|
||||
|
||||
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
||||
|
||||
if (hash_table[i] == p) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// insert
|
||||
GGML_ASSERT(hash_table[i] == NULL);
|
||||
hash_table[i] = p;
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool hash_contains(void * hash_table[], void * p) {
|
||||
size_t i = hash_find(hash_table, p);
|
||||
return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
|
||||
}
|
||||
|
||||
struct hash_map {
|
||||
void * keys[GGML_GRAPH_HASHTABLE_SIZE];
|
||||
void * vals[GGML_GRAPH_HASHTABLE_SIZE];
|
||||
};
|
||||
//static const size_t HASH_MAP_SIZE = sizeof(struct hash_map);
|
||||
|
||||
struct hash_map * new_hash_map() {
|
||||
struct hash_map * result = new struct hash_map;
|
||||
for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
|
||||
result->keys[i] = NULL;
|
||||
result->vals[i] = NULL;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
void free_hash_map(struct hash_map * map) {
|
||||
delete map;
|
||||
}
|
||||
|
||||
static bool ggml_is_view(struct ggml_tensor * t) {
|
||||
return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
|
||||
t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
|
||||
}
|
||||
|
||||
static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
|
||||
switch (t->op) {
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_VIEW:
|
||||
return t->src[0];
|
||||
case GGML_OP_CPY:
|
||||
return t->src[1];
|
||||
default:
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
|
||||
struct ggml_tensor * parent = t;
|
||||
do {
|
||||
parent = get_view_parent(parent);
|
||||
} while (ggml_is_view(parent));
|
||||
return parent;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_recompute_graph_node(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * graph,
|
||||
struct hash_map * replacements,
|
||||
struct ggml_tensor * node) {
|
||||
|
||||
if (node == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (node->is_param) {
|
||||
return node;
|
||||
}
|
||||
|
||||
if (!hash_contains(graph->visited_hash_table, node)) {
|
||||
return node;
|
||||
}
|
||||
|
||||
int count_children = 0;
|
||||
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
||||
if (node->src[k]) {
|
||||
++count_children;
|
||||
}
|
||||
}
|
||||
|
||||
if (count_children == 0) {
|
||||
return node;
|
||||
}
|
||||
|
||||
size_t i = hash_find(replacements->keys, node);
|
||||
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
||||
if (replacements->keys[i] == node) {
|
||||
return (struct ggml_tensor *) replacements->vals[i];
|
||||
}
|
||||
|
||||
struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
|
||||
|
||||
// insert clone into replacements
|
||||
GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
|
||||
replacements->keys[i] = node;
|
||||
replacements->vals[i] = clone;
|
||||
|
||||
clone->op = node->op;
|
||||
clone->grad = node->grad;
|
||||
clone->is_param = node->is_param;
|
||||
clone->extra = node->extra;
|
||||
for (int k = 0; k < GGML_MAX_DIMS; ++k) {
|
||||
clone->nb[k] = node->nb[k];
|
||||
}
|
||||
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
||||
clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
|
||||
}
|
||||
if (ggml_is_view(clone)) {
|
||||
struct ggml_tensor * source = get_view_source(clone);
|
||||
GGML_ASSERT(source != NULL);
|
||||
clone->data = source->data;
|
||||
}
|
||||
|
||||
GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
|
||||
GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
|
||||
memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
|
||||
ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
|
||||
|
||||
return clone;
|
||||
};
|
||||
|
||||
void ggml_build_backward_gradient_checkpointing(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_cgraph * gb,
|
||||
struct ggml_cgraph * gb_tmp,
|
||||
struct ggml_tensor * * checkpoints,
|
||||
int n_checkpoints) {
|
||||
*gb_tmp = *gf;
|
||||
ggml_build_backward_expand(ctx, gf, gb_tmp, true);
|
||||
|
||||
if (n_checkpoints <= 0) {
|
||||
*gb = *gb_tmp;
|
||||
return;
|
||||
}
|
||||
|
||||
struct hash_map * replacements = new_hash_map();
|
||||
|
||||
// insert checkpoints in replacements
|
||||
for (int i = 0; i < n_checkpoints; ++i) {
|
||||
size_t k = hash_find(replacements->keys, checkpoints[i]);
|
||||
GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
||||
GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
|
||||
replacements->keys[k] = checkpoints[i];
|
||||
replacements->vals[k] = checkpoints[i];
|
||||
}
|
||||
|
||||
*gb = *gf;
|
||||
// rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
|
||||
// replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
|
||||
// by recomputing them from checkpoints
|
||||
for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
|
||||
struct ggml_tensor * node = gb_tmp->nodes[i];
|
||||
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
||||
// insert new tensors recomputing src, reusing already made replacements,
|
||||
// remember replacements: remember new tensors with mapping from corresponding gf nodes
|
||||
// recurse for input tensors,
|
||||
// unless (i.e. terminating when) input tensors are checkpoints
|
||||
node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
|
||||
}
|
||||
// insert rewritten backward node with replacements made into resulting backward graph gb
|
||||
ggml_build_forward_expand(gb, node);
|
||||
}
|
||||
|
||||
free_hash_map(replacements);
|
||||
}
|
||||
|
||||
struct ggml_tensor * llama_build_train_graphs(
|
||||
struct my_llama_model * model,
|
||||
struct ggml_allocr * alloc,
|
||||
|
@ -794,13 +596,13 @@ struct ggml_tensor * llama_build_train_graphs(
|
|||
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
|
||||
// input gradient
|
||||
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
|
||||
GGML_ASSERT(t36->grad->data == NULL && !ggml_is_view(t36->grad));
|
||||
GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
|
||||
ggml_allocr_alloc(alloc, t36->grad);
|
||||
|
||||
// allocating checkpoints in one block to reduce memory fragmentation
|
||||
// note: they will be freed in reverse order
|
||||
for (int i = 0; i < (int) checkpoints.size(); ++i) {
|
||||
if (checkpoints[i]->data == NULL && !ggml_is_view(checkpoints[i])) {
|
||||
if (checkpoints[i]->data == NULL && checkpoints[i]->view_src == NULL) {
|
||||
ggml_allocr_alloc(alloc, checkpoints[i]);
|
||||
}
|
||||
}
|
||||
|
|
156
ggml.c
156
ggml.c
|
@ -16174,7 +16174,7 @@ static size_t hash(void * p) {
|
|||
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
|
||||
}
|
||||
|
||||
static bool hash_insert(void * hash_table[], void * p) {
|
||||
static size_t hash_find(void * hash_table[], void * p) {
|
||||
size_t h = hash(p);
|
||||
|
||||
// linear probing
|
||||
|
@ -16182,38 +16182,166 @@ static bool hash_insert(void * hash_table[], void * p) {
|
|||
while (hash_table[i] != NULL && hash_table[i] != p) {
|
||||
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
|
||||
if (i == h) {
|
||||
// hash table is full
|
||||
GGML_ASSERT(false);
|
||||
// visited all hash table entries -> not found
|
||||
return GGML_GRAPH_HASHTABLE_SIZE;
|
||||
}
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
static bool hash_insert(void * hash_table[], void * p) {
|
||||
size_t i = hash_find(hash_table, p);
|
||||
|
||||
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
||||
|
||||
if (hash_table[i] == p) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// insert
|
||||
GGML_ASSERT(hash_table[i] == NULL);
|
||||
hash_table[i] = p;
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool hash_contains(void * hash_table[], void * p) {
|
||||
size_t h = hash(p);
|
||||
size_t i = hash_find(hash_table, p);
|
||||
return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
|
||||
}
|
||||
|
||||
// linear probing
|
||||
size_t i = h;
|
||||
while (hash_table[i] != NULL && hash_table[i] != p) {
|
||||
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
|
||||
if (i == h) {
|
||||
// hash table is full
|
||||
return false;
|
||||
struct hash_map {
|
||||
void * keys[GGML_GRAPH_HASHTABLE_SIZE];
|
||||
void * vals[GGML_GRAPH_HASHTABLE_SIZE];
|
||||
};
|
||||
|
||||
struct hash_map * new_hash_map() {
|
||||
struct hash_map * result = malloc(sizeof(struct hash_map));
|
||||
for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
|
||||
result->keys[i] = NULL;
|
||||
result->vals[i] = NULL;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
void free_hash_map(struct hash_map * map) {
|
||||
free(map);
|
||||
}
|
||||
|
||||
// gradient checkpointing
|
||||
|
||||
static struct ggml_tensor * ggml_recompute_graph_node(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * graph,
|
||||
struct hash_map * replacements,
|
||||
struct ggml_tensor * node) {
|
||||
|
||||
if (node == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (node->is_param) {
|
||||
return node;
|
||||
}
|
||||
|
||||
if (!hash_contains(graph->visited_hash_table, node)) {
|
||||
return node;
|
||||
}
|
||||
|
||||
int count_children = 0;
|
||||
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
||||
if (node->src[k]) {
|
||||
++count_children;
|
||||
}
|
||||
}
|
||||
|
||||
if (hash_table[i] == p) {
|
||||
return true;
|
||||
if (count_children == 0) {
|
||||
return node;
|
||||
}
|
||||
|
||||
return false;
|
||||
size_t i = hash_find(replacements->keys, node);
|
||||
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
||||
if (replacements->keys[i] == node) {
|
||||
return (struct ggml_tensor *) replacements->vals[i];
|
||||
}
|
||||
|
||||
struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
|
||||
|
||||
// insert clone into replacements
|
||||
GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
|
||||
replacements->keys[i] = node;
|
||||
replacements->vals[i] = clone;
|
||||
|
||||
clone->op = node->op;
|
||||
clone->grad = node->grad;
|
||||
clone->is_param = node->is_param;
|
||||
clone->extra = node->extra;
|
||||
for (int k = 0; k < GGML_MAX_DIMS; ++k) {
|
||||
clone->nb[k] = node->nb[k];
|
||||
}
|
||||
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
||||
clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
|
||||
}
|
||||
if (node->view_src != NULL) {
|
||||
clone->data = (node->view_src->data == NULL)
|
||||
? NULL // view_src not yet allocated
|
||||
: (char *) node->view_src->data // view_src already allocated
|
||||
+ node->view_offs;
|
||||
clone->view_src = node->view_src;
|
||||
clone->view_offs = node->view_offs;
|
||||
}
|
||||
|
||||
GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
|
||||
GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
|
||||
memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
|
||||
ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
|
||||
|
||||
return clone;
|
||||
};
|
||||
|
||||
void ggml_build_backward_gradient_checkpointing(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_cgraph * gb,
|
||||
struct ggml_cgraph * gb_tmp,
|
||||
struct ggml_tensor * * checkpoints,
|
||||
int n_checkpoints) {
|
||||
*gb_tmp = *gf;
|
||||
ggml_build_backward_expand(ctx, gf, gb_tmp, true);
|
||||
|
||||
if (n_checkpoints <= 0) {
|
||||
*gb = *gb_tmp;
|
||||
return;
|
||||
}
|
||||
|
||||
struct hash_map * replacements = new_hash_map();
|
||||
|
||||
// insert checkpoints in replacements
|
||||
for (int i = 0; i < n_checkpoints; ++i) {
|
||||
size_t k = hash_find(replacements->keys, checkpoints[i]);
|
||||
GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
|
||||
GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
|
||||
replacements->keys[k] = checkpoints[i];
|
||||
replacements->vals[k] = checkpoints[i];
|
||||
}
|
||||
|
||||
*gb = *gf;
|
||||
// rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
|
||||
// replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
|
||||
// by recomputing them from checkpoints
|
||||
for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
|
||||
struct ggml_tensor * node = gb_tmp->nodes[i];
|
||||
for (int k = 0; k < GGML_MAX_SRC; ++k) {
|
||||
// insert new tensors recomputing src, reusing already made replacements,
|
||||
// remember replacements: remember new tensors with mapping from corresponding gf nodes
|
||||
// recurse for input tensors,
|
||||
// unless (i.e. terminating when) input tensors are replacments (like checkpoints)
|
||||
node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
|
||||
}
|
||||
// insert rewritten backward node with replacements made into resulting backward graph gb
|
||||
ggml_build_forward_expand(gb, node);
|
||||
}
|
||||
|
||||
free_hash_map(replacements);
|
||||
}
|
||||
|
||||
// functions to change gradients considering the case that input a might be initial gradient with zero value
|
||||
|
|
10
ggml.h
10
ggml.h
|
@ -1664,6 +1664,16 @@ extern "C" {
|
|||
// dump the graph into a file using the dot format
|
||||
GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
|
||||
|
||||
// build gradient checkpointing backward graph gb for gf using provided checkpoints
|
||||
// gb_tmp will contain original backward graph with rewritten backward process nodes,
|
||||
// but without the second forward pass nodes.
|
||||
GGML_API void ggml_build_backward_gradient_checkpointing(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_cgraph * gb,
|
||||
struct ggml_cgraph * gb_tmp,
|
||||
struct ggml_tensor * * checkpoints,
|
||||
int n_checkpoints);
|
||||
//
|
||||
// optimization
|
||||
//
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue