use tensor->view_src instead of ggml_is_view and get_view_source

This commit is contained in:
xaedes 2023-08-30 14:46:12 +02:00
parent b1709f2d25
commit 2392b6725b
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -680,33 +680,6 @@ void free_hash_map(struct hash_map * map) {
delete map; delete map;
} }
static bool ggml_is_view(struct ggml_tensor * t) {
return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
}
static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
switch (t->op) {
case GGML_OP_PERMUTE:
case GGML_OP_RESHAPE:
case GGML_OP_TRANSPOSE:
case GGML_OP_VIEW:
return t->src[0];
case GGML_OP_CPY:
return t->src[1];
default:
return NULL;
}
}
static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
struct ggml_tensor * parent = t;
do {
parent = get_view_parent(parent);
} while (ggml_is_view(parent));
return parent;
}
struct ggml_tensor * ggml_recompute_graph_node( struct ggml_tensor * ggml_recompute_graph_node(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_cgraph * graph, struct ggml_cgraph * graph,
@ -759,10 +732,14 @@ struct ggml_tensor * ggml_recompute_graph_node(
for (int k = 0; k < GGML_MAX_SRC; ++k) { for (int k = 0; k < GGML_MAX_SRC; ++k) {
clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]); clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
} }
if (ggml_is_view(clone)) { if (node->view_src != NULL) {
struct ggml_tensor * source = get_view_source(clone); // GGML_ASSERT(node->view_src->data != NULL);
GGML_ASSERT(source != NULL); clone->data = (node->view_src->data == NULL)
clone->data = source->data; ? NULL // view_src not yet allocated
: (char *) node->view_src->data // view_src already allocated
+ node->view_offs;
clone->view_src = node->view_src;
clone->view_offs = node->view_offs;
} }
GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t))); GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
@ -1002,7 +979,7 @@ struct ggml_tensor * llama_build_lora_finetune_graphs(
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one)); ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
// input gradient // input gradient
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one)); ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
GGML_ASSERT(t36->grad->data == NULL && !ggml_is_view(t36->grad)); GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
ggml_allocr_alloc(alloc, t36->grad); ggml_allocr_alloc(alloc, t36->grad);
// make sure base model tensors data cannot be used in viewable operations // make sure base model tensors data cannot be used in viewable operations
@ -1025,7 +1002,7 @@ struct ggml_tensor * llama_build_lora_finetune_graphs(
// allocating checkpoints in one block to reduce memory fragmentation // allocating checkpoints in one block to reduce memory fragmentation
// note: they will be freed in reverse order // note: they will be freed in reverse order
for (int i = 0; i < checkpoints.size(); ++i) { for (int i = 0; i < checkpoints.size(); ++i) {
if (checkpoints[i]->data == NULL && !ggml_is_view(checkpoints[i])) { if (checkpoints[i]->data == NULL && checkpoints[i]->view_src == NULL) {
ggml_allocr_alloc(alloc, checkpoints[i]); ggml_allocr_alloc(alloc, checkpoints[i]);
} }
} }