From 9912b9efc8922321fe7202ab42ba913833cbe9cd Mon Sep 17 00:00:00 2001 From: Cebtenzzre Date: Tue, 5 Sep 2023 18:21:10 -0400 Subject: [PATCH 01/12] build : add LLAMA_METAL_NDEBUG flag (#3033) --- CMakeLists.txt | 5 ++++- Makefile | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e872ae310..d4ed6179e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -83,6 +83,7 @@ set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF) option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT}) +option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF) option(LLAMA_MPI "llama: use MPI" OFF) option(LLAMA_K_QUANTS "llama: use k-quants" ON) option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF) @@ -174,7 +175,9 @@ if (LLAMA_METAL) set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h) add_compile_definitions(GGML_USE_METAL) - #add_compile_definitions(GGML_METAL_NDEBUG) + if (LLAMA_METAL_NDEBUG) + add_compile_definitions(GGML_METAL_NDEBUG) + endif() # get full path to the file #add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/") diff --git a/Makefile b/Makefile index fe7ddc9ef..4334761a4 100644 --- a/Makefile +++ b/Makefile @@ -360,6 +360,9 @@ ifdef LLAMA_METAL MK_CPPFLAGS += -DGGML_USE_METAL MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit OBJS += ggml-metal.o +ifdef LLAMA_METAL_NDEBUG + MK_CPPFLAGS += -DGGML_METAL_NDEBUG +endif endif # LLAMA_METAL ifdef LLAMA_METAL From ea2c85d5d2a93d39d0172222917f3195f0e456ff Mon Sep 17 00:00:00 2001 From: Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com> Date: Wed, 6 Sep 2023 02:49:11 -0600 Subject: [PATCH 02/12] convert-llama-ggml-to-gguf: Try to handle files older than GGJTv3 (#3023) * convert-llama-ggmlv3-to-gguf: Try to handle files older than GGJTv3 * Better error messages for files that cannot be converted * Add file type to GGUF output * Rename to convert-llama-ggml-to-gguf.py * Include original file type information in description * Improve some informational output --- ...o-gguf.py => convert-llama-ggml-to-gguf.py | 168 ++++++++++++++---- 1 file changed, 133 insertions(+), 35 deletions(-) rename convert-llama-ggmlv3-to-gguf.py => convert-llama-ggml-to-gguf.py (68%) diff --git a/convert-llama-ggmlv3-to-gguf.py b/convert-llama-ggml-to-gguf.py similarity index 68% rename from convert-llama-ggmlv3-to-gguf.py rename to convert-llama-ggml-to-gguf.py index 08ba0c490..b5d3e0b3c 100755 --- a/convert-llama-ggmlv3-to-gguf.py +++ b/convert-llama-ggml-to-gguf.py @@ -5,6 +5,7 @@ import argparse import math import struct import sys +from enum import IntEnum from pathlib import Path import numpy as np @@ -34,10 +35,35 @@ GGML_QUANT_SIZES = { gguf.GGMLQuantizationType.Q8_K : (256, 4 + QK_K + QK_K // 8), } +class GGMLFormat(IntEnum): + GGML = 0 + GGMF = 1 + GGJT = 2 + +class GGMLFType(IntEnum): + ALL_F32 = 0 + MOSTLY_F16 = 1 + MOSTLY_Q4_0 = 2 + MOSTLY_Q4_1 = 3 + MOSTLY_Q4_1_SOME_F16 = 4 + MOSTLY_Q8_0 = 7 + MOSTLY_Q5_0 = 8 + MOSTLY_Q5_1 = 9 + MOSTLY_Q2_K = 10 + MOSTLY_Q3_K_S = 11 + MOSTLY_Q3_K_M = 12 + MOSTLY_Q3_K_L = 13 + MOSTLY_Q4_K_S = 14 + MOSTLY_Q4_K_M = 15 + MOSTLY_Q5_K_S = 16 + MOSTLY_Q5_K_M = 17 + MOSTLY_Q6_K = 18 + class Hyperparameters: def __init__(self): - self.n_vocab = self.n_embd = self.n_mult = self.n_head = self.n_layer = self.n_rot = self.ftype = 0 - self.n_ff = 0 + self.n_vocab = self.n_embd = self.n_mult = self.n_head = 0 + self.n_layer = self.n_rot = self.n_ff = 0 + self.ftype = GGMLFType.ALL_F32 def set_n_ff(self, model): ff_tensor_idx = model.tensor_map.get(b'layers.0.feed_forward.w1.weight') @@ -53,16 +79,21 @@ class Hyperparameters: self.n_head, self.n_layer, self.n_rot, - self.ftype, + ftype, ) = struct.unpack('<7I', data[offset:offset + (4 * 7)]) + try: + self.ftype = GGMLFType(ftype) + except ValueError: + raise ValueError(f'Invalid ftype {ftype}') return 4 * 7 def __str__(self): - return f'' + return f'' class Vocab: - def __init__(self): + def __init__(self, load_scores = True): self.items = [] + self.load_scores = load_scores def load(self, data, offset, n_vocab): orig_offset = offset @@ -70,20 +101,24 @@ class Vocab: itemlen = struct.unpack(' 3: + raise ValueError(f'Cannot handle unexpected GGJT file version {version}') + self.file_format = GGMLFormat.GGJT + self.format_version = version + return 8 + raise ValueError(f"Unexpected file magic {magic!r}! This doesn't look like a GGML format file.") + + def validate_conversion(self, ftype): + err = '' + if (self.file_format < GGMLFormat.GGJT or self.format_version < 2): + if ftype not in (GGMLFType.ALL_F32, GGMLFType.MOSTLY_F16): + err = 'Quantizations changed in GGJTv2. Can only convert unquantized GGML files older than GGJTv2.' + elif (self.file_format == GGMLFormat.GGJT and self.format_version == 2): + if ftype in ( GGMLFType.MOSTLY_Q4_0, GGMLFType.MOSTLY_Q4_1, + GGMLFType.MOSTLY_Q4_1_SOME_F16, GGMLFType.MOSTLY_Q8_0): + err = 'Q4 and Q8 quantizations changed in GGJTv3.' + if len(err) > 0: + raise ValueError(f'{err} Sorry, your {self.file_format.name}v{self.format_version} file of type {ftype.name} is not eligible for conversion.') def load(self, data, offset): offset += self.validate_header(data, offset) hp = Hyperparameters() offset += hp.load(data, offset) - vocab = Vocab() + print(f'* File format: {self.file_format.name}v{self.format_version} with ftype {hp.ftype.name}') + self.validate_conversion(hp.ftype) + vocab = Vocab(load_scores = self.file_format > GGMLFormat.GGML) offset += vocab.load(data, offset, hp.n_vocab) tensors: list[Tensor] = [] tensor_map = {} while offset < len(data): - tensor = Tensor() + tensor = Tensor(use_padding = self.file_format > GGMLFormat.GGMF) offset += tensor.load(data, offset) tensor_map[tensor.name] = len(tensors) tensors.append(tensor) @@ -168,7 +235,10 @@ class GGMLToGGUF: def save(self): print('* Preparing to save GGUF file') - gguf_writer = gguf.GGUFWriter(self.cfg.output, gguf.MODEL_ARCH_NAMES[gguf.MODEL_ARCH.LLAMA], use_temp_file = False) + gguf_writer = gguf.GGUFWriter( + self.cfg.output, + gguf.MODEL_ARCH_NAMES[gguf.MODEL_ARCH.LLAMA], + use_temp_file = False ) self.add_params(gguf_writer) self.add_vocab(gguf_writer) if self.special_vocab is not None: @@ -185,7 +255,10 @@ class GGMLToGGUF: def add_params(self, gguf_writer): hp = self.model.hyperparameters cfg = self.cfg - desc = cfg.desc if cfg.desc is not None else 'converted from legacy GGJTv3 format' + if cfg.desc is not None: + desc = cfg.desc + else: + desc = f'converted from legacy {self.model.file_format.name}v{self.model.format_version} {hp.ftype.name} format' try: # Filenames aren't necessarily valid UTF8. name = cfg.name if cfg.name is not None else cfg.input.name @@ -195,6 +268,7 @@ class GGMLToGGUF: if name is not None: gguf_writer.add_name(name) gguf_writer.add_description(desc) + gguf_writer.add_file_type(int(hp.ftype)) if self.params_override is not None: po = self.params_override assert po.n_embd == hp.n_embd, 'Model hyperparams mismatch' @@ -231,7 +305,8 @@ class GGMLToGGUF: tokens.append(vbytes) scores.append(score) toktypes.append(ttype) - assert len(tokens) == hp.n_vocab, f'Override vocab has a different number of items than hyperparameters - override = {len(tokens)} but n_vocab={hp.n_vocab}' + assert len(tokens) == hp.n_vocab, \ + f'Override vocab has a different number of items than hyperparameters - override = {len(tokens)} but n_vocab={hp.n_vocab}' gguf_writer.add_token_list(tokens) gguf_writer.add_token_scores(scores) if len(toktypes) > 0: @@ -283,7 +358,11 @@ class GGMLToGGUF: tempdims[1] = tempdims[0] tempdims[0] = temp # print(f'+ {tensor.name} | {mapped_name} {tensor.dims} :: {tempdims}') - gguf_writer.add_tensor(mapped_name, data[tensor.start_offset:tensor.start_offset + tensor.len_bytes], raw_shape = tempdims, raw_dtype = tensor.dtype) + gguf_writer.add_tensor( + mapped_name, + data[tensor.start_offset:tensor.start_offset + tensor.len_bytes], + raw_shape = tempdims, + raw_dtype = tensor.dtype ) def handle_metadata(cfg, hp): import convert @@ -305,32 +384,46 @@ def handle_metadata(cfg, hp): params = convert.Params.loadOriginalParamsJson(fakemodel, orig_config_path) else: raise ValueError('Unable to load metadata') - vocab = convert.load_vocab(cfg.vocab_dir if cfg.vocab_dir is not None else cfg.model_metadata_dir, cfg.vocabtype) + vocab = convert.load_vocab( + cfg.vocab_dir if cfg.vocab_dir is not None else cfg.model_metadata_dir, + cfg.vocabtype ) # FIXME: Respect cfg.vocab_dir? svocab = gguf.SpecialVocab(cfg.model_metadata_dir) convert.check_vocab_size(params, vocab) return (params, vocab, svocab) def handle_args(): - parser = argparse.ArgumentParser(description = 'Convert GGMLv3 models to GGUF') - parser.add_argument('--input', '-i', type = Path, required = True, help = 'Input GGMLv3 filename') - parser.add_argument('--output', '-o', type = Path, required = True, help ='Output GGUF filename') - parser.add_argument('--name', help = 'Set model name') - parser.add_argument('--desc', help = 'Set model description') - parser.add_argument('--gqa', type = int, default = 1, help = 'grouped-query attention factor (use 8 for LLaMA2 70B)') - parser.add_argument('--eps', default = '5.0e-06', help = 'RMS norm eps: Use 1e-6 for LLaMA1 and OpenLLaMA, use 1e-5 for LLaMA2') - parser.add_argument('--context-length', '-c', type=int, default = 2048, help = 'Default max context length: LLaMA1 is typically 2048, LLaMA2 is typically 4096') - parser.add_argument('--model-metadata-dir', '-m', type = Path, help ='Load HuggingFace/.pth vocab and metadata from the specified directory') - parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file - only meaningful with --model-metadata-dir") - parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format - only meaningful with --model-metadata-dir and/or --vocab-dir (default: spm)", default="spm") + parser = argparse.ArgumentParser(description = 'Convert GGML models to GGUF') + parser.add_argument('--input', '-i', type = Path, required = True, + help = 'Input GGMLv3 filename') + parser.add_argument('--output', '-o', type = Path, required = True, + help ='Output GGUF filename') + parser.add_argument('--name', + help = 'Set model name') + parser.add_argument('--desc', + help = 'Set model description') + parser.add_argument('--gqa', type = int, default = 1, + help = 'grouped-query attention factor (use 8 for LLaMA2 70B)') + parser.add_argument('--eps', default = '5.0e-06', + help = 'RMS norm eps: Use 1e-6 for LLaMA1 and OpenLLaMA, use 1e-5 for LLaMA2') + parser.add_argument('--context-length', '-c', type=int, default = 2048, + help = 'Default max context length: LLaMA1 is typically 2048, LLaMA2 is typically 4096') + parser.add_argument('--model-metadata-dir', '-m', type = Path, + help ='Load HuggingFace/.pth vocab and metadata from the specified directory') + parser.add_argument("--vocab-dir", type=Path, + help="directory containing tokenizer.model, if separate from model file - only meaningful with --model-metadata-dir") + parser.add_argument("--vocabtype", choices=["spm", "bpe"], default="spm", + help="vocab format - only meaningful with --model-metadata-dir and/or --vocab-dir (default: spm)") return parser.parse_args() def main(): cfg = handle_args() print(f'* Using config: {cfg}') print('\n=== WARNING === Be aware that this conversion script is best-effort. Use a native GGUF model if possible. === WARNING ===\n') + if cfg.model_metadata_dir is None and (cfg.gqa == 1 or cfg.eps == '5.0e-06'): + print('- Note: If converting LLaMA2, specifying "--eps 1e-5" is required. 70B models also need "--gqa 8".') data = np.memmap(cfg.input, mode = 'r') - model = GGMLV3Model() + model = GGMLModel() print('* Scanning GGML input file') offset = model.load(data, 0) print(f'* GGML model hyperparameters: {model.hyperparameters}') @@ -345,7 +438,12 @@ def main(): print(f'* Special vocab: {special_vocab}') else: print('\n=== WARNING === Special tokens may not be converted correctly. Use --model-metadata-dir if possible === WARNING ===\n') - converter = GGMLToGGUF(model, data, cfg, params_override = params_override, vocab_override = vocab_override, special_vocab = special_vocab) + if model.file_format == GGMLFormat.GGML: + print('! This is a very old GGML file that does not contain vocab scores. Strongly recommend using model metadata!') + converter = GGMLToGGUF(model, data, cfg, + params_override = params_override, + vocab_override = vocab_override, + special_vocab = special_vocab ) converter.save() print(f'* Successful completion. Output saved to: {cfg.output}') From 178b1850ebd21b349cebbee887950e435c5aa2d3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 6 Sep 2023 12:40:57 +0300 Subject: [PATCH 03/12] k-quants : fix zero-weight guard in Q6_K (ref #3040) --- k_quants.c | 1 + 1 file changed, 1 insertion(+) diff --git a/k_quants.c b/k_quants.c index 8742d4aee..eb702ce86 100644 --- a/k_quants.c +++ b/k_quants.c @@ -1089,6 +1089,7 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict if (!max_abs_scale) { memset(&y[i], 0, sizeof(block_q6_K)); y[i].d = ggml_fp32_to_fp16(0.f); + x += QK_K; continue; } From fec2fb19e4229aac58c98171c46e77144b99f8a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Pawe=C5=82czyk?= Date: Thu, 7 Sep 2023 10:15:06 +0200 Subject: [PATCH 04/12] ggml : posixify madvise and pagesize (#3037) * llama : use posix_madvise() instead of madvise() derived from BSD sed -i 's,\,posix_&,g;s,\ 0) { // Advise the kernel to preload the mapped memory - if (madvise(addr, std::min(file->size, prefetch), MADV_WILLNEED)) { - fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n", + if (posix_madvise(addr, std::min(file->size, prefetch), POSIX_MADV_WILLNEED)) { + fprintf(stderr, "warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n", strerror(errno)); } } if (numa) { // advise the kernel not to use readahead // (because the next page might not belong on the same node) - if (madvise(addr, file->size, MADV_RANDOM)) { - fprintf(stderr, "warning: madvise(.., MADV_RANDOM) failed: %s\n", + if (posix_madvise(addr, file->size, POSIX_MADV_RANDOM)) { + fprintf(stderr, "warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n", strerror(errno)); } } From c4f496648c1e32efeb714200e7eae7fc7cfbb223 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 7 Sep 2023 15:49:09 +0300 Subject: [PATCH 05/12] metal : fix kernel_norm (fixes Falcon on Metal) (#3057) * metal : fix kernel_norm ggml-ci * metal : put warning in kernel_norm to not combine the loops * metal : restore original F16 mat-vec multiplication It works after the norm fixes * common : don't do warm-up with more than n_batch tokens (close #3058) ggml-ci * metal : minor --- common/common.cpp | 2 +- ggml-metal.metal | 43 +++++++++++++++++++++++-------------------- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 22f65ac46..28b7c6300 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -773,7 +773,7 @@ std::tuple llama_init_from_gpt_par LOG("warming up the model with an empty run\n"); const std::vector tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; - llama_eval(lctx, tmp.data(), tmp.size(), 0, params.n_threads); + llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads); llama_reset_timings(lctx); } diff --git a/ggml-metal.metal b/ggml-metal.metal index 119fcbeb6..d66ff340a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -220,27 +220,32 @@ kernel void kernel_norm( } threadgroup_barrier(mem_flags::mem_threadgroup); } - //// broadcast - //if (tpitg == 0) { - // sum[0] /= ne00; - //} - //threadgroup_barrier(mem_flags::mem_threadgroup); + // broadcast + if (tpitg == 0) { + sum[0] /= ne00; + } + threadgroup_barrier(mem_flags::mem_threadgroup); const float mean = sum[0]; - // recenter and VARIANCE + // recenter device float * y = dst + tgpig*ne00; - sum[tpitg] = 0.0f; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { y[i00] = x[i00] - mean; + } + + // VARIANCE + // parallel sum + // + // WARNING: combining this loop with the one above will give you wrong results for nth == 256 + // I have no idea why, so for now I am keeping them separate. But this behavior is very concerning. + // Tested with: + // ./perplexity -m ./falcon-7b/ggml-model-q4_0.gguf -f wiki.test.raw -ngl 1 -t 4 + // + sum[tpitg] = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { sum[tpitg] += y[i00] * y[i00]; } - //// VARIANCE - //// parallel sum - //sum[tpitg] = 0.0f; - //for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - // sum[tpitg] += y[i00] * y[i00]; - //} // reduce threadgroup_barrier(mem_flags::mem_threadgroup); for (uint i = ntg/2; i > 0; i /= 2) { @@ -249,11 +254,11 @@ kernel void kernel_norm( } threadgroup_barrier(mem_flags::mem_threadgroup); } - //// broadcast - //if (tpitg == 0) { - // sum[0] /= ne00; - //} - //threadgroup_barrier(mem_flags::mem_threadgroup); + // broadcast + if (tpitg == 0) { + sum[0] /= ne00; + } + threadgroup_barrier(mem_flags::mem_threadgroup); const float variance = sum[0]; const float scale = 1.0f/sqrt(variance + eps); @@ -262,7 +267,6 @@ kernel void kernel_norm( } } - kernel void kernel_rms_norm( device const void * src0, device float * dst, @@ -630,7 +634,6 @@ kernel void kernel_mul_mat_f16_f32( } } } - } kernel void kernel_alibi_f32( From be6beeb8d75294552c4918fce06d7b84eebf3d79 Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Thu, 7 Sep 2023 15:42:42 +0200 Subject: [PATCH 06/12] metal : correct fix of kernel_norm (#3060) Co-authored-by: Iwan Kawrakow Co-authored-by: Georgi Gerganov --- ggml-metal.metal | 30 +++++------------------------- 1 file changed, 5 insertions(+), 25 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index d66ff340a..5edf6d521 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -220,29 +220,14 @@ kernel void kernel_norm( } threadgroup_barrier(mem_flags::mem_threadgroup); } - // broadcast - if (tpitg == 0) { - sum[0] /= ne00; - } + const float mean = sum[0] / ne00; + + // recenter and VARIANCE threadgroup_barrier(mem_flags::mem_threadgroup); - const float mean = sum[0]; - - // recenter device float * y = dst + tgpig*ne00; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - y[i00] = x[i00] - mean; - } - - // VARIANCE - // parallel sum - // - // WARNING: combining this loop with the one above will give you wrong results for nth == 256 - // I have no idea why, so for now I am keeping them separate. But this behavior is very concerning. - // Tested with: - // ./perplexity -m ./falcon-7b/ggml-model-q4_0.gguf -f wiki.test.raw -ngl 1 -t 4 - // sum[tpitg] = 0.0f; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + y[i00] = x[i00] - mean; sum[tpitg] += y[i00] * y[i00]; } @@ -254,12 +239,7 @@ kernel void kernel_norm( } threadgroup_barrier(mem_flags::mem_threadgroup); } - // broadcast - if (tpitg == 0) { - sum[0] /= ne00; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - const float variance = sum[0]; + const float variance = sum[0] / ne00; const float scale = 1.0f/sqrt(variance + eps); for (int i00 = tpitg; i00 < ne00; i00 += ntg) { From be8c9c245bd129ebabb80e0a7a8dd7daeb4d30af Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Thu, 7 Sep 2023 15:45:01 +0200 Subject: [PATCH 07/12] metal : parallel RoPE on Metal (#3024) * Parallel RoPE on metal * PR suggestion --------- Co-authored-by: Iwan Kawrakow --- ggml-metal.m | 2 +- ggml-metal.metal | 26 ++++++++++++++------------ 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 521ca180f..7e2355ce6 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1141,7 +1141,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&freq_base length:sizeof(float) atIndex:21]; [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 5edf6d521..5070561fb 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -682,25 +682,27 @@ kernel void kernel_rope( constant int & mode, constant float & freq_base, constant float & freq_scale, - uint3 tpig[[thread_position_in_grid]]) { - const int64_t i3 = tpig[2]; - const int64_t i2 = tpig[1]; - const int64_t i1 = tpig[0]; + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int64_t i3 = tgpig[2]; + const int64_t i2 = tgpig[1]; + const int64_t i1 = tgpig[0]; const bool is_neox = mode & 2; - const float theta_scale = pow(freq_base, -2.0f/n_dims); const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); - float theta = freq_scale * (float)p; + const float theta_0 = freq_scale * (float)p; + const float inv_ndims = -1.f/n_dims; if (!is_neox) { - for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + + const float theta = theta_0 * pow(freq_base, inv_ndims*i0); const float cos_theta = cos(theta); const float sin_theta = sin(theta); - theta *= theta_scale; - device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); @@ -712,12 +714,12 @@ kernel void kernel_rope( } } else { for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { - for (int64_t ic = 0; ic < n_dims; ic += 2) { + for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) { + + const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib); const float cos_theta = cos(theta); const float sin_theta = sin(theta); - theta *= theta_scale; - const int64_t i0 = ib*n_dims + ic/2; device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); From 15b67a66c2f2d6032415b28a699b5131962318f1 Mon Sep 17 00:00:00 2001 From: slaren Date: Thu, 7 Sep 2023 15:52:34 +0200 Subject: [PATCH 08/12] llama-bench : use two tokens in the warmup run for prompt evals (#3059) --- examples/llama-bench/llama-bench.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 72a025077..dedaa34fd 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -986,7 +986,12 @@ int main(int argc, char ** argv) { test t(inst, lmodel, ctx); // warmup run - test_gen(ctx, 1, 0, t.n_threads); + if (t.n_prompt > 0) { + test_prompt(ctx, std::min(2, t.n_batch), 0, t.n_batch, t.n_threads); + } + if (t.n_gen > 0) { + test_gen(ctx, 1, 0, t.n_threads); + } for (int i = 0; i < params.reps; i++) { uint64_t t_start = get_time_ns(); From 5ffab089a54bc06ae4a9ab533893b558756a1e80 Mon Sep 17 00:00:00 2001 From: Cebtenzzre Date: Thu, 7 Sep 2023 10:13:50 -0400 Subject: [PATCH 09/12] make : fix CPPFLAGS (#3035) --- Makefile | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index 4334761a4..5d76bce87 100644 --- a/Makefile +++ b/Makefile @@ -91,8 +91,8 @@ else OPT = -O3 endif MK_CPPFLAGS = -I. -Icommon -MK_CFLAGS = $(CPPFLAGS) $(OPT) -std=c11 -fPIC -MK_CXXFLAGS = $(CPPFLAGS) $(OPT) -std=c++11 -fPIC +MK_CFLAGS = $(OPT) -std=c11 -fPIC +MK_CXXFLAGS = $(OPT) -std=c++11 -fPIC MK_LDFLAGS = ifdef LLAMA_DEBUG @@ -381,9 +381,8 @@ k_quants.o: k_quants.c k_quants.h endif # LLAMA_NO_K_QUANTS # combine build flags with cmdline overrides -override CPPFLAGS := $(MK_CPPFLAGS) $(CPPFLAGS) -override CFLAGS := $(MK_CFLAGS) $(CFLAGS) -override CXXFLAGS := $(MK_CXXFLAGS) $(CXXFLAGS) +override CFLAGS := $(MK_CPPFLAGS) $(CPPFLAGS) $(MK_CFLAGS) $(CFLAGS) +override CXXFLAGS := $(MK_CPPFLAGS) $(CPPFLAGS) $(MK_CXXFLAGS) $(CXXFLAGS) override LDFLAGS := $(MK_LDFLAGS) $(LDFLAGS) # From 4fa2cc1750b861880de42515cb19c13b2d776ee2 Mon Sep 17 00:00:00 2001 From: Cebtenzzre Date: Thu, 7 Sep 2023 10:15:01 -0400 Subject: [PATCH 10/12] make : improve test target (#3031) --- Makefile | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index 5d76bce87..4f311ee2c 100644 --- a/Makefile +++ b/Makefile @@ -42,9 +42,9 @@ endif default: $(BUILD_TARGETS) -test: - @echo "Running tests..." - @for test_target in $(TEST_TARGETS); do \ +test: $(TEST_TARGETS) + @failures=0; \ + for test_target in $(TEST_TARGETS); do \ if [ "$$test_target" = "tests/test-tokenizer-0-llama" ]; then \ ./$$test_target $(CURDIR)/models/ggml-vocab-llama.gguf; \ elif [ "$$test_target" = "tests/test-tokenizer-0-falcon" ]; then \ @@ -52,10 +52,21 @@ test: elif [ "$$test_target" = "tests/test-tokenizer-1" ]; then \ continue; \ else \ + echo "Running test $$test_target..."; \ ./$$test_target; \ fi; \ - done - @echo "All tests have been run." + if [ $$? -ne 0 ]; then \ + printf 'Test $$test_target FAILED!\n\n' $$test_target; \ + failures=$$(( failures + 1 )); \ + else \ + printf 'Test %s passed.\n\n' $$test_target; \ + fi; \ + done; \ + if [ $$failures -gt 0 ]; then \ + printf '\n%s tests failed.\n' $$failures; \ + exit 1; \ + fi + @echo 'All tests passed.' all: $(BUILD_TARGETS) $(TEST_TARGETS) From 00d62adb79bf914a95fb9a2e8f42f3029e76d62c Mon Sep 17 00:00:00 2001 From: Cebtenzzre Date: Thu, 7 Sep 2023 13:22:29 -0400 Subject: [PATCH 11/12] fix some warnings from gcc and clang-tidy (#3038) Co-authored-by: xaedes --- .clang-tidy | 5 ++ CMakeLists.txt | 2 +- Makefile | 2 +- common/common.cpp | 2 +- common/common.h | 3 ++ common/grammar-parser.cpp | 1 + .../convert-llama2c-to-ggml.cpp | 8 ++-- examples/embd-input/embd-input-lib.cpp | 2 +- examples/embedding/embedding.cpp | 2 +- examples/gptneox-wip/falcon-main.cpp | 2 +- examples/gptneox-wip/gptneox-main.cpp | 2 +- examples/main/main.cpp | 19 ++++---- examples/perplexity/perplexity.cpp | 2 +- examples/quantize-stats/quantize-stats.cpp | 2 +- examples/quantize/quantize.cpp | 7 ++- examples/save-load-state/save-load-state.cpp | 4 +- examples/server/server.cpp | 8 ++-- .../train-text-from-scratch.cpp | 46 ++++--------------- ggml-alloc.c | 6 +-- ggml.c | 10 ++-- llama.cpp | 27 ++--------- tests/test-quantize-perf.cpp | 2 +- 22 files changed, 63 insertions(+), 101 deletions(-) diff --git a/.clang-tidy b/.clang-tidy index 1a42b9abc..3078beacc 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -3,6 +3,7 @@ Checks: > bugprone-*, -bugprone-easily-swappable-parameters, -bugprone-implicit-widening-of-multiplication-result, + -bugprone-misplaced-widening-cast, -bugprone-narrowing-conversions, readability-*, -readability-avoid-unconditional-preprocessor-if, @@ -15,4 +16,8 @@ Checks: > -clang-analyzer-security.insecureAPI.DeprecatedOrUnsafeBufferHandling, performance-*, portability-*, + misc-*, + -misc-const-correctness, + -misc-non-private-member-variables-in-classes, + -misc-no-recursion, FormatStyle: none diff --git a/CMakeLists.txt b/CMakeLists.txt index d4ed6179e..d4fa5c261 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -426,7 +426,7 @@ if (LLAMA_ALL_WARNINGS) ) if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") # g++ only - set(cxx_flags ${cxx_flags} -Wno-format-truncation) + set(cxx_flags ${cxx_flags} -Wno-format-truncation -Wno-array-bounds) endif() else() # todo : msvc diff --git a/Makefile b/Makefile index 4f311ee2c..86e36ba52 100644 --- a/Makefile +++ b/Makefile @@ -134,7 +134,7 @@ MK_CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-m ifeq '' '$(findstring clang++,$(CXX))' # g++ only - MK_CXXFLAGS += -Wno-format-truncation + MK_CXXFLAGS += -Wno-format-truncation -Wno-array-bounds endif # OS specific diff --git a/common/common.cpp b/common/common.cpp index 28b7c6300..6e5d5b4d5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -57,7 +57,7 @@ int32_t get_num_physical_cores() { siblings.insert(line); } } - if (siblings.size() > 0) { + if (!siblings.empty()) { return static_cast(siblings.size()); } #elif defined(__APPLE__) && defined(__MACH__) diff --git a/common/common.h b/common/common.h index 85ac0df9b..012bf5e13 100644 --- a/common/common.h +++ b/common/common.h @@ -20,6 +20,9 @@ #define DIRECTORY_SEPARATOR '/' #endif // _WIN32 +#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0) +#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", ##__VA_ARGS__); exit(1); } while (0) + // // CLI argument parsing // diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp index e76bd11c3..177d1e3a8 100644 --- a/common/grammar-parser.cpp +++ b/common/grammar-parser.cpp @@ -415,6 +415,7 @@ namespace grammar_parser { std::vector parse_state::c_rules() { std::vector ret; + ret.reserve(rules.size()); for (const auto & rule : rules) { ret.push_back(rule.data()); } diff --git a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp index 9e856c21a..293b455d0 100644 --- a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +++ b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp @@ -1,5 +1,6 @@ #include "ggml.h" #include "llama.h" +#include "common.h" #include #include @@ -499,10 +500,10 @@ struct llama_file { errno = 0; std::size_t ret = std::fread(ptr, size, 1, fp); if (ferror(fp)) { - throw std::runtime_error(format("read error: %s", strerror(errno))); + die_fmt("fread failed: %s", strerror(errno)); } if (ret != 1) { - throw std::runtime_error(std::string("unexpectedly reached end of file")); + die("unexpectedly reached end of file"); } } @@ -597,8 +598,7 @@ void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab) printf("Assuming llama2.c vocabulary since %s is not a gguf file\n", filename); llama_file file(filename, "rb"); if (!file.fp) { - fprintf(stderr, "error: %s: %s\n", strerror(errno), filename); - exit(1); + die_fmt("%s: %s", strerror(errno), filename); } const int n_vocab = config->vocab_size; /* uint32_t max_token_length = */ file.read_u32(); // unused diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index 036bdb398..87aac3479 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -23,7 +23,7 @@ extern "C" { struct MyModel* create_mymodel(int argc, char ** argv) { gpt_params params; - if (gpt_params_parse(argc, argv, params) == false) { + if (!gpt_params_parse(argc, argv, params)) { return nullptr; } diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 93d583b5c..49ab3e063 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -11,7 +11,7 @@ int main(int argc, char ** argv) { gpt_params params; - if (gpt_params_parse(argc, argv, params) == false) { + if (!gpt_params_parse(argc, argv, params)) { return 1; } diff --git a/examples/gptneox-wip/falcon-main.cpp b/examples/gptneox-wip/falcon-main.cpp index d4b130b25..7f9a1620b 100644 --- a/examples/gptneox-wip/falcon-main.cpp +++ b/examples/gptneox-wip/falcon-main.cpp @@ -953,7 +953,7 @@ int main(int argc, char ** argv) { gpt_params params; - if (gpt_params_parse(argc, argv, params) == false) { + if (!gpt_params_parse(argc, argv, params)) { return 1; } diff --git a/examples/gptneox-wip/gptneox-main.cpp b/examples/gptneox-wip/gptneox-main.cpp index b6cc46c5f..55eba0cdc 100644 --- a/examples/gptneox-wip/gptneox-main.cpp +++ b/examples/gptneox-wip/gptneox-main.cpp @@ -925,7 +925,7 @@ int main(int argc, char ** argv) { gpt_params params; - if (gpt_params_parse(argc, argv, params) == false) { + if (!gpt_params_parse(argc, argv, params)) { return 1; } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 9201b53bd..c9ca7719b 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -48,8 +48,9 @@ static bool is_interacting = false; void write_logfile( const llama_context * ctx, const gpt_params & params, const llama_model * model, - const std::vector input_tokens, const std::string output, const std::vector output_tokens) { - + const std::vector & input_tokens, const std::string & output, + const std::vector & output_tokens +) { if (params.logdir.empty()) { return; } @@ -109,7 +110,7 @@ int main(int argc, char ** argv) { gpt_params params; g_params = ¶ms; - if (gpt_params_parse(argc, argv, params) == false) { + if (!gpt_params_parse(argc, argv, params)) { return 1; } @@ -303,7 +304,7 @@ int main(int argc, char ** argv) { // debug message about similarity of saved session, if applicable size_t n_matching_session_tokens = 0; - if (session_tokens.size() > 0) { + if (!session_tokens.empty()) { for (llama_token id : session_tokens) { if (n_matching_session_tokens >= embd_inp.size() || id != embd_inp[n_matching_session_tokens]) { break; @@ -401,7 +402,7 @@ int main(int argc, char ** argv) { LOG_TEE("%s: interactive mode on.\n", __func__); - if (params.antiprompt.size()) { + if (!params.antiprompt.empty()) { for (const auto & antiprompt : params.antiprompt) { LOG_TEE("Reverse prompt: '%s'\n", antiprompt.c_str()); } @@ -499,7 +500,7 @@ int main(int argc, char ** argv) { while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict - if (embd.size() > 0) { + if (!embd.empty()) { // Note: n_ctx - 4 here is to match the logic for commandline prompt handling via // --prompt or --file which uses the same value. int max_embd_size = n_ctx - 4; @@ -624,7 +625,7 @@ int main(int argc, char ** argv) { LOG("n_past = %d\n", n_past); } - if (embd.size() > 0 && !path_session.empty()) { + if (!embd.empty() && !path_session.empty()) { session_tokens.insert(session_tokens.end(), embd.begin(), embd.end()); n_session_consumed = session_tokens.size(); } @@ -695,7 +696,7 @@ int main(int argc, char ** argv) { // if not currently processing queued inputs; if ((int) embd_inp.size() <= n_consumed) { // check for reverse prompt - if (params.antiprompt.size()) { + if (!params.antiprompt.empty()) { std::string last_output; for (auto id : last_tokens) { last_output += llama_token_to_piece(ctx, id); @@ -732,7 +733,7 @@ int main(int argc, char ** argv) { LOG("found EOS token\n"); if (params.interactive) { - if (params.antiprompt.size() != 0) { + if (!params.antiprompt.empty()) { // tokenize and inject first reverse prompt const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false); embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end()); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 843b2ae35..1b760683b 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -655,7 +655,7 @@ int main(int argc, char ** argv) { gpt_params params; params.n_batch = 512; - if (gpt_params_parse(argc, argv, params) == false) { + if (!gpt_params_parse(argc, argv, params)) { return 1; } diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 06ce18f09..6ce03ba7b 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -71,7 +71,7 @@ void quantize_stats_print_usage(int /*argc*/, char ** argv) { } // Check if a layer is included/excluded by command line -bool layer_included(const quantize_stats_params params, const std::string & layer) { +bool layer_included(const quantize_stats_params & params, const std::string & layer) { for (const auto& excluded : params.exclude_layers) { if (std::regex_search(layer, std::regex(excluded))) { return false; diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index c174be069..1bf182482 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -143,10 +143,9 @@ int main(int argc, char ** argv) { if (!try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) { fprintf(stderr, "%s: invalid ftype '%s'\n", __func__, argv[3]); return 1; - } else { - if (ftype_str == "COPY") { - params.only_copy = true; - } + } + if (ftype_str == "COPY") { + params.only_copy = true; } arg_idx++; } diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 573bc4ef9..14e9501ca 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -13,7 +13,7 @@ int main(int argc, char ** argv) { params.repeat_last_n = 64; params.prompt = "The quick brown fox"; - if (gpt_params_parse(argc, argv, params) == false) { + if (!gpt_params_parse(argc, argv, params)) { return 1; } @@ -44,7 +44,7 @@ int main(int argc, char ** argv) { llama_free_model(model); return 1; } - auto tokens = llama_tokenize(ctx, params.prompt.c_str(), true); + auto tokens = llama_tokenize(ctx, params.prompt, true); auto n_prompt_tokens = tokens.size(); if (n_prompt_tokens < 1) { fprintf(stderr, "%s : failed to tokenize prompt\n", __func__); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6b606447d..3f3c64650 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -139,7 +139,7 @@ static std::string tokens_to_output_formatted_string(const llama_context *ctx, c } // convert a vector of completion_token_output to json -static json probs_vector_to_json(const llama_context *ctx, const std::vector probs) +static json probs_vector_to_json(const llama_context *ctx, const std::vector & probs) { json out = json::array(); for (const auto &prob : probs) @@ -271,7 +271,7 @@ struct llama_server_context return true; } - std::vector tokenize(json json_prompt, bool add_bos) + std::vector tokenize(const json & json_prompt, bool add_bos) const { // If `add_bos` is true, we only add BOS, when json_prompt is a string, // or the first element of the json_prompt array is a string. @@ -611,7 +611,7 @@ struct llama_server_context completion_token_output doCompletion() { - const completion_token_output token_with_probs = nextToken(); + auto token_with_probs = nextToken(); const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok); generated_text += token_text; @@ -1255,7 +1255,7 @@ void beam_search_callback(void * callback_data, llama_beams_state beams_state) { struct token_translator { llama_context * ctx; std::string operator()(llama_token tok) const { return llama_token_to_piece(ctx, tok); } - std::string operator()(completion_token_output cto) const { return (*this)(cto.tok); } + std::string operator()(const completion_token_output & cto) const { return (*this)(cto.tok); } }; void append_to_generated_text_from_generated_token_probs(llama_server_context & llama) { diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 6fe85d419..947aa7ed3 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -169,10 +169,6 @@ struct my_llama_hparams { float rope_freq_base = 10000.0f; float rope_freq_scale = 1.0f; - - bool operator!=(const my_llama_hparams& other) const { - return memcmp(this, &other, sizeof(my_llama_hparams)); - } }; struct my_llama_layer { @@ -929,28 +925,6 @@ void get_example_targets_batch(struct llama_context * lctx, const int * train_sa } } - -#ifdef __GNUC__ -#ifdef __MINGW32__ -__attribute__((format(gnu_printf, 1, 2))) -#else -__attribute__((format(printf, 1, 2))) -#endif -#endif -static std::string format(const char * fmt, ...) { - va_list ap, ap2; - va_start(ap, fmt); - va_copy(ap2, ap); - int size = vsnprintf(NULL, 0, fmt, ap); - GGML_ASSERT(size >= 0 && size < INT_MAX); - std::vector buf(size + 1); - int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); - GGML_ASSERT(size2 == size); - va_end(ap2); - va_end(ap); - return std::string(buf.data(), size); -} - int tokenize_file(struct llama_context * lctx, const char * filename, std::vector& out) { FILE * fp = std::fopen(filename, "rb"); if (fp == NULL) { @@ -983,10 +957,10 @@ int tokenize_file(struct llama_context * lctx, const char * filename, std::vecto out.resize(size+1); if (std::fread(buf.data(), size, 1, fp) != 1) { - throw std::runtime_error(std::string("unexpectedly reached end of file")); + die("unexpectedly reached end of file"); } if (ferror(fp)) { - throw std::runtime_error(format("read error: %s", strerror(errno))); + die_fmt("fread failed: %s", strerror(errno)); } buf[size] = '\0'; @@ -1047,11 +1021,11 @@ void shuffle_ints(int * begin, int * end) { if (kid >= 0) { \ enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \ if (ktype != (type)) { \ - throw std::runtime_error(format("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype))); \ + die_fmt("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype)); \ } \ (dst) = func(ctx, kid); \ } else if (req) { \ - throw std::runtime_error(format("key not found in model: %s", skey.c_str())); \ + die_fmt("key not found in model: %s", skey.c_str()); \ } \ } @@ -1136,7 +1110,7 @@ void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_g read_tensor_by_name(opt->lbfgs.lms, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S); read_tensor_by_name(opt->lbfgs.lmy, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y); } else { - throw std::runtime_error("unknown optimizer type\n"); + die("unknown optimizer type"); } } @@ -1315,20 +1289,20 @@ void save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vocab_mod const int token_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_LIST)); if (token_idx == -1) { - throw std::runtime_error("cannot find tokenizer vocab in model file\n"); + die("cannot find tokenizer vocab in model file"); } const uint32_t n_vocab = gguf_get_arr_n(vctx, token_idx); const int score_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_SCORES)); if (score_idx == -1) { - throw std::runtime_error("cannot find tokenizer scores in model file\n"); + die("cannot find tokenizer scores in model file"); } const float * scores = (const float * ) gguf_get_arr_data(vctx, score_idx); const int toktype_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE)); if (toktype_idx == -1) { - throw std::runtime_error("cannot find token type list in GGUF file\n"); + die("cannot find token type list in GGUF file"); } const int * toktypes = (const int * ) gguf_get_arr_data(vctx, toktype_idx); @@ -1356,7 +1330,7 @@ void save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vocab_mod // read and copy bpe merges const int merges_keyidx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_MERGES)); if (merges_keyidx == -1) { - throw std::runtime_error("cannot find tokenizer merges in model file\n"); + die("cannot find tokenizer merges in model file"); } const int n_merges = gguf_get_arr_n(vctx, merges_keyidx); @@ -1988,7 +1962,7 @@ void opt_callback(void * vdata, float * sched) { float min_sched = params->adam_min_alpha / params->adam_alpha; *sched = min_sched + *sched * (1.0f - min_sched); - int impr_plot = std::isnan(opt->loss_after) ? 0 : -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f); + int impr_plot = std::isnan(opt->loss_after) ? 0 : -std::lround(1 + (opt->loss_before - opt->loss_after) * 10.0f); printf("%s: iter=%*d, sched=%f loss0=%f loss=%f | improvement: %*d>\n", __func__, 6, opt->iter, *sched, opt->loss_before, opt->loss_after, impr_plot, (int)0); if (data->shuffle_countdown < n_batch) { diff --git a/ggml-alloc.c b/ggml-alloc.c index c1939a4b7..a896601d1 100644 --- a/ggml-alloc.c +++ b/ggml-alloc.c @@ -138,7 +138,7 @@ static bool ggml_allocr_is_own(struct ggml_allocr * alloc, const struct ggml_ten void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { #ifdef GGML_ALLOCATOR_DEBUG - GGML_ASSERT(ggml_is_view(tensor) == false); // views generally get data pointer from one of their sources + GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated #endif size_t size = ggml_allocr_get_alloc_size(alloc, tensor); @@ -165,14 +165,14 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) if (best_fit_block == -1) { // the last block is our last resort struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1]; + max_avail = MAX(max_avail, block->size); if (block->size >= size) { best_fit_block = alloc->n_free_blocks - 1; - max_avail = MAX(max_avail, block->size); } else { fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n", __func__, size, max_avail); GGML_ASSERT(!"not enough space in the buffer"); - return; + return; } } struct free_block * block = &alloc->free_blocks[best_fit_block]; diff --git a/ggml.c b/ggml.c index 50adf18ec..8a677ab2a 100644 --- a/ggml.c +++ b/ggml.c @@ -4768,7 +4768,7 @@ static struct ggml_tensor * ggml_new_tensor_impl( size_t obj_alloc_size = 0; - if (view_src == NULL && ctx->no_alloc == false) { + if (view_src == NULL && !ctx->no_alloc) { if (ctx->scratch.data != NULL) { // allocate tensor data in the scratch buffer if (ctx->scratch.offs + data_size > ctx->scratch.size) { @@ -5469,7 +5469,7 @@ static struct ggml_tensor * ggml_mul_impl( } if (inplace) { - GGML_ASSERT(is_node == false); + GGML_ASSERT(!is_node); } struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); @@ -5512,7 +5512,7 @@ static struct ggml_tensor * ggml_div_impl( } if (inplace) { - GGML_ASSERT(is_node == false); + GGML_ASSERT(!is_node); } struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); @@ -19957,7 +19957,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p struct ggml_tensor * data = NULL; - if (params.no_alloc == false) { + if (!params.no_alloc) { data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size); ok = ok && data != NULL; @@ -19998,7 +19998,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p } // point the data member to the appropriate location in the binary blob using the tensor infos - if (params.no_alloc == false) { + if (!params.no_alloc) { //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file cur->data = (char *) data->data + ctx->infos[i].offset; // offset from data } diff --git a/llama.cpp b/llama.cpp index 2c9071a8f..208dcef0e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3052,33 +3052,10 @@ static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) { return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_CONTROL; } -static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token id) { - return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED; -} - -static bool llama_is_unused_token(const llama_vocab & vocab, llama_token id) { - return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_UNUSED; -} - static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) { return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE; } -static bool llama_is_bos_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(llama_is_control_token(vocab, id)); - return id == vocab.special_bos_id; -} - -static bool llama_is_eos_token(const llama_vocab & vocab, llama_token id ) { - GGML_ASSERT(llama_is_control_token(vocab, id)); - return id == vocab.special_eos_id; -} - -static bool llama_is_pad_token(const llama_vocab & vocab, llama_token id ) { - GGML_ASSERT(id < 0 || llama_is_control_token(vocab, id)); - return id == vocab.special_pad_id; -} - static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) { GGML_ASSERT(llama_is_byte_token(vocab, id)); const auto& token_data = vocab.id_to_token.at(id); @@ -4800,9 +4777,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s std::vector workers; std::mutex mutex; +#ifdef GGML_USE_K_QUANTS auto use_more_bits = [] (int i_layer, int num_layers) -> bool { return i_layer < num_layers/8 || i_layer >= 7*num_layers/8 || (i_layer - num_layers/8)%3 == 2; }; +#endif int idx = 0; @@ -5947,7 +5926,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { rng_ss.str(std::string(&rng_buf[0], rng_size)); rng_ss >> ctx->rng; - GGML_ASSERT(rng_ss.fail() == false); + GGML_ASSERT(!rng_ss.fail()); } // set logits diff --git a/tests/test-quantize-perf.cpp b/tests/test-quantize-perf.cpp index 0bb9537f6..cbea7d452 100644 --- a/tests/test-quantize-perf.cpp +++ b/tests/test-quantize-perf.cpp @@ -76,7 +76,7 @@ void * align_with_offset(void * ptr, int offset) { return (char *) std::align(MAX_ALIGNMENT, MAX_ALIGNMENT, ptr, dummy_size) + offset; } -void benchmark_function(size_t size, size_t q_size, int64_t iterations, std::function function) { +void benchmark_function(size_t size, size_t q_size, int64_t iterations, const std::function & function) { int64_t min_time_us = INT64_MAX; int64_t total_time_us = 0; int64_t min_time_cycles = INT64_MAX; From 6336d834ec7bff3e93e24182c0f609d2f2bdce26 Mon Sep 17 00:00:00 2001 From: Cebtenzzre Date: Thu, 7 Sep 2023 14:27:42 -0400 Subject: [PATCH 12/12] convert : fix F32 ftype not being saved (#3048) --- convert.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/convert.py b/convert.py index 59d75141d..79a7cd52b 100755 --- a/convert.py +++ b/convert.py @@ -266,7 +266,7 @@ class Params: f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None # hack to determine LLaMA v1 vs v2 vs CodeLlama - if f_rope_freq_base and f_rope_freq_base == 1000000: + if f_rope_freq_base == 1000000: # CodeLlama n_ctx = 16384 elif config["norm_eps"] == 1e-05: @@ -841,9 +841,9 @@ class OutputFile: name = "LLaMA" # TODO: better logic to determine model name - if (params.n_ctx == 4096): + if params.n_ctx == 4096: name = "LLaMA v2" - elif params.path_model: + elif params.path_model is not None: name = str(params.path_model.parent).split('/')[-1] self.gguf.add_name (name) @@ -856,13 +856,13 @@ class OutputFile: self.gguf.add_head_count_kv (params.n_head_kv) self.gguf.add_layer_norm_rms_eps (params.f_norm_eps) - if params.f_rope_freq_base: + if params.f_rope_freq_base is not None: self.gguf.add_rope_freq_base(params.f_rope_freq_base) - if params.f_rope_scale: + if params.f_rope_scale is not None: self.gguf.add_rope_scale_linear(params.f_rope_scale) - if params.ftype: + if params.ftype is not None: self.gguf.add_file_type(params.ftype) def add_meta_vocab(self, vocab: Vocab) -> None: