Configurable sparse prediction threshold (#7)

* remove warning at gpu split

* remove dead code

* adaptive sparsity threshold reading from model file

* convert models with sparse threshold
This commit is contained in:
Holden X 2023-12-18 16:36:24 +08:00 committed by GitHub
parent 597ef34ba1
commit 603c771974
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 96 additions and 41 deletions

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import argparse import argparse
import concurrent.futures import concurrent.futures
import dataclasses
import enum import enum
import faulthandler import faulthandler
import functools import functools
@ -138,6 +139,28 @@ GGML_FILE_TYPE_TO_DATA_TYPE: dict[GGMLFileType, DataType] = {
# hparams loading # hparams loading
# #
@dataclass
class PredictorParams:
sparse_threshold: float | None = None
@staticmethod
def loadPredictorJson(model: LazyModel, config_path: Path) -> PredictorParams:
config = json.load(open(config_path))
return PredictorParams(
sparse_threshold = config.get("sparse_threshold"),
)
@staticmethod
def load(model_plus: ModelPlus) -> PredictorParams:
config_path = model_plus.paths[0].parent / "config.json"
if config_path.exists():
params = PredictorParams.loadPredictorJson(model_plus.model, config_path)
else:
params = PredictorParams()
return params
@dataclass @dataclass
class Params: class Params:
n_vocab: int n_vocab: int
@ -160,6 +183,9 @@ class Params:
# path to the directory containing the model files # path to the directory containing the model files
path_model: Path | None = None path_model: Path | None = None
# MLP predictor parameters
predictor_params: PredictorParams = dataclasses.field(default_factory=PredictorParams)
@staticmethod @staticmethod
def guessed(model: LazyModel) -> Params: def guessed(model: LazyModel) -> Params:
# try transformer naming first # try transformer naming first
@ -843,6 +869,9 @@ class OutputFile:
if params.ftype is not None: if params.ftype is not None:
self.gguf.add_file_type(params.ftype) self.gguf.add_file_type(params.ftype)
if params.predictor_params.sparse_threshold is not None:
self.gguf.add_sparse_threshold(params.predictor_params.sparse_threshold)
def add_meta_vocab(self, vocab: Vocab) -> None: def add_meta_vocab(self, vocab: Vocab) -> None:
tokens = [] tokens = []
scores = [] scores = []
@ -1181,10 +1210,13 @@ def main(args_in: list[str] | None = None) -> None:
if not args.vocab_only: if not args.vocab_only:
model_plus = load_some_model(args.model) model_plus = load_some_model(args.model)
params = Params.load(model_plus)
mlp_predictor_plus = load_mlp_model(args.mlp_model) mlp_predictor_plus = load_mlp_model(args.mlp_model)
params.predictor_params = PredictorParams.load(mlp_predictor_plus)
model_plus = merge_multifile_models([model_plus, mlp_predictor_plus]) model_plus = merge_multifile_models([model_plus, mlp_predictor_plus])
else: else:
model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None) model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
params = Params.load(model_plus)
if args.dump: if args.dump:
do_dump_model(model_plus) do_dump_model(model_plus)
@ -1193,7 +1225,6 @@ def main(args_in: list[str] | None = None) -> None:
if args.bigendian: if args.bigendian:
endianess = gguf.GGUFEndian.BIG endianess = gguf.GGUFEndian.BIG
params = Params.load(model_plus)
if params.n_ctx == -1: if params.n_ctx == -1:
if args.ctx is None: if args.ctx is None:
raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n" raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n"

View file

@ -108,6 +108,8 @@
// max batch size to use MMQ kernels when tensor cores are available // max batch size to use MMQ kernels when tensor cores are available
#define MMQ_MAX_BATCH_SIZE 32 #define MMQ_MAX_BATCH_SIZE 32
__constant__ float dev_sparse_threshold;
#if defined(GGML_USE_HIPBLAS) #if defined(GGML_USE_HIPBLAS)
#define __CUDA_ARCH__ 1300 #define __CUDA_ARCH__ 1300
@ -4483,7 +4485,7 @@ static __global__ void dequantize_mul_mat_axpy_sparse(const void * __restrict__
// printf("row in gpu %d cols %d, value %d %d %d\n", id, ncols, *d, *(d+1), *(d+4095)); // printf("row in gpu %d cols %d, value %d %d %d\n", id, ncols, *d, *(d+1), *(d+4095));
// } // }
// int id = row; // int id = row;
if (idx[id] < 0.0f) { if (idx[id] < dev_sparse_threshold) {
return; return;
} }
@ -4552,12 +4554,7 @@ static __global__ void dequantize_mul_mat_axpy_sparse_batch(const void * __restr
return; return;
} }
int id = lst[row]; int id = lst[row];
// int id = row;
// if (idx[id] < 0.0f) {
// return;
// }
const int bid = blockIdx.y; const int bid = blockIdx.y;
// if (bid == 0) global_lock = 0;
extern __shared__ float shared_dst[]; // TODO:dynamic extern __shared__ float shared_dst[]; // TODO:dynamic
@ -4578,7 +4575,7 @@ static __global__ void dequantize_mul_mat_axpy_sparse_batch(const void * __restr
// __syncthreads(); // __syncthreads();
for (int col_id = 0; col_id < src1_ncols; col_id++) { for (int col_id = 0; col_id < src1_ncols; col_id++) {
__syncthreads(); __syncthreads();
if (loop_idx[id] < 0.0f) { if (loop_idx[id] < dev_sparse_threshold) {
loop_dst += ncols; loop_dst += ncols;
loop_idx += src1_ne0; loop_idx += src1_ne0;
loop_y += src1_ne0; loop_y += src1_ne0;
@ -4640,7 +4637,7 @@ static __global__ void dequantize_axpy_sparse(const void * __restrict__ vx, cons
return; return;
} }
int id = lst[row]; int id = lst[row];
if (idx[id] < 0.0f) { if (idx[id] < dev_sparse_threshold) {
return; return;
} }
@ -4689,8 +4686,7 @@ static __global__ void dequantize_mul_mat_vec_sparse(const void * __restrict__ v
return; return;
} }
int id = lst[row]; int id = lst[row];
// int id = row; if (idx[id] < dev_sparse_threshold) {
if (idx[id] < 0.0f) {
return; return;
} }
@ -4782,7 +4778,7 @@ static __global__ void dequantize_mul_mat_batch_sparse(const void * __restrict__
{ {
__syncthreads(); __syncthreads();
tmp = 0.0f; tmp = 0.0f;
if (loop_idx[id] < 0.0f) if (loop_idx[id] < dev_sparse_threshold)
{ {
loop_dst += dst_ne0; loop_dst += dst_ne0;
loop_idx += dst_ne0; loop_idx += dst_ne0;
@ -9618,3 +9614,6 @@ ggml_backend_t ggml_backend_cuda_init() {
return cuda_backend; return cuda_backend;
} }
void ggml_cuda_set_device_constants(float sparse_pred_threshold) {
CUDA_CHECK(cudaMemcpyToSymbol(dev_sparse_threshold, &sparse_pred_threshold, sizeof(float)));
}

View file

@ -53,6 +53,8 @@ GGML_API int ggml_cuda_get_device_count(void);
GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size); GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
GGML_API size_t ggml_cuda_get_free_memory(int device); GGML_API size_t ggml_cuda_get_free_memory(int device);
GGML_API void ggml_cuda_set_device_constants(float sparse_pred_threshold);
// backend API // backend API
GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use

17
ggml.c
View file

@ -14059,6 +14059,8 @@ static void ggml_compute_forward_mul_mat_sparse(
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
const float threshold = sparse_pred_threshold;
GGML_ASSERT(ne0 == ne01); GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11); GGML_ASSERT(ne1 == ne11);
GGML_ASSERT(ne2 == ne12); GGML_ASSERT(ne2 == ne12);
@ -14262,7 +14264,7 @@ static void ggml_compute_forward_mul_mat_sparse(
float *dst_col = (float *)((char *)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); float *dst_col = (float *)((char *)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
// if (ffdata[ir0] <= 0.0f) { // if (ffdata[ir0] <= 0.0f) {
if (gid[ir0] == 1 || ffdata[ir0] < -0.0f) { if (gid[ir0] == 1 || ffdata[ir0] < threshold) {
dst_col[ir0] = 0; dst_col[ir0] = 0;
continue; continue;
} }
@ -14413,11 +14415,6 @@ static void ggml_compute_forward_mul_mat_axpy_dense(
const int ir0 = atomic_fetch_add(params->aic, dr); const int ir0 = atomic_fetch_add(params->aic, dr);
for (int64_t ir1 = ir0; ir1 < ir0+dr; ir1++) { for (int64_t ir1 = ir0; ir1 < ir0+dr; ir1++) {
if (ir1 >= nr) break; if (ir1 >= nr) break;
// if (gid[ir1] == 1)
// continue;
// if (idx[ir1] < 0.0f)
// continue;
// ggml_axpy_normal_f16(ne00, src0_row+nb01*ir1, vy, vy, wdata[ir1]);
ggml_axpy_avx_f16(ne00, (ggml_fp16_t *)(src0_row+nb01*ir1), (ggml_fp16_t *)vy, vy, wdata[ir1]); ggml_axpy_avx_f16(ne00, (ggml_fp16_t *)(src0_row+nb01*ir1), (ggml_fp16_t *)vy, vy, wdata[ir1]);
} }
if (ir0 + dr >= nr) if (ir0 + dr >= nr)
@ -14482,6 +14479,8 @@ static void ggml_compute_forward_mul_mat_axpy(
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
const float threshold = sparse_pred_threshold;
// GGML_ASSERT(ne0 == ne01); // GGML_ASSERT(ne0 == ne01);
// GGML_ASSERT(ne1 == ne11); // GGML_ASSERT(ne1 == ne11);
// GGML_ASSERT(ne2 == ne12); // GGML_ASSERT(ne2 == ne12);
@ -14569,7 +14568,7 @@ static void ggml_compute_forward_mul_mat_axpy(
if (gid[ir1] == 1) { if (gid[ir1] == 1) {
continue; continue;
} }
if (idx[ir1] < -0.0f) if (idx[ir1] < threshold)
continue; continue;
// ggml_axpy_normal_f16(ne00, src0_row+nb01*ir1, vy, vy, wdata[ir1]); // ggml_axpy_normal_f16(ne00, src0_row+nb01*ir1, vy, vy, wdata[ir1]);
ggml_axpy_avx_f16(ne00, (ggml_fp16_t *)(src0_row+nb01*ir1), (ggml_fp16_t *)vy, vy, src1_ptr[ir1]); ggml_axpy_avx_f16(ne00, (ggml_fp16_t *)(src0_row+nb01*ir1), (ggml_fp16_t *)vy, vy, src1_ptr[ir1]);
@ -14632,6 +14631,8 @@ static void ggml_compute_forward_mul_mat_axpy_q4_0(
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
const float threshold = sparse_pred_threshold;
// GGML_ASSERT(ne0 == ne01); // GGML_ASSERT(ne0 == ne01);
// GGML_ASSERT(ne1 == ne11); // GGML_ASSERT(ne1 == ne11);
// GGML_ASSERT(ne2 == ne12); // GGML_ASSERT(ne2 == ne12);
@ -14713,7 +14714,7 @@ static void ggml_compute_forward_mul_mat_axpy_q4_0(
break; break;
if (gid[ir1] == 1) if (gid[ir1] == 1)
continue; continue;
if (idx[ir1] < 0.0f) if (idx[ir1] < threshold)
continue; continue;
int bid = ir1 / QK8_0; int bid = ir1 / QK8_0;
int qsid = ir1 % QK8_0; int qsid = ir1 % QK8_0;

6
ggml.h
View file

@ -2196,6 +2196,12 @@ extern "C" {
GGML_API int ggml_cpu_has_ssse3 (void); GGML_API int ggml_cpu_has_ssse3 (void);
GGML_API int ggml_cpu_has_vsx (void); GGML_API int ggml_cpu_has_vsx (void);
//
// global variables
//
// TODO: these should be moved to the context
extern float sparse_pred_threshold;
// //
// Internal types and functions exposed for tests and benchmarks // Internal types and functions exposed for tests and benchmarks
// //

View file

@ -70,6 +70,9 @@ class Keys:
ADD_EOS = "tokenizer.ggml.add_eos_token" ADD_EOS = "tokenizer.ggml.add_eos_token"
HF_JSON = "tokenizer.huggingface.json" HF_JSON = "tokenizer.huggingface.json"
RWKV = "tokenizer.rwkv.world" RWKV = "tokenizer.rwkv.world"
class PowerInfer:
SPARSE_THRESHOLD = "powerinfer.sparse_threshold"
# #

View file

@ -399,6 +399,9 @@ class GGUFWriter:
def add_add_eos_token(self, value: bool) -> None: def add_add_eos_token(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.ADD_EOS, value) self.add_bool(Keys.Tokenizer.ADD_EOS, value)
def add_sparse_threshold(self, value: float) -> None:
self.add_float32(Keys.PowerInfer.SPARSE_THRESHOLD, value)
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes: def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
pack_prefix = '' pack_prefix = ''
if not skip_pack_prefix: if not skip_pack_prefix:

View file

@ -93,6 +93,13 @@
#define LLAMA_MAX_NODES 4096 #define LLAMA_MAX_NODES 4096
//
// global variables
//
// sparsity threshold for sparse matrix multiplication prediction
float sparse_pred_threshold = 0.;
// //
// logging // logging
// //
@ -257,6 +264,8 @@ enum llm_kv {
LLM_KV_TOKENIZER_PAD_ID, LLM_KV_TOKENIZER_PAD_ID,
LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_HF_JSON,
LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_RWKV,
LLM_KV_SPARSE_THRESHOLD,
}; };
static std::map<llm_kv, std::string> LLM_KV_NAMES = { static std::map<llm_kv, std::string> LLM_KV_NAMES = {
@ -305,6 +314,8 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" },
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
{ LLM_KV_SPARSE_THRESHOLD, "powerinfer.sparse_threshold" },
}; };
struct LLM_KV { struct LLM_KV {
@ -1150,6 +1161,9 @@ struct llama_hparams {
float f_clamp_kqv; float f_clamp_kqv;
float f_max_alibi_bias; float f_max_alibi_bias;
// sparse predictor threshold if sparse inference is enabled
float sparse_pred_threshold = atof(getenv("LLAMA_SPARSE_PRED_THRESHOLD") ?: "0.0");
bool operator!=(const llama_hparams & other) const { bool operator!=(const llama_hparams & other) const {
if (this->vocab_only != other.vocab_only) return true; if (this->vocab_only != other.vocab_only) return true;
@ -2220,6 +2234,11 @@ static void llm_load_hparams(
// gpt-j n_rot = rotary_dim // gpt-j n_rot = rotary_dim
} }
if (gguf_get_sparse_deriv(ctx)) {
// read sparse threshold override if sparse deriv is enabled
GGUF_GET_KEY(ctx, hparams.sparse_pred_threshold, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_SPARSE_THRESHOLD));
}
// arch-specific KVs // arch-specific KVs
switch (model.arch) { switch (model.arch) {
case LLM_ARCH_LLAMA: case LLM_ARCH_LLAMA:
@ -2607,6 +2626,9 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); } if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); }
if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); } if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); }
if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); } if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
// sparse inference
LLAMA_LOG_INFO("%s: sparse_pred_threshold = %.2f\n", __func__, hparams.sparse_pred_threshold);
} }
@ -2808,7 +2830,7 @@ struct llama_augmentation_model_loader {
return NULL; return NULL;
} }
// allocate and copy selected weights to gpu // allocate and copy selected weights to gpu
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
int64_t row_len = src->ne[0]; int64_t row_len = src->ne[0];
int64_t gpu_rows = gpu_bucket->ne[0]; int64_t gpu_rows = gpu_bucket->ne[0];
if (gpu_rows == 0) if (gpu_rows == 0)
@ -2841,10 +2863,9 @@ struct llama_augmentation_model_loader {
ggml_set_no_alloc(aux_ctx, false); ggml_set_no_alloc(aux_ctx, false);
return gpu_dst; return gpu_dst;
#else #else
printf("As you do not support CUDA. Split to GPU is not allowed.\n");
return NULL; return NULL;
#endif #endif
} }
void slice_ffn_mat_to_gpu(llama_layer & layer) { void slice_ffn_mat_to_gpu(llama_layer & layer) {
@ -2882,22 +2903,11 @@ struct llama_augmentation_model_loader {
const int64_t t_start_aug_us = ggml_time_us(); const int64_t t_start_aug_us = ggml_time_us();
std::vector<uint8_t> work_buffer; std::vector<uint8_t> work_buffer;
// transpose ffn_down to use axpy // Set sparsity threshold via global virables
// ggml_cgraph * tmp_transpose_gf = ggml_new_graph(aux_ctx); sparse_pred_threshold = model->hparams.sparse_pred_threshold;
// for (llama_layer &model_layer : model -> layers) { #if defined (GGML_USE_CUBLAS)
// // gpu_w2 transpose load ggml_cuda_set_device_constants(model->hparams.sparse_pred_threshold);
// ggml_tensor * ffn_down_t = ggml_cont(aux_ctx, ggml_transpose(aux_ctx, model_layer.ffn_down)); #endif
// ggml_build_forward_expand(tmp_transpose_gf, ffn_down_t);
// model_layer.ffn_down_t = ffn_down_t;
// LLAMA_LOG_INFO(".");
// }
// ggml_graph_compute_helper(work_buffer, tmp_transpose_gf, 2);
// for (llama_layer &model_layer : model -> layers) {
// model_layer.ffn_down_t->op = GGML_OP_NONE;
// model_layer.ffn_down_t->src[0] = NULL;
// model_layer.ffn_down_t->src[1] = NULL;
// model_layer.ffn_down_t->src[2] = NULL;
// }
// load gpu_idx and slice mat to gpu // load gpu_idx and slice mat to gpu
for (llama_layer &model_layer : model -> layers) { for (llama_layer &model_layer : model -> layers) {

View file

@ -1,3 +1,3 @@
numpy==1.24.4 numpy==1.24.4
sentencepiece==0.1.98 sentencepiece==0.1.98
gguf>=0.1.0 -e ./gguf-py