Returning 0 for some cases, instead of asserting.

This commit is contained in:
Martin Evans 2024-03-27 16:31:27 +00:00
parent b8e8facb0e
commit b182f8f67f
2 changed files with 19 additions and 5 deletions

View file

@ -15227,7 +15227,9 @@ size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama
uint32_t size_t_size; uint32_t size_t_size;
memcpy(&size_t_size, inp, sizeof(size_t_size)); memcpy(&size_t_size, inp, sizeof(size_t_size));
inp += sizeof(size_t_size); inp += sizeof(size_t_size);
GGML_ASSERT(size_t_size == sizeof(size_t)); if (size_t_size != sizeof(size_t)) {
return 0;
}
// Read the cell count // Read the cell count
uint32_t cell_count; uint32_t cell_count;
@ -15244,6 +15246,18 @@ size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama
memcpy(&n_embd_v_gqa_ref, inp, sizeof(n_embd_v_gqa_ref)); memcpy(&n_embd_v_gqa_ref, inp, sizeof(n_embd_v_gqa_ref));
inp += sizeof(n_embd_v_gqa_ref); inp += sizeof(n_embd_v_gqa_ref);
// Sanity check model compatibility
const auto& hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
if (n_layer != n_layer_ref) {
return 0;
}
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
return 0;
}
// Allocate the new cells for the slot // Allocate the new cells for the slot
{ {
llama_batch batch = llama_batch_init(cell_count, 0, 1); llama_batch batch = llama_batch_init(cell_count, 0, 1);
@ -15274,10 +15288,6 @@ size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama
llama_batch_free(batch); llama_batch_free(batch);
} }
const auto& hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
const uint32_t kv_size = kv_self.size; const uint32_t kv_size = kv_self.size;
const uint32_t kv_head = kv_self.head; const uint32_t kv_head = kv_self.head;
GGML_ASSERT(n_layer == n_layer_ref); GGML_ASSERT(n_layer == n_layer_ref);

View file

@ -632,6 +632,10 @@ extern "C" {
uint8_t * dst, uint8_t * dst,
llama_seq_id seq_id); llama_seq_id seq_id);
// Copy the sequence data (originally copied with `llama_copy_seq_data`) into a sequence.
// Returns:
// - Positive: Ok
// - Zero: Failed to load
LLAMA_API size_t llama_set_seq_data( LLAMA_API size_t llama_set_seq_data(
struct llama_context * ctx, struct llama_context * ctx,
const uint8_t * src, const uint8_t * src,