ggml : fix op params handling

This commit is contained in:
Georgi Gerganov 2024-05-27 15:55:35 +03:00
parent ec96ee57f4
commit b9a63636c0
No known key found for this signature in database
GPG key ID: BF970631944C16B7

7
ggml.c
View file

@ -4886,6 +4886,8 @@ struct ggml_tensor * ggml_concat(
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b, struct ggml_tensor * b,
int dim) { int dim) {
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
int64_t ne[GGML_MAX_DIMS]; int64_t ne[GGML_MAX_DIMS];
for (int d = 0; d < GGML_MAX_DIMS; ++d) { for (int d = 0; d < GGML_MAX_DIMS; ++d) {
if (d == dim) { if (d == dim) {
@ -4904,7 +4906,7 @@ struct ggml_tensor * ggml_concat(
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne); struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
ggml_set_op_params(result, &dim, sizeof(dim)); ggml_set_op_params_i32(result, 0, dim);
result->op = GGML_OP_CONCAT; result->op = GGML_OP_CONCAT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -10979,8 +10981,7 @@ static void ggml_compute_forward_concat_f32(
GGML_ASSERT(nb00 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float)); GGML_ASSERT(nb10 == sizeof(float));
int32_t dim; const int32_t dim = ggml_get_op_params_i32(dst, 0);
memcpy(&dim, dst->op_params, sizeof(int32_t));
GGML_ASSERT(dim >= 0 && dim < 4); GGML_ASSERT(dim >= 0 && dim < 4);