move gradient checkpointing code into ggml, new API function:

// build gradient checkpointing backward graph gb for gf using provided checkpoints
// gb_tmp will contain original backward graph with rewritten backward process nodes,
// but without the second forward pass nodes.
GGML_API void ggml_build_backward_gradient_checkpointing(
        struct ggml_context   * ctx,
        struct ggml_cgraph    * gf,
        struct ggml_cgraph    * gb,
        struct ggml_cgraph    * gb_tmp,
        struct ggml_tensor  * * checkpoints,
        int                     n_checkpoints);
This commit is contained in:
xaedes 2023-08-30 15:21:10 +02:00
parent 2392b6725b
commit d487e0531f
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
4 changed files with 154 additions and 387 deletions

View file

@ -623,179 +623,6 @@ void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int6
GGML_ASSERT(tensor->ne[3] == ne3); GGML_ASSERT(tensor->ne[3] == ne3);
} }
static size_t hash(void * p) {
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
}
static size_t hash_find(void * hash_table[], void * p) {
size_t h = hash(p);
// linear probing
size_t i = h;
while (hash_table[i] != NULL && hash_table[i] != p) {
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
if (i == h) {
// visited all hash table entries -> not found
return GGML_GRAPH_HASHTABLE_SIZE;
}
}
return i;
}
static bool hash_insert(void * hash_table[], void * p) {
size_t i = hash_find(hash_table, p);
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
if (hash_table[i] == p) {
return true;
}
// insert
GGML_ASSERT(hash_table[i] == NULL);
hash_table[i] = p;
return false;
}
static bool hash_contains(void * hash_table[], void * p) {
size_t i = hash_find(hash_table, p);
return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
}
struct hash_map {
void * keys[GGML_GRAPH_HASHTABLE_SIZE];
void * vals[GGML_GRAPH_HASHTABLE_SIZE];
};
struct hash_map * new_hash_map() {
struct hash_map * result = new struct hash_map;
for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
result->keys[i] = NULL;
result->vals[i] = NULL;
}
return result;
};
void free_hash_map(struct hash_map * map) {
delete map;
}
struct ggml_tensor * ggml_recompute_graph_node(
struct ggml_context * ctx,
struct ggml_cgraph * graph,
struct hash_map * replacements,
struct ggml_tensor * node) {
if (node == NULL) {
return NULL;
}
if (node->is_param) {
return node;
}
if (!hash_contains(graph->visited_hash_table, node)) {
return node;
}
int count_children = 0;
for (int k = 0; k < GGML_MAX_SRC; ++k) {
if (node->src[k]) {
++count_children;
}
}
if (count_children == 0) {
return node;
}
size_t i = hash_find(replacements->keys, node);
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
if (replacements->keys[i] == node) {
return (struct ggml_tensor *) replacements->vals[i];
}
struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
// insert clone into replacements
GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
replacements->keys[i] = node;
replacements->vals[i] = clone;
clone->op = node->op;
clone->grad = node->grad;
clone->is_param = node->is_param;
clone->extra = node->extra;
for (int k = 0; k < GGML_MAX_DIMS; ++k) {
clone->nb[k] = node->nb[k];
}
for (int k = 0; k < GGML_MAX_SRC; ++k) {
clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
}
if (node->view_src != NULL) {
// GGML_ASSERT(node->view_src->data != NULL);
clone->data = (node->view_src->data == NULL)
? NULL // view_src not yet allocated
: (char *) node->view_src->data // view_src already allocated
+ node->view_offs;
clone->view_src = node->view_src;
clone->view_offs = node->view_offs;
}
GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
return clone;
};
void ggml_build_backward_gradient_checkpointing(
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct ggml_cgraph * gb,
struct ggml_cgraph * gb_tmp,
struct ggml_tensor * * checkpoints,
int n_checkpoints) {
*gb_tmp = *gf;
ggml_build_backward_expand(ctx, gf, gb_tmp, true);
if (n_checkpoints <= 0) {
*gb = *gb_tmp;
return;
}
struct hash_map * replacements = new_hash_map();
// insert checkpoints in replacements
for (int i = 0; i < n_checkpoints; ++i) {
size_t k = hash_find(replacements->keys, checkpoints[i]);
GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
replacements->keys[k] = checkpoints[i];
replacements->vals[k] = checkpoints[i];
}
*gb = *gf;
// rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
// replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
// by recomputing them from checkpoints
for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
struct ggml_tensor * node = gb_tmp->nodes[i];
for (int k = 0; k < GGML_MAX_SRC; ++k) {
// insert new tensors recomputing src, reusing already made replacements,
// remember replacements: remember new tensors with mapping from corresponding gf nodes
// recurse for input tensors,
// unless (i.e. terminating when) input tensors are checkpoints
node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
}
// insert rewritten backward node with replacements made into resulting backward graph gb
ggml_build_forward_expand(gb, node);
}
free_hash_map(replacements);
}
struct ggml_tensor * llama_build_lora_finetune_graphs( struct ggml_tensor * llama_build_lora_finetune_graphs(
struct my_llama_model * model, struct my_llama_model * model,
struct my_llama_lora * lora, struct my_llama_lora * lora,

View file

@ -451,204 +451,6 @@ void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int6
GGML_ASSERT(tensor->ne[3] == ne3); GGML_ASSERT(tensor->ne[3] == ne3);
} }
static size_t hash(void * p) {
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
}
static size_t hash_find(void * hash_table[], void * p) {
size_t h = hash(p);
// linear probing
size_t i = h;
while (hash_table[i] != NULL && hash_table[i] != p) {
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
if (i == h) {
// visited all hash table entries -> not found
return GGML_GRAPH_HASHTABLE_SIZE;
}
}
return i;
}
static bool hash_insert(void * hash_table[], void * p) {
//size_t h = hash(p);
size_t i = hash_find(hash_table, p);
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
if (hash_table[i] == p) {
return true;
}
// insert
GGML_ASSERT(hash_table[i] == NULL);
hash_table[i] = p;
return false;
}
static bool hash_contains(void * hash_table[], void * p) {
size_t i = hash_find(hash_table, p);
return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
}
struct hash_map {
void * keys[GGML_GRAPH_HASHTABLE_SIZE];
void * vals[GGML_GRAPH_HASHTABLE_SIZE];
};
//static const size_t HASH_MAP_SIZE = sizeof(struct hash_map);
struct hash_map * new_hash_map() {
struct hash_map * result = new struct hash_map;
for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
result->keys[i] = NULL;
result->vals[i] = NULL;
}
return result;
};
void free_hash_map(struct hash_map * map) {
delete map;
}
static bool ggml_is_view(struct ggml_tensor * t) {
return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
}
static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
switch (t->op) {
case GGML_OP_PERMUTE:
case GGML_OP_RESHAPE:
case GGML_OP_TRANSPOSE:
case GGML_OP_VIEW:
return t->src[0];
case GGML_OP_CPY:
return t->src[1];
default:
return NULL;
}
}
static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
struct ggml_tensor * parent = t;
do {
parent = get_view_parent(parent);
} while (ggml_is_view(parent));
return parent;
}
struct ggml_tensor * ggml_recompute_graph_node(
struct ggml_context * ctx,
struct ggml_cgraph * graph,
struct hash_map * replacements,
struct ggml_tensor * node) {
if (node == NULL) {
return NULL;
}
if (node->is_param) {
return node;
}
if (!hash_contains(graph->visited_hash_table, node)) {
return node;
}
int count_children = 0;
for (int k = 0; k < GGML_MAX_SRC; ++k) {
if (node->src[k]) {
++count_children;
}
}
if (count_children == 0) {
return node;
}
size_t i = hash_find(replacements->keys, node);
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
if (replacements->keys[i] == node) {
return (struct ggml_tensor *) replacements->vals[i];
}
struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
// insert clone into replacements
GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
replacements->keys[i] = node;
replacements->vals[i] = clone;
clone->op = node->op;
clone->grad = node->grad;
clone->is_param = node->is_param;
clone->extra = node->extra;
for (int k = 0; k < GGML_MAX_DIMS; ++k) {
clone->nb[k] = node->nb[k];
}
for (int k = 0; k < GGML_MAX_SRC; ++k) {
clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
}
if (ggml_is_view(clone)) {
struct ggml_tensor * source = get_view_source(clone);
GGML_ASSERT(source != NULL);
clone->data = source->data;
}
GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
return clone;
};
void ggml_build_backward_gradient_checkpointing(
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct ggml_cgraph * gb,
struct ggml_cgraph * gb_tmp,
struct ggml_tensor * * checkpoints,
int n_checkpoints) {
*gb_tmp = *gf;
ggml_build_backward_expand(ctx, gf, gb_tmp, true);
if (n_checkpoints <= 0) {
*gb = *gb_tmp;
return;
}
struct hash_map * replacements = new_hash_map();
// insert checkpoints in replacements
for (int i = 0; i < n_checkpoints; ++i) {
size_t k = hash_find(replacements->keys, checkpoints[i]);
GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
replacements->keys[k] = checkpoints[i];
replacements->vals[k] = checkpoints[i];
}
*gb = *gf;
// rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
// replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
// by recomputing them from checkpoints
for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
struct ggml_tensor * node = gb_tmp->nodes[i];
for (int k = 0; k < GGML_MAX_SRC; ++k) {
// insert new tensors recomputing src, reusing already made replacements,
// remember replacements: remember new tensors with mapping from corresponding gf nodes
// recurse for input tensors,
// unless (i.e. terminating when) input tensors are checkpoints
node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
}
// insert rewritten backward node with replacements made into resulting backward graph gb
ggml_build_forward_expand(gb, node);
}
free_hash_map(replacements);
}
struct ggml_tensor * llama_build_train_graphs( struct ggml_tensor * llama_build_train_graphs(
struct my_llama_model * model, struct my_llama_model * model,
struct ggml_allocr * alloc, struct ggml_allocr * alloc,
@ -794,13 +596,13 @@ struct ggml_tensor * llama_build_train_graphs(
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one)); ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
// input gradient // input gradient
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one)); ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
GGML_ASSERT(t36->grad->data == NULL && !ggml_is_view(t36->grad)); GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
ggml_allocr_alloc(alloc, t36->grad); ggml_allocr_alloc(alloc, t36->grad);
// allocating checkpoints in one block to reduce memory fragmentation // allocating checkpoints in one block to reduce memory fragmentation
// note: they will be freed in reverse order // note: they will be freed in reverse order
for (int i = 0; i < (int) checkpoints.size(); ++i) { for (int i = 0; i < (int) checkpoints.size(); ++i) {
if (checkpoints[i]->data == NULL && !ggml_is_view(checkpoints[i])) { if (checkpoints[i]->data == NULL && checkpoints[i]->view_src == NULL) {
ggml_allocr_alloc(alloc, checkpoints[i]); ggml_allocr_alloc(alloc, checkpoints[i]);
} }
} }

156
ggml.c
View file

@ -16174,7 +16174,7 @@ static size_t hash(void * p) {
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE; return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
} }
static bool hash_insert(void * hash_table[], void * p) { static size_t hash_find(void * hash_table[], void * p) {
size_t h = hash(p); size_t h = hash(p);
// linear probing // linear probing
@ -16182,38 +16182,166 @@ static bool hash_insert(void * hash_table[], void * p) {
while (hash_table[i] != NULL && hash_table[i] != p) { while (hash_table[i] != NULL && hash_table[i] != p) {
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE; i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
if (i == h) { if (i == h) {
// hash table is full // visited all hash table entries -> not found
GGML_ASSERT(false); return GGML_GRAPH_HASHTABLE_SIZE;
} }
} }
return i;
}
static bool hash_insert(void * hash_table[], void * p) {
size_t i = hash_find(hash_table, p);
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
if (hash_table[i] == p) { if (hash_table[i] == p) {
return true; return true;
} }
// insert // insert
GGML_ASSERT(hash_table[i] == NULL);
hash_table[i] = p; hash_table[i] = p;
return false; return false;
} }
static bool hash_contains(void * hash_table[], void * p) { static bool hash_contains(void * hash_table[], void * p) {
size_t h = hash(p); size_t i = hash_find(hash_table, p);
return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
}
// linear probing struct hash_map {
size_t i = h; void * keys[GGML_GRAPH_HASHTABLE_SIZE];
while (hash_table[i] != NULL && hash_table[i] != p) { void * vals[GGML_GRAPH_HASHTABLE_SIZE];
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE; };
if (i == h) {
// hash table is full struct hash_map * new_hash_map() {
return false; struct hash_map * result = malloc(sizeof(struct hash_map));
for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
result->keys[i] = NULL;
result->vals[i] = NULL;
}
return result;
};
void free_hash_map(struct hash_map * map) {
free(map);
}
// gradient checkpointing
static struct ggml_tensor * ggml_recompute_graph_node(
struct ggml_context * ctx,
struct ggml_cgraph * graph,
struct hash_map * replacements,
struct ggml_tensor * node) {
if (node == NULL) {
return NULL;
}
if (node->is_param) {
return node;
}
if (!hash_contains(graph->visited_hash_table, node)) {
return node;
}
int count_children = 0;
for (int k = 0; k < GGML_MAX_SRC; ++k) {
if (node->src[k]) {
++count_children;
} }
} }
if (hash_table[i] == p) { if (count_children == 0) {
return true; return node;
} }
return false; size_t i = hash_find(replacements->keys, node);
GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
if (replacements->keys[i] == node) {
return (struct ggml_tensor *) replacements->vals[i];
}
struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
// insert clone into replacements
GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
replacements->keys[i] = node;
replacements->vals[i] = clone;
clone->op = node->op;
clone->grad = node->grad;
clone->is_param = node->is_param;
clone->extra = node->extra;
for (int k = 0; k < GGML_MAX_DIMS; ++k) {
clone->nb[k] = node->nb[k];
}
for (int k = 0; k < GGML_MAX_SRC; ++k) {
clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
}
if (node->view_src != NULL) {
clone->data = (node->view_src->data == NULL)
? NULL // view_src not yet allocated
: (char *) node->view_src->data // view_src already allocated
+ node->view_offs;
clone->view_src = node->view_src;
clone->view_offs = node->view_offs;
}
GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
return clone;
};
void ggml_build_backward_gradient_checkpointing(
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct ggml_cgraph * gb,
struct ggml_cgraph * gb_tmp,
struct ggml_tensor * * checkpoints,
int n_checkpoints) {
*gb_tmp = *gf;
ggml_build_backward_expand(ctx, gf, gb_tmp, true);
if (n_checkpoints <= 0) {
*gb = *gb_tmp;
return;
}
struct hash_map * replacements = new_hash_map();
// insert checkpoints in replacements
for (int i = 0; i < n_checkpoints; ++i) {
size_t k = hash_find(replacements->keys, checkpoints[i]);
GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
replacements->keys[k] = checkpoints[i];
replacements->vals[k] = checkpoints[i];
}
*gb = *gf;
// rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
// replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
// by recomputing them from checkpoints
for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
struct ggml_tensor * node = gb_tmp->nodes[i];
for (int k = 0; k < GGML_MAX_SRC; ++k) {
// insert new tensors recomputing src, reusing already made replacements,
// remember replacements: remember new tensors with mapping from corresponding gf nodes
// recurse for input tensors,
// unless (i.e. terminating when) input tensors are replacments (like checkpoints)
node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
}
// insert rewritten backward node with replacements made into resulting backward graph gb
ggml_build_forward_expand(gb, node);
}
free_hash_map(replacements);
} }
// functions to change gradients considering the case that input a might be initial gradient with zero value // functions to change gradients considering the case that input a might be initial gradient with zero value

10
ggml.h
View file

@ -1664,6 +1664,16 @@ extern "C" {
// dump the graph into a file using the dot format // dump the graph into a file using the dot format
GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename); GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
// build gradient checkpointing backward graph gb for gf using provided checkpoints
// gb_tmp will contain original backward graph with rewritten backward process nodes,
// but without the second forward pass nodes.
GGML_API void ggml_build_backward_gradient_checkpointing(
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct ggml_cgraph * gb,
struct ggml_cgraph * gb_tmp,
struct ggml_tensor * * checkpoints,
int n_checkpoints);
// //
// optimization // optimization
// //