Configurable sparse prediction threshold (#7)
* remove warning at gpu split * remove dead code * adaptive sparsity threshold reading from model file * convert models with sparse threshold
This commit is contained in:
parent
597ef34ba1
commit
603c771974
9 changed files with 96 additions and 41 deletions
33
convert.py
33
convert.py
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
import dataclasses
|
||||||
import enum
|
import enum
|
||||||
import faulthandler
|
import faulthandler
|
||||||
import functools
|
import functools
|
||||||
|
@ -138,6 +139,28 @@ GGML_FILE_TYPE_TO_DATA_TYPE: dict[GGMLFileType, DataType] = {
|
||||||
# hparams loading
|
# hparams loading
|
||||||
#
|
#
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PredictorParams:
|
||||||
|
sparse_threshold: float | None = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def loadPredictorJson(model: LazyModel, config_path: Path) -> PredictorParams:
|
||||||
|
config = json.load(open(config_path))
|
||||||
|
return PredictorParams(
|
||||||
|
sparse_threshold = config.get("sparse_threshold"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(model_plus: ModelPlus) -> PredictorParams:
|
||||||
|
config_path = model_plus.paths[0].parent / "config.json"
|
||||||
|
|
||||||
|
if config_path.exists():
|
||||||
|
params = PredictorParams.loadPredictorJson(model_plus.model, config_path)
|
||||||
|
else:
|
||||||
|
params = PredictorParams()
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Params:
|
class Params:
|
||||||
n_vocab: int
|
n_vocab: int
|
||||||
|
@ -160,6 +183,9 @@ class Params:
|
||||||
# path to the directory containing the model files
|
# path to the directory containing the model files
|
||||||
path_model: Path | None = None
|
path_model: Path | None = None
|
||||||
|
|
||||||
|
# MLP predictor parameters
|
||||||
|
predictor_params: PredictorParams = dataclasses.field(default_factory=PredictorParams)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def guessed(model: LazyModel) -> Params:
|
def guessed(model: LazyModel) -> Params:
|
||||||
# try transformer naming first
|
# try transformer naming first
|
||||||
|
@ -843,6 +869,9 @@ class OutputFile:
|
||||||
if params.ftype is not None:
|
if params.ftype is not None:
|
||||||
self.gguf.add_file_type(params.ftype)
|
self.gguf.add_file_type(params.ftype)
|
||||||
|
|
||||||
|
if params.predictor_params.sparse_threshold is not None:
|
||||||
|
self.gguf.add_sparse_threshold(params.predictor_params.sparse_threshold)
|
||||||
|
|
||||||
def add_meta_vocab(self, vocab: Vocab) -> None:
|
def add_meta_vocab(self, vocab: Vocab) -> None:
|
||||||
tokens = []
|
tokens = []
|
||||||
scores = []
|
scores = []
|
||||||
|
@ -1181,10 +1210,13 @@ def main(args_in: list[str] | None = None) -> None:
|
||||||
|
|
||||||
if not args.vocab_only:
|
if not args.vocab_only:
|
||||||
model_plus = load_some_model(args.model)
|
model_plus = load_some_model(args.model)
|
||||||
|
params = Params.load(model_plus)
|
||||||
mlp_predictor_plus = load_mlp_model(args.mlp_model)
|
mlp_predictor_plus = load_mlp_model(args.mlp_model)
|
||||||
|
params.predictor_params = PredictorParams.load(mlp_predictor_plus)
|
||||||
model_plus = merge_multifile_models([model_plus, mlp_predictor_plus])
|
model_plus = merge_multifile_models([model_plus, mlp_predictor_plus])
|
||||||
else:
|
else:
|
||||||
model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
|
model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
|
||||||
|
params = Params.load(model_plus)
|
||||||
|
|
||||||
if args.dump:
|
if args.dump:
|
||||||
do_dump_model(model_plus)
|
do_dump_model(model_plus)
|
||||||
|
@ -1193,7 +1225,6 @@ def main(args_in: list[str] | None = None) -> None:
|
||||||
if args.bigendian:
|
if args.bigendian:
|
||||||
endianess = gguf.GGUFEndian.BIG
|
endianess = gguf.GGUFEndian.BIG
|
||||||
|
|
||||||
params = Params.load(model_plus)
|
|
||||||
if params.n_ctx == -1:
|
if params.n_ctx == -1:
|
||||||
if args.ctx is None:
|
if args.ctx is None:
|
||||||
raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n"
|
raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n"
|
||||||
|
|
21
ggml-cuda.cu
21
ggml-cuda.cu
|
@ -108,6 +108,8 @@
|
||||||
// max batch size to use MMQ kernels when tensor cores are available
|
// max batch size to use MMQ kernels when tensor cores are available
|
||||||
#define MMQ_MAX_BATCH_SIZE 32
|
#define MMQ_MAX_BATCH_SIZE 32
|
||||||
|
|
||||||
|
__constant__ float dev_sparse_threshold;
|
||||||
|
|
||||||
#if defined(GGML_USE_HIPBLAS)
|
#if defined(GGML_USE_HIPBLAS)
|
||||||
#define __CUDA_ARCH__ 1300
|
#define __CUDA_ARCH__ 1300
|
||||||
|
|
||||||
|
@ -4483,7 +4485,7 @@ static __global__ void dequantize_mul_mat_axpy_sparse(const void * __restrict__
|
||||||
// printf("row in gpu %d cols %d, value %d %d %d\n", id, ncols, *d, *(d+1), *(d+4095));
|
// printf("row in gpu %d cols %d, value %d %d %d\n", id, ncols, *d, *(d+1), *(d+4095));
|
||||||
// }
|
// }
|
||||||
// int id = row;
|
// int id = row;
|
||||||
if (idx[id] < 0.0f) {
|
if (idx[id] < dev_sparse_threshold) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4552,12 +4554,7 @@ static __global__ void dequantize_mul_mat_axpy_sparse_batch(const void * __restr
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
int id = lst[row];
|
int id = lst[row];
|
||||||
// int id = row;
|
|
||||||
// if (idx[id] < 0.0f) {
|
|
||||||
// return;
|
|
||||||
// }
|
|
||||||
const int bid = blockIdx.y;
|
const int bid = blockIdx.y;
|
||||||
// if (bid == 0) global_lock = 0;
|
|
||||||
|
|
||||||
extern __shared__ float shared_dst[]; // TODO:dynamic
|
extern __shared__ float shared_dst[]; // TODO:dynamic
|
||||||
|
|
||||||
|
@ -4578,7 +4575,7 @@ static __global__ void dequantize_mul_mat_axpy_sparse_batch(const void * __restr
|
||||||
// __syncthreads();
|
// __syncthreads();
|
||||||
for (int col_id = 0; col_id < src1_ncols; col_id++) {
|
for (int col_id = 0; col_id < src1_ncols; col_id++) {
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (loop_idx[id] < 0.0f) {
|
if (loop_idx[id] < dev_sparse_threshold) {
|
||||||
loop_dst += ncols;
|
loop_dst += ncols;
|
||||||
loop_idx += src1_ne0;
|
loop_idx += src1_ne0;
|
||||||
loop_y += src1_ne0;
|
loop_y += src1_ne0;
|
||||||
|
@ -4640,7 +4637,7 @@ static __global__ void dequantize_axpy_sparse(const void * __restrict__ vx, cons
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
int id = lst[row];
|
int id = lst[row];
|
||||||
if (idx[id] < 0.0f) {
|
if (idx[id] < dev_sparse_threshold) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4689,8 +4686,7 @@ static __global__ void dequantize_mul_mat_vec_sparse(const void * __restrict__ v
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
int id = lst[row];
|
int id = lst[row];
|
||||||
// int id = row;
|
if (idx[id] < dev_sparse_threshold) {
|
||||||
if (idx[id] < 0.0f) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4782,7 +4778,7 @@ static __global__ void dequantize_mul_mat_batch_sparse(const void * __restrict__
|
||||||
{
|
{
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
tmp = 0.0f;
|
tmp = 0.0f;
|
||||||
if (loop_idx[id] < 0.0f)
|
if (loop_idx[id] < dev_sparse_threshold)
|
||||||
{
|
{
|
||||||
loop_dst += dst_ne0;
|
loop_dst += dst_ne0;
|
||||||
loop_idx += dst_ne0;
|
loop_idx += dst_ne0;
|
||||||
|
@ -9618,3 +9614,6 @@ ggml_backend_t ggml_backend_cuda_init() {
|
||||||
return cuda_backend;
|
return cuda_backend;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_set_device_constants(float sparse_pred_threshold) {
|
||||||
|
CUDA_CHECK(cudaMemcpyToSymbol(dev_sparse_threshold, &sparse_pred_threshold, sizeof(float)));
|
||||||
|
}
|
||||||
|
|
|
@ -53,6 +53,8 @@ GGML_API int ggml_cuda_get_device_count(void);
|
||||||
GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
|
GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
|
||||||
GGML_API size_t ggml_cuda_get_free_memory(int device);
|
GGML_API size_t ggml_cuda_get_free_memory(int device);
|
||||||
|
|
||||||
|
GGML_API void ggml_cuda_set_device_constants(float sparse_pred_threshold);
|
||||||
|
|
||||||
// backend API
|
// backend API
|
||||||
GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use
|
GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use
|
||||||
|
|
||||||
|
|
17
ggml.c
17
ggml.c
|
@ -14059,6 +14059,8 @@ static void ggml_compute_forward_mul_mat_sparse(
|
||||||
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
||||||
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
||||||
|
|
||||||
|
const float threshold = sparse_pred_threshold;
|
||||||
|
|
||||||
GGML_ASSERT(ne0 == ne01);
|
GGML_ASSERT(ne0 == ne01);
|
||||||
GGML_ASSERT(ne1 == ne11);
|
GGML_ASSERT(ne1 == ne11);
|
||||||
GGML_ASSERT(ne2 == ne12);
|
GGML_ASSERT(ne2 == ne12);
|
||||||
|
@ -14262,7 +14264,7 @@ static void ggml_compute_forward_mul_mat_sparse(
|
||||||
float *dst_col = (float *)((char *)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
|
float *dst_col = (float *)((char *)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
|
||||||
|
|
||||||
// if (ffdata[ir0] <= 0.0f) {
|
// if (ffdata[ir0] <= 0.0f) {
|
||||||
if (gid[ir0] == 1 || ffdata[ir0] < -0.0f) {
|
if (gid[ir0] == 1 || ffdata[ir0] < threshold) {
|
||||||
dst_col[ir0] = 0;
|
dst_col[ir0] = 0;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -14413,11 +14415,6 @@ static void ggml_compute_forward_mul_mat_axpy_dense(
|
||||||
const int ir0 = atomic_fetch_add(params->aic, dr);
|
const int ir0 = atomic_fetch_add(params->aic, dr);
|
||||||
for (int64_t ir1 = ir0; ir1 < ir0+dr; ir1++) {
|
for (int64_t ir1 = ir0; ir1 < ir0+dr; ir1++) {
|
||||||
if (ir1 >= nr) break;
|
if (ir1 >= nr) break;
|
||||||
// if (gid[ir1] == 1)
|
|
||||||
// continue;
|
|
||||||
// if (idx[ir1] < 0.0f)
|
|
||||||
// continue;
|
|
||||||
// ggml_axpy_normal_f16(ne00, src0_row+nb01*ir1, vy, vy, wdata[ir1]);
|
|
||||||
ggml_axpy_avx_f16(ne00, (ggml_fp16_t *)(src0_row+nb01*ir1), (ggml_fp16_t *)vy, vy, wdata[ir1]);
|
ggml_axpy_avx_f16(ne00, (ggml_fp16_t *)(src0_row+nb01*ir1), (ggml_fp16_t *)vy, vy, wdata[ir1]);
|
||||||
}
|
}
|
||||||
if (ir0 + dr >= nr)
|
if (ir0 + dr >= nr)
|
||||||
|
@ -14482,6 +14479,8 @@ static void ggml_compute_forward_mul_mat_axpy(
|
||||||
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
||||||
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
||||||
|
|
||||||
|
const float threshold = sparse_pred_threshold;
|
||||||
|
|
||||||
// GGML_ASSERT(ne0 == ne01);
|
// GGML_ASSERT(ne0 == ne01);
|
||||||
// GGML_ASSERT(ne1 == ne11);
|
// GGML_ASSERT(ne1 == ne11);
|
||||||
// GGML_ASSERT(ne2 == ne12);
|
// GGML_ASSERT(ne2 == ne12);
|
||||||
|
@ -14569,7 +14568,7 @@ static void ggml_compute_forward_mul_mat_axpy(
|
||||||
if (gid[ir1] == 1) {
|
if (gid[ir1] == 1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (idx[ir1] < -0.0f)
|
if (idx[ir1] < threshold)
|
||||||
continue;
|
continue;
|
||||||
// ggml_axpy_normal_f16(ne00, src0_row+nb01*ir1, vy, vy, wdata[ir1]);
|
// ggml_axpy_normal_f16(ne00, src0_row+nb01*ir1, vy, vy, wdata[ir1]);
|
||||||
ggml_axpy_avx_f16(ne00, (ggml_fp16_t *)(src0_row+nb01*ir1), (ggml_fp16_t *)vy, vy, src1_ptr[ir1]);
|
ggml_axpy_avx_f16(ne00, (ggml_fp16_t *)(src0_row+nb01*ir1), (ggml_fp16_t *)vy, vy, src1_ptr[ir1]);
|
||||||
|
@ -14632,6 +14631,8 @@ static void ggml_compute_forward_mul_mat_axpy_q4_0(
|
||||||
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
||||||
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
||||||
|
|
||||||
|
const float threshold = sparse_pred_threshold;
|
||||||
|
|
||||||
// GGML_ASSERT(ne0 == ne01);
|
// GGML_ASSERT(ne0 == ne01);
|
||||||
// GGML_ASSERT(ne1 == ne11);
|
// GGML_ASSERT(ne1 == ne11);
|
||||||
// GGML_ASSERT(ne2 == ne12);
|
// GGML_ASSERT(ne2 == ne12);
|
||||||
|
@ -14713,7 +14714,7 @@ static void ggml_compute_forward_mul_mat_axpy_q4_0(
|
||||||
break;
|
break;
|
||||||
if (gid[ir1] == 1)
|
if (gid[ir1] == 1)
|
||||||
continue;
|
continue;
|
||||||
if (idx[ir1] < 0.0f)
|
if (idx[ir1] < threshold)
|
||||||
continue;
|
continue;
|
||||||
int bid = ir1 / QK8_0;
|
int bid = ir1 / QK8_0;
|
||||||
int qsid = ir1 % QK8_0;
|
int qsid = ir1 % QK8_0;
|
||||||
|
|
6
ggml.h
6
ggml.h
|
@ -2196,6 +2196,12 @@ extern "C" {
|
||||||
GGML_API int ggml_cpu_has_ssse3 (void);
|
GGML_API int ggml_cpu_has_ssse3 (void);
|
||||||
GGML_API int ggml_cpu_has_vsx (void);
|
GGML_API int ggml_cpu_has_vsx (void);
|
||||||
|
|
||||||
|
//
|
||||||
|
// global variables
|
||||||
|
//
|
||||||
|
// TODO: these should be moved to the context
|
||||||
|
extern float sparse_pred_threshold;
|
||||||
|
|
||||||
//
|
//
|
||||||
// Internal types and functions exposed for tests and benchmarks
|
// Internal types and functions exposed for tests and benchmarks
|
||||||
//
|
//
|
||||||
|
|
|
@ -71,6 +71,9 @@ class Keys:
|
||||||
HF_JSON = "tokenizer.huggingface.json"
|
HF_JSON = "tokenizer.huggingface.json"
|
||||||
RWKV = "tokenizer.rwkv.world"
|
RWKV = "tokenizer.rwkv.world"
|
||||||
|
|
||||||
|
class PowerInfer:
|
||||||
|
SPARSE_THRESHOLD = "powerinfer.sparse_threshold"
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# recommended mapping of model tensor names for storage in gguf
|
# recommended mapping of model tensor names for storage in gguf
|
||||||
|
|
|
@ -399,6 +399,9 @@ class GGUFWriter:
|
||||||
def add_add_eos_token(self, value: bool) -> None:
|
def add_add_eos_token(self, value: bool) -> None:
|
||||||
self.add_bool(Keys.Tokenizer.ADD_EOS, value)
|
self.add_bool(Keys.Tokenizer.ADD_EOS, value)
|
||||||
|
|
||||||
|
def add_sparse_threshold(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.PowerInfer.SPARSE_THRESHOLD, value)
|
||||||
|
|
||||||
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
|
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
|
||||||
pack_prefix = ''
|
pack_prefix = ''
|
||||||
if not skip_pack_prefix:
|
if not skip_pack_prefix:
|
||||||
|
|
50
llama.cpp
50
llama.cpp
|
@ -93,6 +93,13 @@
|
||||||
|
|
||||||
#define LLAMA_MAX_NODES 4096
|
#define LLAMA_MAX_NODES 4096
|
||||||
|
|
||||||
|
//
|
||||||
|
// global variables
|
||||||
|
//
|
||||||
|
|
||||||
|
// sparsity threshold for sparse matrix multiplication prediction
|
||||||
|
float sparse_pred_threshold = 0.;
|
||||||
|
|
||||||
//
|
//
|
||||||
// logging
|
// logging
|
||||||
//
|
//
|
||||||
|
@ -257,6 +264,8 @@ enum llm_kv {
|
||||||
LLM_KV_TOKENIZER_PAD_ID,
|
LLM_KV_TOKENIZER_PAD_ID,
|
||||||
LLM_KV_TOKENIZER_HF_JSON,
|
LLM_KV_TOKENIZER_HF_JSON,
|
||||||
LLM_KV_TOKENIZER_RWKV,
|
LLM_KV_TOKENIZER_RWKV,
|
||||||
|
|
||||||
|
LLM_KV_SPARSE_THRESHOLD,
|
||||||
};
|
};
|
||||||
|
|
||||||
static std::map<llm_kv, std::string> LLM_KV_NAMES = {
|
static std::map<llm_kv, std::string> LLM_KV_NAMES = {
|
||||||
|
@ -305,6 +314,8 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" },
|
{ LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" },
|
||||||
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
|
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
|
||||||
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
|
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
|
||||||
|
|
||||||
|
{ LLM_KV_SPARSE_THRESHOLD, "powerinfer.sparse_threshold" },
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LLM_KV {
|
struct LLM_KV {
|
||||||
|
@ -1151,6 +1162,9 @@ struct llama_hparams {
|
||||||
float f_clamp_kqv;
|
float f_clamp_kqv;
|
||||||
float f_max_alibi_bias;
|
float f_max_alibi_bias;
|
||||||
|
|
||||||
|
// sparse predictor threshold if sparse inference is enabled
|
||||||
|
float sparse_pred_threshold = atof(getenv("LLAMA_SPARSE_PRED_THRESHOLD") ?: "0.0");
|
||||||
|
|
||||||
bool operator!=(const llama_hparams & other) const {
|
bool operator!=(const llama_hparams & other) const {
|
||||||
if (this->vocab_only != other.vocab_only) return true;
|
if (this->vocab_only != other.vocab_only) return true;
|
||||||
if (this->n_vocab != other.n_vocab) return true;
|
if (this->n_vocab != other.n_vocab) return true;
|
||||||
|
@ -2220,6 +2234,11 @@ static void llm_load_hparams(
|
||||||
// gpt-j n_rot = rotary_dim
|
// gpt-j n_rot = rotary_dim
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (gguf_get_sparse_deriv(ctx)) {
|
||||||
|
// read sparse threshold override if sparse deriv is enabled
|
||||||
|
GGUF_GET_KEY(ctx, hparams.sparse_pred_threshold, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_SPARSE_THRESHOLD));
|
||||||
|
}
|
||||||
|
|
||||||
// arch-specific KVs
|
// arch-specific KVs
|
||||||
switch (model.arch) {
|
switch (model.arch) {
|
||||||
case LLM_ARCH_LLAMA:
|
case LLM_ARCH_LLAMA:
|
||||||
|
@ -2607,6 +2626,9 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
||||||
if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); }
|
if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); }
|
||||||
if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); }
|
if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); }
|
||||||
if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
|
if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
|
||||||
|
|
||||||
|
// sparse inference
|
||||||
|
LLAMA_LOG_INFO("%s: sparse_pred_threshold = %.2f\n", __func__, hparams.sparse_pred_threshold);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2808,7 +2830,7 @@ struct llama_augmentation_model_loader {
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
// allocate and copy selected weights to gpu
|
// allocate and copy selected weights to gpu
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
int64_t row_len = src->ne[0];
|
int64_t row_len = src->ne[0];
|
||||||
int64_t gpu_rows = gpu_bucket->ne[0];
|
int64_t gpu_rows = gpu_bucket->ne[0];
|
||||||
if (gpu_rows == 0)
|
if (gpu_rows == 0)
|
||||||
|
@ -2841,10 +2863,9 @@ struct llama_augmentation_model_loader {
|
||||||
ggml_set_no_alloc(aux_ctx, false);
|
ggml_set_no_alloc(aux_ctx, false);
|
||||||
|
|
||||||
return gpu_dst;
|
return gpu_dst;
|
||||||
#else
|
#else
|
||||||
printf("As you do not support CUDA. Split to GPU is not allowed.\n");
|
|
||||||
return NULL;
|
return NULL;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void slice_ffn_mat_to_gpu(llama_layer & layer) {
|
void slice_ffn_mat_to_gpu(llama_layer & layer) {
|
||||||
|
@ -2882,22 +2903,11 @@ struct llama_augmentation_model_loader {
|
||||||
const int64_t t_start_aug_us = ggml_time_us();
|
const int64_t t_start_aug_us = ggml_time_us();
|
||||||
std::vector<uint8_t> work_buffer;
|
std::vector<uint8_t> work_buffer;
|
||||||
|
|
||||||
// transpose ffn_down to use axpy
|
// Set sparsity threshold via global virables
|
||||||
// ggml_cgraph * tmp_transpose_gf = ggml_new_graph(aux_ctx);
|
sparse_pred_threshold = model->hparams.sparse_pred_threshold;
|
||||||
// for (llama_layer &model_layer : model -> layers) {
|
#if defined (GGML_USE_CUBLAS)
|
||||||
// // gpu_w2 transpose load
|
ggml_cuda_set_device_constants(model->hparams.sparse_pred_threshold);
|
||||||
// ggml_tensor * ffn_down_t = ggml_cont(aux_ctx, ggml_transpose(aux_ctx, model_layer.ffn_down));
|
#endif
|
||||||
// ggml_build_forward_expand(tmp_transpose_gf, ffn_down_t);
|
|
||||||
// model_layer.ffn_down_t = ffn_down_t;
|
|
||||||
// LLAMA_LOG_INFO(".");
|
|
||||||
// }
|
|
||||||
// ggml_graph_compute_helper(work_buffer, tmp_transpose_gf, 2);
|
|
||||||
// for (llama_layer &model_layer : model -> layers) {
|
|
||||||
// model_layer.ffn_down_t->op = GGML_OP_NONE;
|
|
||||||
// model_layer.ffn_down_t->src[0] = NULL;
|
|
||||||
// model_layer.ffn_down_t->src[1] = NULL;
|
|
||||||
// model_layer.ffn_down_t->src[2] = NULL;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// load gpu_idx and slice mat to gpu
|
// load gpu_idx and slice mat to gpu
|
||||||
for (llama_layer &model_layer : model -> layers) {
|
for (llama_layer &model_layer : model -> layers) {
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
numpy==1.24.4
|
numpy==1.24.4
|
||||||
sentencepiece==0.1.98
|
sentencepiece==0.1.98
|
||||||
gguf>=0.1.0
|
-e ./gguf-py
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue