fix for RWKV
This commit is contained in:
parent
81abd3cb1f
commit
5925c23d51
1 changed files with 13 additions and 5 deletions
|
@ -304,6 +304,14 @@ struct rwkv_tensor {
|
|||
uint8_t * data;
|
||||
};
|
||||
|
||||
//rwkv relied on the old ggml_nbytes implementation, so backport it here. Fixes breaking change in PR 2874
|
||||
size_t rwkv_nbytes_old(const struct ggml_tensor * tensor) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||
auto a = tensor->ne[3]*tensor->nb[3];
|
||||
auto b = (ggml_nelements(tensor)*ggml_type_size(tensor->type))/ggml_blck_size(tensor->type);
|
||||
return ((a) > (b) ? (a) : (b));
|
||||
}
|
||||
|
||||
bool rwkv_fread_tensor_header(FILE * file, struct rwkv_tensor_header & header) {
|
||||
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, sizeof(struct rwkv_tensor_header) - sizeof(uint32_t), &header));
|
||||
header.height = 1;
|
||||
|
@ -371,7 +379,7 @@ bool rwkv_fread_ggml_tensor_data(FILE * file, const struct rwkv_tensor_header &
|
|||
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor");
|
||||
ggml_set_name(tensor, name.c_str());
|
||||
|
||||
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, ggml_nbytes(tensor), tensor->data), "Failed to read tensor data from %s", name.c_str());
|
||||
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, rwkv_nbytes_old(tensor), tensor->data), "Failed to read tensor data from %s", name.c_str());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -539,7 +547,7 @@ struct rwkv_future_tensor {
|
|||
decoy.ne[1] = height;
|
||||
decoy.ne[2] = 1;
|
||||
decoy.ne[3] = 1;
|
||||
return ggml_nbytes(&decoy);
|
||||
return rwkv_nbytes_old(&decoy);
|
||||
}
|
||||
|
||||
rwkv_future_tensor() {}
|
||||
|
@ -1595,7 +1603,7 @@ bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers)
|
|||
|
||||
void rwkv_set_inputs(const struct rwkv_context * ctx, const float * state_in) {
|
||||
if (state_in) {
|
||||
memcpy(ctx->input_state->data, state_in, ggml_nbytes(ctx->input_state));
|
||||
memcpy(ctx->input_state->data, state_in, rwkv_nbytes_old(ctx->input_state));
|
||||
} else {
|
||||
rwkv_init_state(ctx, (float *) ctx->input_state->data);
|
||||
}
|
||||
|
@ -1603,11 +1611,11 @@ void rwkv_set_inputs(const struct rwkv_context * ctx, const float * state_in) {
|
|||
|
||||
void rwkv_get_outputs(const struct rwkv_context * ctx, float * state_out, float * logits_out) {
|
||||
if (state_out) {
|
||||
memcpy(state_out, ctx->output_state->data, ggml_nbytes(ctx->output_state));
|
||||
memcpy(state_out, ctx->output_state->data, rwkv_nbytes_old(ctx->output_state));
|
||||
}
|
||||
|
||||
if (logits_out) {
|
||||
memcpy(logits_out, ctx->logits->data, ggml_nbytes(ctx->logits));
|
||||
memcpy(logits_out, ctx->logits->data, rwkv_nbytes_old(ctx->logits));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue