bug fixes for convert-train-checkpoint-to-gguf.py loading checkpoints with opt_version=0

This commit is contained in:
xaedes 2023-08-28 18:33:00 +02:00
parent e8df9e6815
commit 31c093c2cc
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -116,8 +116,8 @@ class OptimizationParamsV0:
self.n_threads = struct.unpack('<i', bytes(data[offset:offset + 4]))[0]; offset += 4 self.n_threads = struct.unpack('<i', bytes(data[offset:offset + 4]))[0]; offset += 4
self.past = struct.unpack('<i', bytes(data[offset:offset + 4]))[0]; offset += 4 self.past = struct.unpack('<i', bytes(data[offset:offset + 4]))[0]; offset += 4
self.delta = struct.unpack('<f', bytes(data[offset:offset + 4]))[0]; offset += 4 self.delta = struct.unpack('<f', bytes(data[offset:offset + 4]))[0]; offset += 4
self.print_forward_graph = struct.unpack('<?', bytes(data[offset:offset + 4]))[0]; offset += 4 # 32bit-aligned self.print_forward_graph = struct.unpack('<?', bytes(data[offset:offset + 1]))[0]; offset += 4 # 32bit-aligned
self.print_backward_graph = struct.unpack('<?', bytes(data[offset:offset + 4]))[0]; offset += 4 # 32bit-aligned self.print_backward_graph = struct.unpack('<?', bytes(data[offset:offset + 1]))[0]; offset += 4 # 32bit-aligned
self.adam_n_iter = struct.unpack('<i', bytes(data[offset:offset + 4]))[0]; offset += 4 self.adam_n_iter = struct.unpack('<i', bytes(data[offset:offset + 4]))[0]; offset += 4
self.adam_sched = struct.unpack('<f', bytes(data[offset:offset + 4]))[0]; offset += 4 self.adam_sched = struct.unpack('<f', bytes(data[offset:offset + 4]))[0]; offset += 4
self.adam_decay = struct.unpack('<f', bytes(data[offset:offset + 4]))[0]; offset += 4 self.adam_decay = struct.unpack('<f', bytes(data[offset:offset + 4]))[0]; offset += 4
@ -177,7 +177,7 @@ class OptimizationContext:
g = Tensor('f', [self.nx]) g = Tensor('f', [self.nx])
g2 = Tensor('f', [self.nx]) g2 = Tensor('f', [self.nx])
mh = Tensor('f', [self.nx]) mh = Tensor('f', [self.nx])
mv = Tensor('f', [self.nx]) vh = Tensor('f', [self.nx])
offset = x.load(data, offset) offset = x.load(data, offset)
offset = g.load(data, offset) offset = g.load(data, offset)