ggml : fix handling of "view" ops in ggml_graph_import()

This commit is contained in:
Georgi Gerganov 2023-05-31 22:28:15 +03:00
parent b2fd06c6aa
commit 6af6a05663
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

80
ggml.c
View file

@ -11156,7 +11156,7 @@ static void ggml_compute_forward_rope_f32(
theta *= theta_scale;
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = src[0];
const float x1 = src[1];
@ -11177,7 +11177,7 @@ static void ggml_compute_forward_rope_f32(
const int64_t i0 = ib*n_dims + ic/2;
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = src[0];
const float x1 = src[n_dims/2];
@ -14970,6 +14970,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
op = *(const uint32_t *) ptr; ptr += sizeof(op);
n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims);
enum ggml_op eop = (enum ggml_op) op;
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];
@ -14984,42 +14986,62 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
nb[j] = nb_cur;
}
struct ggml_tensor * tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne);
uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur); // TODO: not yet used
tensor->op = (enum ggml_op) op;
const char * ptr_name = ptr; ptr += GGML_MAX_NAME;
uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur);
const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += (2 + GGML_MAX_OPT)*sizeof(int32_t);
memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME;
struct ggml_tensor * args[2 + GGML_MAX_OPT] = { NULL };
// parse args
for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
const int32_t arg_idx = ptr_arg_idx[j];
if (arg_idx == -1) {
continue;
}
if (arg_idx < GGML_MAX_NODES) {
args[j] = result.leafs[arg_idx];
} else {
args[j] = result.nodes[arg_idx - GGML_MAX_NODES];
}
}
// create the tensor
// "view" operations are handled differently
struct ggml_tensor * tensor = NULL;
switch (eop) {
// TODO: implement other view ops
case GGML_OP_RESHAPE:
{
// TODO: implement other dims
tensor = ggml_reshape_3d(*ctx_eval, args[0], ne[0], ne[1], ne[2]);
} break;
default:
{
tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne);
tensor->op = eop;
} break;
}
memcpy(tensor->name, ptr_name, GGML_MAX_NAME);
// TODO: double-check this is needed
for (int j = 0; j < GGML_MAX_DIMS; ++j) {
tensor->nb[j] = nb[j];
}
// parse args
{
struct ggml_tensor ** args[2 + GGML_MAX_OPT] = {
&tensor->src0,
&tensor->src1,
};
tensor->src0 = args[0];
tensor->src1 = args[1];
for (int j = 0; j < GGML_MAX_OPT; ++j) {
args[2 + j] = &tensor->opt[j];
}
for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
const int32_t arg_idx = *(const int32_t *) ptr; ptr += sizeof(arg_idx);
if (arg_idx == -1) {
continue;
}
if (arg_idx < GGML_MAX_NODES) {
*args[j] = result.leafs[arg_idx];
} else {
*args[j] = result.nodes[arg_idx - GGML_MAX_NODES];
}
}
for (int j = 0; j < GGML_MAX_OPT; ++j) {
tensor->opt[j] = args[2 + j];
}
result.nodes[i] = tensor;