Merge 'origin/master' into hipblas
This commit is contained in:
commit
df7346ccd5
5 changed files with 92 additions and 57 deletions
|
@ -251,6 +251,15 @@ if (LLAMA_CUBLAS)
|
|||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
|
||||
endif()
|
||||
|
||||
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
|
||||
if (LLAMA_CUDA_DMMV_F16)
|
||||
set(CMAKE_CUDA_ARCHITECTURES "61") # needed for f16 CUDA intrinsics
|
||||
else()
|
||||
set(CMAKE_CUDA_ARCHITECTURES "52") # lowest CUDA 12 standard
|
||||
endif()
|
||||
endif()
|
||||
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
||||
|
||||
else()
|
||||
message(WARNING "cuBLAS not found")
|
||||
endif()
|
||||
|
@ -525,22 +534,6 @@ if (BUILD_SHARED_LIBS)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
if (GGML_SOURCES_CUDA)
|
||||
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
|
||||
set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES "native")
|
||||
set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
|
||||
|
||||
set_property(TARGET ggml_static PROPERTY CUDA_ARCHITECTURES "native")
|
||||
set_property(TARGET ggml_static PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
set_property(TARGET ggml_shared PROPERTY CUDA_ARCHITECTURES "native")
|
||||
set_property(TARGET ggml_shared PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
|
||||
endif()
|
||||
|
||||
set_property(TARGET llama PROPERTY CUDA_ARCHITECTURES "native")
|
||||
endif()
|
||||
|
||||
|
||||
#
|
||||
# programs, examples and tests
|
||||
|
|
10
README.md
10
README.md
|
@ -9,12 +9,8 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
|
|||
|
||||
**Hot topics:**
|
||||
|
||||
- p1 : LLM-based code completion engine at the edge : https://github.com/ggml-org/p1/discussions/1
|
||||
- Roadmap June 2023: https://github.com/ggerganov/llama.cpp/discussions/1729
|
||||
- GPU support with Metal (Apple Silicon): https://github.com/ggerganov/llama.cpp/pull/1642
|
||||
- High-quality 2,3,4,5,6-bit quantization: https://github.com/ggerganov/llama.cpp/pull/1684
|
||||
- Multi-GPU support: https://github.com/ggerganov/llama.cpp/pull/1607
|
||||
- Training LLaMA models from scratch: https://github.com/ggerganov/llama.cpp/pull/1652
|
||||
- CPU threading improvements: https://github.com/ggerganov/llama.cpp/pull/1632
|
||||
|
||||
<details>
|
||||
<summary>Table of Contents</summary>
|
||||
|
@ -344,7 +340,7 @@ Building the program with BLAS support may lead to some performance improvements
|
|||
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
|
||||
| LLAMA_CUDA_DMMV_Y | Positive integer | 1 | Block size in y direction for the CUDA dequantization + mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. |
|
||||
| LLAMA_CUDA_DMMV_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels. Can improve performance on relatively recent GPUs. |
|
||||
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value 2 1 can improve performance for slow GPUs. |
|
||||
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
|
||||
|
||||
- #### CLBlast
|
||||
|
||||
|
@ -378,7 +374,7 @@ Building the program with BLAS support may lead to some performance improvements
|
|||
```sh
|
||||
git clone https://github.com/CNugteren/CLBlast.git
|
||||
mkdir CLBlast/build
|
||||
cd CLBLast/build
|
||||
cd CLBlast/build
|
||||
cmake .. -DBUILD_SHARED_LIBS=OFF -DTUNERS=OFF
|
||||
cmake --build . --config Release
|
||||
cmake --install . --prefix /some/path
|
||||
|
|
91
convert.py
91
convert.py
|
@ -130,6 +130,14 @@ TENSORS_LIST = make_tensors_list()
|
|||
TENSORS_SET = set(TENSORS_LIST)
|
||||
|
||||
|
||||
def find_n_mult(n_ff: int, n_embd: int) -> int:
|
||||
# hardcoded magic range
|
||||
for n_mult in range(256, 1, -1):
|
||||
calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult
|
||||
if calc_ff == n_ff:
|
||||
return n_mult
|
||||
return 1
|
||||
|
||||
@dataclass
|
||||
class Params:
|
||||
n_vocab: int
|
||||
|
@ -137,21 +145,61 @@ class Params:
|
|||
n_mult: int
|
||||
n_head: int
|
||||
n_layer: int
|
||||
file_type: GGMLFileType
|
||||
|
||||
@staticmethod
|
||||
def guessed(model: 'LazyModel', file_type: GGMLFileType) -> 'Params':
|
||||
n_vocab, n_embd = model["tok_embeddings.weight"].shape
|
||||
def guessed(model: 'LazyModel') -> 'Params':
|
||||
# try transformer naming first
|
||||
n_vocab, n_embd = model["model.embed_tokens.weight"].shape if "model.embed_tokens.weight" in model else model["tok_embeddings.weight"].shape
|
||||
|
||||
# try transformer naming first
|
||||
if "model.layers.0.self_attn.q_proj.weight" in model:
|
||||
n_layer=next(i for i in itertools.count() if f"model.layers.{i}.self_attn.q_proj.weight" not in model)
|
||||
else:
|
||||
n_layer=next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model)
|
||||
|
||||
n_head=n_embd // 128 # guessed
|
||||
|
||||
return Params(
|
||||
n_vocab=n_vocab,
|
||||
n_embd=n_embd,
|
||||
n_mult=256,
|
||||
n_head=n_embd // 128,
|
||||
n_layer=next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model),
|
||||
file_type=file_type,
|
||||
n_head=n_head,
|
||||
n_layer=n_layer,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
|
||||
config = json.load(open(config_path))
|
||||
|
||||
n_vocab = config["vocab_size"];
|
||||
n_embd = config["hidden_size"];
|
||||
n_head = config["num_attention_heads"];
|
||||
n_layer = config["num_hidden_layers"];
|
||||
n_ff = config["intermediate_size"];
|
||||
|
||||
n_mult = find_n_mult(n_ff, n_embd);
|
||||
|
||||
return Params(
|
||||
n_vocab=n_vocab,
|
||||
n_embd=n_embd,
|
||||
n_mult=n_mult,
|
||||
n_head=n_head,
|
||||
n_layer=n_layer,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load(model_plus: 'ModelPlus') -> 'Params':
|
||||
orig_config_path = model_plus.paths[0].parent / "params.json"
|
||||
hf_transformer_config_path = model_plus.paths[0].parent / "config.json"
|
||||
|
||||
if hf_transformer_config_path.exists():
|
||||
params = Params.loadHFTransformerJson(model_plus.model, hf_transformer_config_path)
|
||||
else:
|
||||
params = Params.guessed(model_plus.model)
|
||||
|
||||
print(f'params: n_vocab:{params.n_vocab} n_embd:{params.n_embd} n_mult:{params.n_mult} n_head:{params.n_head} n_layer:{params.n_layer}')
|
||||
return params
|
||||
|
||||
|
||||
class SentencePieceVocab:
|
||||
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
|
||||
|
@ -595,18 +643,17 @@ def permute_lazy(lazy_tensor: LazyTensor, n_head: int) -> LazyTensor:
|
|||
return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description)
|
||||
|
||||
|
||||
def convert_transformers_to_orig(model: LazyModel) -> LazyModel:
|
||||
def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
|
||||
out: LazyModel = {}
|
||||
out["tok_embeddings.weight"] = model["model.embed_tokens.weight"]
|
||||
out["norm.weight"] = model["model.norm.weight"]
|
||||
out["output.weight"] = model["lm_head.weight"]
|
||||
|
||||
n_head = model["model.layers.0.self_attn.q_proj.weight"].shape[1] // 128
|
||||
for i in itertools.count():
|
||||
if f"model.layers.{i}.self_attn.q_proj.weight" not in model:
|
||||
break
|
||||
out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], n_head)
|
||||
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], n_head)
|
||||
out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head)
|
||||
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head)
|
||||
out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
|
||||
out[f"layers.{i}.attention.wo.weight"] = model[f"model.layers.{i}.self_attn.o_proj.weight"]
|
||||
|
||||
|
@ -920,7 +967,7 @@ class OutputFile:
|
|||
def __init__(self, fname_out: Path) -> None:
|
||||
self.fout = open(fname_out, "wb")
|
||||
|
||||
def write_file_header(self, params: Params) -> None:
|
||||
def write_file_header(self, params: Params, file_type: GGMLFileType) -> None:
|
||||
self.fout.write(b"ggjt"[::-1]) # magic
|
||||
values = [
|
||||
1, # file version
|
||||
|
@ -930,7 +977,7 @@ class OutputFile:
|
|||
params.n_head,
|
||||
params.n_layer,
|
||||
params.n_embd // params.n_head, # rot (obsolete)
|
||||
params.file_type.value,
|
||||
file_type.value,
|
||||
]
|
||||
self.fout.write(struct.pack("i" * len(values), *values))
|
||||
|
||||
|
@ -958,10 +1005,10 @@ class OutputFile:
|
|||
of.fout.close()
|
||||
|
||||
@staticmethod
|
||||
def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None:
|
||||
def write_all(fname_out: Path, params: Params, file_type: GGMLFileType, model: LazyModel, vocab: Vocab) -> None:
|
||||
check_vocab_size(params, vocab)
|
||||
of = OutputFile(fname_out)
|
||||
of.write_file_header(params)
|
||||
of.write_file_header(params, file_type)
|
||||
print("Writing vocab...")
|
||||
of.write_vocab(vocab)
|
||||
|
||||
|
@ -997,11 +1044,11 @@ def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFi
|
|||
raise Exception(f"Unexpected combination of types: {name_to_type}")
|
||||
|
||||
|
||||
def do_necessary_conversions(model: LazyModel) -> LazyModel:
|
||||
def do_necessary_conversions(model: LazyModel, params: Params) -> LazyModel:
|
||||
model = handle_quantization(model)
|
||||
|
||||
if "lm_head.weight" in model:
|
||||
model = convert_transformers_to_orig(model)
|
||||
model = convert_transformers_to_orig(model, params)
|
||||
model = filter_and_sort_tensors(model)
|
||||
|
||||
return model
|
||||
|
@ -1107,14 +1154,14 @@ def load_vocab(path: Path) -> SentencePieceVocab:
|
|||
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None)
|
||||
|
||||
|
||||
def default_outfile(model_paths: List[Path], params: Params) -> Path:
|
||||
def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path:
|
||||
namestr = {
|
||||
GGMLFileType.AllF32: "f32",
|
||||
GGMLFileType.MostlyF16: "f16",
|
||||
GGMLFileType.MostlyQ4_0: "q4_0",
|
||||
GGMLFileType.MostlyQ4_1: "q4_1",
|
||||
GGMLFileType.PerLayerIsQ4_1: "q4_1",
|
||||
}[params.file_type]
|
||||
}[file_type]
|
||||
ret = model_paths[0].parent / f"ggml-model-{namestr}.bin"
|
||||
if ret in model_paths:
|
||||
sys.stderr.write(
|
||||
|
@ -1164,13 +1211,13 @@ def main(args_in: Optional[List[str]] = None) -> None:
|
|||
else:
|
||||
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
|
||||
vocab = load_vocab(vocab_dir)
|
||||
params = Params.load(model_plus)
|
||||
model = model_plus.model
|
||||
model = do_necessary_conversions(model)
|
||||
model = do_necessary_conversions(model, params)
|
||||
output_type = pick_output_type(model, args.outtype)
|
||||
model = convert_to_output_type(model, output_type)
|
||||
params = Params.guessed(model, output_type)
|
||||
outfile = args.outfile or default_outfile(model_plus.paths, params)
|
||||
OutputFile.write_all(outfile, params, model, vocab)
|
||||
outfile = args.outfile or default_outfile(model_plus.paths, output_type)
|
||||
OutputFile.write_all(outfile, params, output_type, model, vocab)
|
||||
print(f"Wrote {outfile}")
|
||||
|
||||
|
||||
|
|
|
@ -925,21 +925,21 @@ static bool kv_cache_init(
|
|||
|
||||
struct llama_context_params llama_context_default_params() {
|
||||
struct llama_context_params result = {
|
||||
/*.seed =*/ -1,
|
||||
/*.n_ctx =*/ 512,
|
||||
/*.n_batch =*/ 512,
|
||||
/*.gpu_layers =*/ 0,
|
||||
/*.main_gpu =*/ 0,
|
||||
/*.tensor_split =*/ {0},
|
||||
/*.progress_callback =*/ nullptr,
|
||||
/*.progress_callback_user_data =*/ nullptr,
|
||||
/*.low_vram =*/ false,
|
||||
/*.seed =*/ -1,
|
||||
/*.f16_kv =*/ true,
|
||||
/*.logits_all =*/ false,
|
||||
/*.vocab_only =*/ false,
|
||||
/*.use_mmap =*/ true,
|
||||
/*.use_mlock =*/ false,
|
||||
/*.embedding =*/ false,
|
||||
/*.progress_callback =*/ nullptr,
|
||||
/*.progress_callback_user_data =*/ nullptr,
|
||||
};
|
||||
|
||||
return result;
|
||||
|
|
17
llama.h
17
llama.h
|
@ -71,28 +71,27 @@ extern "C" {
|
|||
|
||||
typedef void (*llama_progress_callback)(float progress, void *ctx);
|
||||
|
||||
struct llama_context_params {
|
||||
struct llama_context_params {
|
||||
int seed; // RNG seed, -1 for random
|
||||
int n_ctx; // text context
|
||||
int n_batch; // prompt processing batch size
|
||||
int n_gpu_layers; // number of layers to store in VRAM
|
||||
int main_gpu; // the GPU that is used for scratch and small tensors
|
||||
float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs
|
||||
bool low_vram; // if true, reduce VRAM usage at the cost of performance
|
||||
int seed; // RNG seed, -1 for random
|
||||
// called with a progress value between 0 and 1, pass NULL to disable
|
||||
llama_progress_callback progress_callback;
|
||||
// context pointer passed to the progress callback
|
||||
void * progress_callback_user_data;
|
||||
|
||||
// Keep the booleans together to avoid misalignment during copy-by-value.
|
||||
bool low_vram; // if true, reduce VRAM usage at the cost of performance
|
||||
bool f16_kv; // use fp16 for KV cache
|
||||
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
||||
bool vocab_only; // only load the vocabulary, no weights
|
||||
bool use_mmap; // use mmap if possible
|
||||
bool use_mlock; // force system to keep model in RAM
|
||||
bool embedding; // embedding mode only
|
||||
|
||||
// called with a progress value between 0 and 1, pass NULL to disable
|
||||
llama_progress_callback progress_callback;
|
||||
// context pointer passed to the progress callback
|
||||
void * progress_callback_user_data;
|
||||
};
|
||||
|
||||
// model file types
|
||||
enum llama_ftype {
|
||||
LLAMA_FTYPE_ALL_F32 = 0,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue