ggml : fix op params handling
This commit is contained in:
parent
ec96ee57f4
commit
b9a63636c0
1 changed files with 4 additions and 3 deletions
7
ggml.c
7
ggml.c
|
@ -4886,6 +4886,8 @@ struct ggml_tensor * ggml_concat(
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
int dim) {
|
int dim) {
|
||||||
|
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
|
||||||
|
|
||||||
int64_t ne[GGML_MAX_DIMS];
|
int64_t ne[GGML_MAX_DIMS];
|
||||||
for (int d = 0; d < GGML_MAX_DIMS; ++d) {
|
for (int d = 0; d < GGML_MAX_DIMS; ++d) {
|
||||||
if (d == dim) {
|
if (d == dim) {
|
||||||
|
@ -4904,7 +4906,7 @@ struct ggml_tensor * ggml_concat(
|
||||||
|
|
||||||
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
|
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
|
||||||
|
|
||||||
ggml_set_op_params(result, &dim, sizeof(dim));
|
ggml_set_op_params_i32(result, 0, dim);
|
||||||
|
|
||||||
result->op = GGML_OP_CONCAT;
|
result->op = GGML_OP_CONCAT;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
@ -10979,8 +10981,7 @@ static void ggml_compute_forward_concat_f32(
|
||||||
GGML_ASSERT(nb00 == sizeof(float));
|
GGML_ASSERT(nb00 == sizeof(float));
|
||||||
GGML_ASSERT(nb10 == sizeof(float));
|
GGML_ASSERT(nb10 == sizeof(float));
|
||||||
|
|
||||||
int32_t dim;
|
const int32_t dim = ggml_get_op_params_i32(dst, 0);
|
||||||
memcpy(&dim, dst->op_params, sizeof(int32_t));
|
|
||||||
|
|
||||||
GGML_ASSERT(dim >= 0 && dim < 4);
|
GGML_ASSERT(dim >= 0 && dim < 4);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue