gguf : add support for 64-bit (no backwards comp yet)

This commit is contained in:
Georgi Gerganov 2023-08-26 22:05:14 +03:00
parent 5f1fffd2d4
commit 4f0547e4a3
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 95 additions and 4 deletions

View file

@ -30,6 +30,9 @@ bool gguf_ex_write(const std::string & fname) {
gguf_set_val_u32 (ctx, "some.parameter.uint32", 0x12345678);
gguf_set_val_i32 (ctx, "some.parameter.int32", -0x12345679);
gguf_set_val_f32 (ctx, "some.parameter.float32", 0.123456789f);
gguf_set_val_u64 (ctx, "some.parameter.uint64", 0x123456789abcdef0ull);
gguf_set_val_i64 (ctx, "some.parameter.int64", -0x123456789abcdef1ll);
gguf_set_val_f64 (ctx, "some.parameter.float64", 0.1234567890123456789);
gguf_set_val_bool(ctx, "some.parameter.bool", true);
gguf_set_val_str (ctx, "some.parameter.string", "hello world");

63
ggml.c
View file

@ -19408,9 +19408,12 @@ static const size_t GGUF_TYPE_SIZE[GGUF_TYPE_COUNT] = {
[GGUF_TYPE_FLOAT32] = sizeof(float),
[GGUF_TYPE_BOOL] = sizeof(bool),
[GGUF_TYPE_STRING] = sizeof(struct gguf_str),
[GGUF_TYPE_UINT64] = sizeof(uint64_t),
[GGUF_TYPE_INT64] = sizeof(int64_t),
[GGUF_TYPE_FLOAT64] = sizeof(double),
[GGUF_TYPE_ARRAY] = 0, // undefined
};
static_assert(GGUF_TYPE_COUNT == 10, "GGUF_TYPE_COUNT != 10");
static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
static const char * GGUF_TYPE_NAME[GGUF_TYPE_COUNT] = {
[GGUF_TYPE_UINT8] = "u8",
@ -19423,8 +19426,11 @@ static const char * GGUF_TYPE_NAME[GGUF_TYPE_COUNT] = {
[GGUF_TYPE_BOOL] = "bool",
[GGUF_TYPE_STRING] = "str",
[GGUF_TYPE_ARRAY] = "arr",
[GGUF_TYPE_UINT64] = "u64",
[GGUF_TYPE_INT64] = "i64",
[GGUF_TYPE_FLOAT64] = "f64",
};
static_assert(GGUF_TYPE_COUNT == 10, "GGUF_TYPE_COUNT != 10");
static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
union gguf_value {
uint8_t uint8;
@ -19434,6 +19440,9 @@ union gguf_value {
uint32_t uint32;
int32_t int32;
float float32;
uint64_t uint64;
int64_t int64;
double float64;
bool bool_;
struct gguf_str str;
@ -19466,7 +19475,7 @@ struct gguf_tensor_info {
struct gguf_str name;
uint32_t n_dims;
uint32_t ne[GGML_MAX_DIMS];
uint64_t ne[GGML_MAX_DIMS];
enum ggml_type type;
@ -19599,6 +19608,9 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
case GGUF_TYPE_UINT32: ok = ok && gguf_fread_el (file, &kv->value.uint32, sizeof(kv->value.uint32), &offset); break;
case GGUF_TYPE_INT32: ok = ok && gguf_fread_el (file, &kv->value.int32, sizeof(kv->value.int32), &offset); break;
case GGUF_TYPE_FLOAT32: ok = ok && gguf_fread_el (file, &kv->value.float32, sizeof(kv->value.float32), &offset); break;
case GGUF_TYPE_UINT64: ok = ok && gguf_fread_el (file, &kv->value.uint64, sizeof(kv->value.uint64), &offset); break;
case GGUF_TYPE_INT64: ok = ok && gguf_fread_el (file, &kv->value.int64, sizeof(kv->value.int64), &offset); break;
case GGUF_TYPE_FLOAT64: ok = ok && gguf_fread_el (file, &kv->value.float64, sizeof(kv->value.float64), &offset); break;
case GGUF_TYPE_BOOL: ok = ok && gguf_fread_el (file, &kv->value.bool_, sizeof(kv->value.bool_), &offset); break;
case GGUF_TYPE_STRING: ok = ok && gguf_fread_str(file, &kv->value.str, &offset); break;
case GGUF_TYPE_ARRAY:
@ -19614,6 +19626,9 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
case GGUF_TYPE_UINT32:
case GGUF_TYPE_INT32:
case GGUF_TYPE_FLOAT32:
case GGUF_TYPE_UINT64:
case GGUF_TYPE_INT64:
case GGUF_TYPE_FLOAT64:
case GGUF_TYPE_BOOL:
{
kv->value.arr.data = malloc(kv->value.arr.n * GGUF_TYPE_SIZE[kv->value.arr.type]);
@ -19954,6 +19969,18 @@ float gguf_get_val_f32(struct gguf_context * ctx, int i) {
return ctx->kv[i].value.float32;
}
uint64_t gguf_get_val_u64(struct gguf_context * ctx, int i) {
return ctx->kv[i].value.uint64;
}
int64_t gguf_get_val_i64(struct gguf_context * ctx, int i) {
return ctx->kv[i].value.int64;
}
double gguf_get_val_f64(struct gguf_context * ctx, int i) {
return ctx->kv[i].value.float64;
}
bool gguf_get_val_bool(struct gguf_context * ctx, int i) {
return ctx->kv[i].value.bool_;
}
@ -20056,6 +20083,27 @@ void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) {
ctx->kv[idx].value.float32 = val;
}
void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) {
const int idx = gguf_get_or_add_key(ctx, key);
ctx->kv[idx].type = GGUF_TYPE_UINT64;
ctx->kv[idx].value.uint64 = val;
}
void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) {
const int idx = gguf_get_or_add_key(ctx, key);
ctx->kv[idx].type = GGUF_TYPE_INT64;
ctx->kv[idx].value.int64 = val;
}
void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) {
const int idx = gguf_get_or_add_key(ctx, key);
ctx->kv[idx].type = GGUF_TYPE_FLOAT64;
ctx->kv[idx].value.float64 = val;
}
void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) {
const int idx = gguf_get_or_add_key(ctx, key);
@ -20106,6 +20154,9 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
case GGUF_TYPE_UINT32: gguf_set_val_u32 (ctx, src->kv[i].key.data, src->kv[i].value.uint32); break;
case GGUF_TYPE_INT32: gguf_set_val_i32 (ctx, src->kv[i].key.data, src->kv[i].value.int32); break;
case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, src->kv[i].key.data, src->kv[i].value.float32); break;
case GGUF_TYPE_UINT64: gguf_set_val_u64 (ctx, src->kv[i].key.data, src->kv[i].value.uint64); break;
case GGUF_TYPE_INT64: gguf_set_val_i64 (ctx, src->kv[i].key.data, src->kv[i].value.int64); break;
case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, src->kv[i].key.data, src->kv[i].value.float64); break;
case GGUF_TYPE_BOOL: gguf_set_val_bool(ctx, src->kv[i].key.data, src->kv[i].value.bool_); break;
case GGUF_TYPE_STRING: gguf_set_val_str (ctx, src->kv[i].key.data, src->kv[i].value.str.data); break;
case GGUF_TYPE_ARRAY:
@ -20267,6 +20318,9 @@ static void gguf_write_to_buf(struct gguf_context * ctx, struct gguf_buf * buf,
case GGUF_TYPE_UINT32: gguf_bwrite_el (buf, &kv->value.uint32, sizeof(kv->value.uint32) ); break;
case GGUF_TYPE_INT32: gguf_bwrite_el (buf, &kv->value.int32, sizeof(kv->value.int32) ); break;
case GGUF_TYPE_FLOAT32: gguf_bwrite_el (buf, &kv->value.float32, sizeof(kv->value.float32)); break;
case GGUF_TYPE_UINT64: gguf_bwrite_el (buf, &kv->value.uint64, sizeof(kv->value.uint64) ); break;
case GGUF_TYPE_INT64: gguf_bwrite_el (buf, &kv->value.int64, sizeof(kv->value.int64) ); break;
case GGUF_TYPE_FLOAT64: gguf_bwrite_el (buf, &kv->value.float64, sizeof(kv->value.float64)); break;
case GGUF_TYPE_BOOL: gguf_bwrite_el (buf, &kv->value.bool_, sizeof(kv->value.bool_) ); break;
case GGUF_TYPE_STRING: gguf_bwrite_str(buf, &kv->value.str ); break;
case GGUF_TYPE_ARRAY:
@ -20282,6 +20336,9 @@ static void gguf_write_to_buf(struct gguf_context * ctx, struct gguf_buf * buf,
case GGUF_TYPE_UINT32:
case GGUF_TYPE_INT32:
case GGUF_TYPE_FLOAT32:
case GGUF_TYPE_UINT64:
case GGUF_TYPE_INT64:
case GGUF_TYPE_FLOAT64:
case GGUF_TYPE_BOOL:
{
gguf_bwrite_el(buf, kv->value.arr.data, kv->value.arr.n * GGUF_TYPE_SIZE[kv->value.arr.type]);

9
ggml.h
View file

@ -1827,6 +1827,9 @@ extern "C" {
GGUF_TYPE_BOOL = 7,
GGUF_TYPE_STRING = 8,
GGUF_TYPE_ARRAY = 9,
GGUF_TYPE_UINT64 = 10,
GGUF_TYPE_INT64 = 11,
GGUF_TYPE_FLOAT64 = 12,
GGUF_TYPE_COUNT, // marks the end of the enum
};
@ -1867,6 +1870,9 @@ extern "C" {
GGML_API uint32_t gguf_get_val_u32 (struct gguf_context * ctx, int i);
GGML_API int32_t gguf_get_val_i32 (struct gguf_context * ctx, int i);
GGML_API float gguf_get_val_f32 (struct gguf_context * ctx, int i);
GGML_API uint64_t gguf_get_val_u64 (struct gguf_context * ctx, int i);
GGML_API int64_t gguf_get_val_i64 (struct gguf_context * ctx, int i);
GGML_API double gguf_get_val_f64 (struct gguf_context * ctx, int i);
GGML_API bool gguf_get_val_bool(struct gguf_context * ctx, int i);
GGML_API const char * gguf_get_val_str (struct gguf_context * ctx, int i);
GGML_API int gguf_get_arr_n (struct gguf_context * ctx, int i);
@ -1886,6 +1892,9 @@ extern "C" {
GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);
GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val);
GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val);
GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val);
GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val);
GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val);
GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val);
GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n);

View file

@ -365,6 +365,9 @@ class GGUFValueType(IntEnum):
BOOL = 7
STRING = 8
ARRAY = 9
UINT64 = 10
INT64 = 11
FLOAT64 = 12
@staticmethod
def get_type(val):
@ -378,6 +381,7 @@ class GGUFValueType(IntEnum):
return GGUFValueType.BOOL
elif isinstance(val, int):
return GGUFValueType.INT32
# TODO: need help with 64-bit types in Python
else:
print("Unknown type: "+str(type(val)))
sys.exit()
@ -444,6 +448,18 @@ class GGUFWriter:
self.add_key(key)
self.add_val(val, GGUFValueType.FLOAT32)
def add_uint64(self, key: str, val: int):
self.add_key(key)
self.add_val(val, GGUFValueType.UINT64)
def add_int64(self, key: str, val: int):
self.add_key(key)
self.add_val(val, GGUFValueType.INT64)
def add_float64(self, key: str, val: float):
self.add_key(key)
self.add_val(val, GGUFValueType.FLOAT64)
def add_bool(self, key: str, val: bool):
self.add_key(key)
self.add_val(val, GGUFValueType.BOOL)
@ -483,6 +499,12 @@ class GGUFWriter:
self.kv_data += struct.pack("<i", val)
elif vtype == GGUFValueType.FLOAT32:
self.kv_data += struct.pack("<f", val)
elif vtype == GGUFValueType.UINT64:
self.kv_data += struct.pack("<Q", val)
elif vtype == GGUFValueType.INT64:
self.kv_data += struct.pack("<q", val)
elif vtype == GGUFValueType.FLOAT64:
self.kv_data += struct.pack("<d", val)
elif vtype == GGUFValueType.BOOL:
self.kv_data += struct.pack("?", val)
elif vtype == GGUFValueType.STRING:
@ -512,7 +534,7 @@ class GGUFWriter:
n_dims = len(tensor_shape)
self.ti_data += struct.pack("<I", n_dims)
for i in range(n_dims):
self.ti_data += struct.pack("<I", tensor_shape[n_dims - 1 - i])
self.ti_data += struct.pack("<Q", tensor_shape[n_dims - 1 - i])
if raw_dtype is None:
dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16
else: