ggml : fix op params handling
This commit is contained in:
parent
ec96ee57f4
commit
b9a63636c0
1 changed files with 4 additions and 3 deletions
7
ggml.c
7
ggml.c
|
@ -4886,6 +4886,8 @@ struct ggml_tensor * ggml_concat(
|
|||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int dim) {
|
||||
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
|
||||
|
||||
int64_t ne[GGML_MAX_DIMS];
|
||||
for (int d = 0; d < GGML_MAX_DIMS; ++d) {
|
||||
if (d == dim) {
|
||||
|
@ -4904,7 +4906,7 @@ struct ggml_tensor * ggml_concat(
|
|||
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
|
||||
|
||||
ggml_set_op_params(result, &dim, sizeof(dim));
|
||||
ggml_set_op_params_i32(result, 0, dim);
|
||||
|
||||
result->op = GGML_OP_CONCAT;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
|
@ -10979,8 +10981,7 @@ static void ggml_compute_forward_concat_f32(
|
|||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
|
||||
int32_t dim;
|
||||
memcpy(&dim, dst->op_params, sizeof(int32_t));
|
||||
const int32_t dim = ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
GGML_ASSERT(dim >= 0 && dim < 4);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue