fix for RWKV

This commit is contained in:
Concedo 2023-09-01 17:02:11 +08:00
parent 81abd3cb1f
commit 5925c23d51

View file

@ -304,6 +304,14 @@ struct rwkv_tensor {
uint8_t * data;
};
//rwkv relied on the old ggml_nbytes implementation, so backport it here. Fixes breaking change in PR 2874
size_t rwkv_nbytes_old(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
auto a = tensor->ne[3]*tensor->nb[3];
auto b = (ggml_nelements(tensor)*ggml_type_size(tensor->type))/ggml_blck_size(tensor->type);
return ((a) > (b) ? (a) : (b));
}
bool rwkv_fread_tensor_header(FILE * file, struct rwkv_tensor_header & header) {
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, sizeof(struct rwkv_tensor_header) - sizeof(uint32_t), &header));
header.height = 1;
@ -371,7 +379,7 @@ bool rwkv_fread_ggml_tensor_data(FILE * file, const struct rwkv_tensor_header &
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor");
ggml_set_name(tensor, name.c_str());
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, ggml_nbytes(tensor), tensor->data), "Failed to read tensor data from %s", name.c_str());
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, rwkv_nbytes_old(tensor), tensor->data), "Failed to read tensor data from %s", name.c_str());
return true;
}
@ -539,7 +547,7 @@ struct rwkv_future_tensor {
decoy.ne[1] = height;
decoy.ne[2] = 1;
decoy.ne[3] = 1;
return ggml_nbytes(&decoy);
return rwkv_nbytes_old(&decoy);
}
rwkv_future_tensor() {}
@ -1595,7 +1603,7 @@ bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers)
void rwkv_set_inputs(const struct rwkv_context * ctx, const float * state_in) {
if (state_in) {
memcpy(ctx->input_state->data, state_in, ggml_nbytes(ctx->input_state));
memcpy(ctx->input_state->data, state_in, rwkv_nbytes_old(ctx->input_state));
} else {
rwkv_init_state(ctx, (float *) ctx->input_state->data);
}
@ -1603,11 +1611,11 @@ void rwkv_set_inputs(const struct rwkv_context * ctx, const float * state_in) {
void rwkv_get_outputs(const struct rwkv_context * ctx, float * state_out, float * logits_out) {
if (state_out) {
memcpy(state_out, ctx->output_state->data, ggml_nbytes(ctx->output_state));
memcpy(state_out, ctx->output_state->data, rwkv_nbytes_old(ctx->output_state));
}
if (logits_out) {
memcpy(logits_out, ctx->logits->data, ggml_nbytes(ctx->logits));
memcpy(logits_out, ctx->logits->data, rwkv_nbytes_old(ctx->logits));
}
}