ggml : fix handling of "view" ops in ggml_graph_import()
This commit is contained in:
parent
b2fd06c6aa
commit
6af6a05663
1 changed files with 51 additions and 29 deletions
64
ggml.c
64
ggml.c
|
@ -14970,6 +14970,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
|
||||||
op = *(const uint32_t *) ptr; ptr += sizeof(op);
|
op = *(const uint32_t *) ptr; ptr += sizeof(op);
|
||||||
n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims);
|
n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims);
|
||||||
|
|
||||||
|
enum ggml_op eop = (enum ggml_op) op;
|
||||||
|
|
||||||
int64_t ne[GGML_MAX_DIMS];
|
int64_t ne[GGML_MAX_DIMS];
|
||||||
size_t nb[GGML_MAX_DIMS];
|
size_t nb[GGML_MAX_DIMS];
|
||||||
|
|
||||||
|
@ -14984,42 +14986,62 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
|
||||||
nb[j] = nb_cur;
|
nb[j] = nb_cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne);
|
uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur); // TODO: not yet used
|
||||||
|
|
||||||
tensor->op = (enum ggml_op) op;
|
const char * ptr_name = ptr; ptr += GGML_MAX_NAME;
|
||||||
|
|
||||||
uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur);
|
const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += (2 + GGML_MAX_OPT)*sizeof(int32_t);
|
||||||
|
|
||||||
memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME;
|
struct ggml_tensor * args[2 + GGML_MAX_OPT] = { NULL };
|
||||||
|
|
||||||
for (int j = 0; j < GGML_MAX_DIMS; ++j) {
|
|
||||||
tensor->nb[j] = nb[j];
|
|
||||||
}
|
|
||||||
|
|
||||||
// parse args
|
// parse args
|
||||||
{
|
|
||||||
struct ggml_tensor ** args[2 + GGML_MAX_OPT] = {
|
|
||||||
&tensor->src0,
|
|
||||||
&tensor->src1,
|
|
||||||
};
|
|
||||||
|
|
||||||
for (int j = 0; j < GGML_MAX_OPT; ++j) {
|
|
||||||
args[2 + j] = &tensor->opt[j];
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
|
for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
|
||||||
const int32_t arg_idx = *(const int32_t *) ptr; ptr += sizeof(arg_idx);
|
const int32_t arg_idx = ptr_arg_idx[j];
|
||||||
|
|
||||||
if (arg_idx == -1) {
|
if (arg_idx == -1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (arg_idx < GGML_MAX_NODES) {
|
if (arg_idx < GGML_MAX_NODES) {
|
||||||
*args[j] = result.leafs[arg_idx];
|
args[j] = result.leafs[arg_idx];
|
||||||
} else {
|
} else {
|
||||||
*args[j] = result.nodes[arg_idx - GGML_MAX_NODES];
|
args[j] = result.nodes[arg_idx - GGML_MAX_NODES];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create the tensor
|
||||||
|
// "view" operations are handled differently
|
||||||
|
|
||||||
|
struct ggml_tensor * tensor = NULL;
|
||||||
|
|
||||||
|
switch (eop) {
|
||||||
|
// TODO: implement other view ops
|
||||||
|
case GGML_OP_RESHAPE:
|
||||||
|
{
|
||||||
|
// TODO: implement other dims
|
||||||
|
tensor = ggml_reshape_3d(*ctx_eval, args[0], ne[0], ne[1], ne[2]);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne);
|
||||||
|
|
||||||
|
tensor->op = eop;
|
||||||
|
} break;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
memcpy(tensor->name, ptr_name, GGML_MAX_NAME);
|
||||||
|
|
||||||
|
// TODO: double-check this is needed
|
||||||
|
for (int j = 0; j < GGML_MAX_DIMS; ++j) {
|
||||||
|
tensor->nb[j] = nb[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor->src0 = args[0];
|
||||||
|
tensor->src1 = args[1];
|
||||||
|
|
||||||
|
for (int j = 0; j < GGML_MAX_OPT; ++j) {
|
||||||
|
tensor->opt[j] = args[2 + j];
|
||||||
}
|
}
|
||||||
|
|
||||||
result.nodes[i] = tensor;
|
result.nodes[i] = tensor;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue