not allow adding duplicated tensor name
This commit is contained in:
parent
b4e4b8a935
commit
9cd09aa79b
2 changed files with 9 additions and 0 deletions
4
ggml.c
4
ggml.c
|
@ -21360,6 +21360,10 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
|
|||
void gguf_add_tensor(
|
||||
struct gguf_context * ctx,
|
||||
const struct ggml_tensor * tensor) {
|
||||
if (gguf_find_tensor(ctx, tensor->name) != -1) {
|
||||
GGML_ASSERT(false && "duplicated tensor name");
|
||||
}
|
||||
|
||||
const int idx = ctx->header.n_tensors;
|
||||
ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info));
|
||||
|
||||
|
|
|
@ -63,6 +63,7 @@ class GGUFWriter:
|
|||
self.kv_data_count = 0
|
||||
self.ti_data = bytearray()
|
||||
self.ti_data_count = 0
|
||||
self.ti_names = set()
|
||||
self.use_temp_file = use_temp_file
|
||||
self.temp_file = None
|
||||
self.tensors = []
|
||||
|
@ -197,6 +198,10 @@ class GGUFWriter:
|
|||
if self.state is not WriterState.EMPTY:
|
||||
raise ValueError(f'Expected output file to be empty, got {self.state}')
|
||||
|
||||
if name in self.ti_names:
|
||||
raise ValueError(f'Duplicated tensor name {name}')
|
||||
self.ti_names.add(name)
|
||||
|
||||
encoded_name = name.encode("utf8")
|
||||
self.ti_data += self._pack("Q", len(encoded_name))
|
||||
self.ti_data += encoded_name
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue