diff --git a/otherarch/rwkv_v3.cpp b/otherarch/rwkv_v3.cpp index 48e0414ed..cbabab5a8 100644 --- a/otherarch/rwkv_v3.cpp +++ b/otherarch/rwkv_v3.cpp @@ -304,6 +304,14 @@ struct rwkv_tensor { uint8_t * data; }; +//rwkv relied on the old ggml_nbytes implementation, so backport it here. Fixes breaking change in PR 2874 +size_t rwkv_nbytes_old(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + auto a = tensor->ne[3]*tensor->nb[3]; + auto b = (ggml_nelements(tensor)*ggml_type_size(tensor->type))/ggml_blck_size(tensor->type); + return ((a) > (b) ? (a) : (b)); +} + bool rwkv_fread_tensor_header(FILE * file, struct rwkv_tensor_header & header) { RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, sizeof(struct rwkv_tensor_header) - sizeof(uint32_t), &header)); header.height = 1; @@ -371,7 +379,7 @@ bool rwkv_fread_ggml_tensor_data(FILE * file, const struct rwkv_tensor_header & RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor"); ggml_set_name(tensor, name.c_str()); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, ggml_nbytes(tensor), tensor->data), "Failed to read tensor data from %s", name.c_str()); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, rwkv_nbytes_old(tensor), tensor->data), "Failed to read tensor data from %s", name.c_str()); return true; } @@ -539,7 +547,7 @@ struct rwkv_future_tensor { decoy.ne[1] = height; decoy.ne[2] = 1; decoy.ne[3] = 1; - return ggml_nbytes(&decoy); + return rwkv_nbytes_old(&decoy); } rwkv_future_tensor() {} @@ -1595,7 +1603,7 @@ bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers) void rwkv_set_inputs(const struct rwkv_context * ctx, const float * state_in) { if (state_in) { - memcpy(ctx->input_state->data, state_in, ggml_nbytes(ctx->input_state)); + memcpy(ctx->input_state->data, state_in, rwkv_nbytes_old(ctx->input_state)); } else { rwkv_init_state(ctx, (float *) ctx->input_state->data); } @@ -1603,11 +1611,11 @@ void rwkv_set_inputs(const struct rwkv_context * ctx, const float * state_in) { void rwkv_get_outputs(const struct rwkv_context * ctx, float * state_out, float * logits_out) { if (state_out) { - memcpy(state_out, ctx->output_state->data, ggml_nbytes(ctx->output_state)); + memcpy(state_out, ctx->output_state->data, rwkv_nbytes_old(ctx->output_state)); } if (logits_out) { - memcpy(logits_out, ctx->logits->data, ggml_nbytes(ctx->logits)); + memcpy(logits_out, ctx->logits->data, rwkv_nbytes_old(ctx->logits)); } }