Merge branch 'master' into concedo_experimental
# Conflicts: # .github/workflows/build.yml # scripts/build-info.sh
This commit is contained in:
commit
6a054b80b0
12 changed files with 753 additions and 146 deletions
37
convert.py
37
convert.py
|
@ -234,13 +234,20 @@ class Params:
|
||||||
|
|
||||||
|
|
||||||
class SentencePieceVocab:
|
class SentencePieceVocab:
|
||||||
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
|
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], vocabtype: Optional[str]) -> None:
|
||||||
|
self.vocabtype = vocabtype
|
||||||
|
if self.vocabtype == "bpe":
|
||||||
|
self.sentencepiece_tokenizer = json.loads(open(str(fname_tokenizer)).read())
|
||||||
|
else:
|
||||||
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
|
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
|
||||||
added_tokens: Dict[str, int]
|
added_tokens: Dict[str, int]
|
||||||
if fname_added_tokens is not None:
|
if fname_added_tokens is not None:
|
||||||
added_tokens = json.load(open(fname_added_tokens))
|
added_tokens = json.load(open(fname_added_tokens))
|
||||||
else:
|
else:
|
||||||
added_tokens = {}
|
added_tokens = {}
|
||||||
|
if self.vocabtype == "bpe":
|
||||||
|
vocab_size: int = len(self.sentencepiece_tokenizer)
|
||||||
|
else:
|
||||||
vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
|
vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
|
||||||
expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
|
expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
|
||||||
actual_ids = sorted(added_tokens.values())
|
actual_ids = sorted(added_tokens.values())
|
||||||
|
@ -255,6 +262,16 @@ class SentencePieceVocab:
|
||||||
|
|
||||||
def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]:
|
def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]:
|
||||||
tokenizer = self.sentencepiece_tokenizer
|
tokenizer = self.sentencepiece_tokenizer
|
||||||
|
if self.vocabtype == "bpe":
|
||||||
|
from transformers.models.gpt2 import tokenization_gpt2
|
||||||
|
byte_encoder = tokenization_gpt2.bytes_to_unicode()
|
||||||
|
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
||||||
|
for i, item in enumerate(tokenizer):
|
||||||
|
text: bytes
|
||||||
|
text = b''.join([x.to_bytes(1, byteorder='big') for x in [byte_decoder[y] for y in item]])
|
||||||
|
score: float = -i
|
||||||
|
yield text, score
|
||||||
|
else:
|
||||||
for i in range(tokenizer.vocab_size()):
|
for i in range(tokenizer.vocab_size()):
|
||||||
text: bytes
|
text: bytes
|
||||||
if tokenizer.is_unknown(i):
|
if tokenizer.is_unknown(i):
|
||||||
|
@ -1196,14 +1213,18 @@ def filter_and_sort_tensors(model: LazyModel) -> LazyModel:
|
||||||
return {name: model[name] for name in TENSORS_LIST if name in model}
|
return {name: model[name] for name in TENSORS_LIST if name in model}
|
||||||
|
|
||||||
|
|
||||||
def load_vocab(path: Path) -> SentencePieceVocab:
|
def load_vocab(path: Path, vocabtype: Optional[str]) -> SentencePieceVocab:
|
||||||
|
print(f"vocabtype: {vocabtype}")
|
||||||
# Be extra-friendly and accept either a file or a directory. Also, if it's
|
# Be extra-friendly and accept either a file or a directory. Also, if it's
|
||||||
# a directory, it might be the model directory, and tokenizer.model might
|
# a directory, it might be the model directory, and tokenizer.model might
|
||||||
# be in the parent of that.
|
# be in the parent of that.
|
||||||
if path.is_dir():
|
if path.is_dir():
|
||||||
path2 = path / "tokenizer.model"
|
vocab_file = "tokenizer.model"
|
||||||
|
if vocabtype == 'bpe':
|
||||||
|
vocab_file = "vocab.json"
|
||||||
|
path2 = path / vocab_file
|
||||||
# Use `.parent` instead of /.. to handle the symlink case better.
|
# Use `.parent` instead of /.. to handle the symlink case better.
|
||||||
path3 = path.parent / "tokenizer.model"
|
path3 = path.parent / vocab_file
|
||||||
if path2.exists():
|
if path2.exists():
|
||||||
path = path2
|
path = path2
|
||||||
elif path3.exists():
|
elif path3.exists():
|
||||||
|
@ -1214,7 +1235,8 @@ def load_vocab(path: Path) -> SentencePieceVocab:
|
||||||
"if it's in another directory, pass the directory as --vocab-dir")
|
"if it's in another directory, pass the directory as --vocab-dir")
|
||||||
added_tokens_path = path.parent / "added_tokens.json"
|
added_tokens_path = path.parent / "added_tokens.json"
|
||||||
print(f"Loading vocab file {path}")
|
print(f"Loading vocab file {path}")
|
||||||
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None)
|
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None,
|
||||||
|
vocabtype)
|
||||||
|
|
||||||
|
|
||||||
def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path:
|
def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path:
|
||||||
|
@ -1252,6 +1274,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
|
||||||
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
|
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
|
||||||
parser.add_argument("model", type=Path,
|
parser.add_argument("model", type=Path,
|
||||||
help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
|
help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
|
||||||
|
parser.add_argument("--vocabtype", default='spm', choices=["spm", "bpe"], help="vocab format (default: spm)")
|
||||||
args = parser.parse_args(args_in)
|
args = parser.parse_args(args_in)
|
||||||
|
|
||||||
vocab: Vocab
|
vocab: Vocab
|
||||||
|
@ -1259,7 +1282,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
|
||||||
model_plus = lazy_load_file(args.model)
|
model_plus = lazy_load_file(args.model)
|
||||||
do_dump_model(model_plus)
|
do_dump_model(model_plus)
|
||||||
elif args.vocab_only:
|
elif args.vocab_only:
|
||||||
vocab = load_vocab(args.vocab_dir or args.model)
|
vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype)
|
||||||
assert args.outfile, "need --outfile if using --vocab-only"
|
assert args.outfile, "need --outfile if using --vocab-only"
|
||||||
outfile = args.outfile
|
outfile = args.outfile
|
||||||
OutputFile.write_vocab_only(outfile, vocab)
|
OutputFile.write_vocab_only(outfile, vocab)
|
||||||
|
@ -1273,7 +1296,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
|
||||||
vocab = model_plus.vocab
|
vocab = model_plus.vocab
|
||||||
else:
|
else:
|
||||||
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
|
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
|
||||||
vocab = load_vocab(vocab_dir)
|
vocab = load_vocab(vocab_dir, args.vocabtype)
|
||||||
params = Params.load(model_plus)
|
params = Params.load(model_plus)
|
||||||
model = model_plus.model
|
model = model_plus.model
|
||||||
model = do_necessary_conversions(model, params)
|
model = do_necessary_conversions(model, params)
|
||||||
|
|
|
@ -432,6 +432,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
exit(0);
|
exit(0);
|
||||||
} else if (arg == "--random-prompt") {
|
} else if (arg == "--random-prompt") {
|
||||||
params.random_prompt = true;
|
params.random_prompt = true;
|
||||||
|
} else if (arg == "--in-prefix-bos") {
|
||||||
|
params.input_prefix_bos = true;
|
||||||
} else if (arg == "--in-prefix") {
|
} else if (arg == "--in-prefix") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -517,6 +519,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
fprintf(stdout, " not supported with --interactive or other interactive options\n");
|
fprintf(stdout, " not supported with --interactive or other interactive options\n");
|
||||||
fprintf(stdout, " --prompt-cache-ro if specified, uses the prompt cache but does not update it.\n");
|
fprintf(stdout, " --prompt-cache-ro if specified, uses the prompt cache but does not update it.\n");
|
||||||
fprintf(stdout, " --random-prompt start with a randomized prompt.\n");
|
fprintf(stdout, " --random-prompt start with a randomized prompt.\n");
|
||||||
|
fprintf(stdout, " --in-prefix-bos prefix BOS to user inputs, preceding the `--in-prefix` string\n");
|
||||||
fprintf(stdout, " --in-prefix STRING string to prefix user inputs with (default: empty)\n");
|
fprintf(stdout, " --in-prefix STRING string to prefix user inputs with (default: empty)\n");
|
||||||
fprintf(stdout, " --in-suffix STRING string to suffix after user inputs with (default: empty)\n");
|
fprintf(stdout, " --in-suffix STRING string to suffix after user inputs with (default: empty)\n");
|
||||||
fprintf(stdout, " -f FNAME, --file FNAME\n");
|
fprintf(stdout, " -f FNAME, --file FNAME\n");
|
||||||
|
|
|
@ -82,6 +82,7 @@ struct gpt_params {
|
||||||
bool interactive_first = false; // wait for user input immediately
|
bool interactive_first = false; // wait for user input immediately
|
||||||
bool multiline_input = false; // reverse the usage of `\`
|
bool multiline_input = false; // reverse the usage of `\`
|
||||||
|
|
||||||
|
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||||
bool instruct = false; // instruction mode (used for Alpaca models)
|
bool instruct = false; // instruction mode (used for Alpaca models)
|
||||||
bool penalize_nl = true; // consider newlines as a repeatable token
|
bool penalize_nl = true; // consider newlines as a repeatable token
|
||||||
bool perplexity = false; // compute perplexity over the prompt
|
bool perplexity = false; // compute perplexity over the prompt
|
||||||
|
|
|
@ -325,6 +325,10 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (params.input_prefix_bos) {
|
||||||
|
fprintf(stderr, "Input prefix with BOS\n");
|
||||||
|
}
|
||||||
|
|
||||||
if (!params.input_prefix.empty()) {
|
if (!params.input_prefix.empty()) {
|
||||||
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
|
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
|
||||||
}
|
}
|
||||||
|
@ -633,16 +637,6 @@ int main(int argc, char ** argv) {
|
||||||
last_n_tokens.push_back(id);
|
last_n_tokens.push_back(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// replace end of text token with newline token when in interactive mode
|
|
||||||
if (id == llama_token_eos() && params.interactive && !params.instruct) {
|
|
||||||
id = llama_token_newline.front();
|
|
||||||
if (params.antiprompt.size() != 0) {
|
|
||||||
// tokenize and inject first reverse prompt
|
|
||||||
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
|
|
||||||
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// add it to the context
|
// add it to the context
|
||||||
embd.push_back(id);
|
embd.push_back(id);
|
||||||
|
|
||||||
|
@ -708,11 +702,34 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// deal with end of text token in interactive mode
|
||||||
|
if (last_n_tokens.back() == llama_token_eos()) {
|
||||||
|
if (params.interactive) {
|
||||||
|
if (params.antiprompt.size() != 0) {
|
||||||
|
// tokenize and inject first reverse prompt
|
||||||
|
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
|
||||||
|
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
|
||||||
|
is_antiprompt = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
is_interacting = true;
|
||||||
|
printf("\n");
|
||||||
|
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
|
||||||
|
fflush(stdout);
|
||||||
|
} else if (params.instruct) {
|
||||||
|
is_interacting = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (n_past > 0 && is_interacting) {
|
if (n_past > 0 && is_interacting) {
|
||||||
if (params.instruct) {
|
if (params.instruct) {
|
||||||
printf("\n> ");
|
printf("\n> ");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (params.input_prefix_bos) {
|
||||||
|
embd_inp.push_back(llama_token_bos());
|
||||||
|
}
|
||||||
|
|
||||||
std::string buffer;
|
std::string buffer;
|
||||||
if (!params.input_prefix.empty()) {
|
if (!params.input_prefix.empty()) {
|
||||||
buffer += params.input_prefix;
|
buffer += params.input_prefix;
|
||||||
|
@ -776,14 +793,10 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// end of text token
|
// end of text token
|
||||||
if (!embd.empty() && embd.back() == llama_token_eos()) {
|
if (!embd.empty() && embd.back() == llama_token_eos() && !(params.instruct || params.interactive)) {
|
||||||
if (params.instruct) {
|
|
||||||
is_interacting = true;
|
|
||||||
} else {
|
|
||||||
fprintf(stderr, " [end of text]\n");
|
fprintf(stderr, " [end of text]\n");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
|
// In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
|
||||||
if (params.interactive && n_remain <= 0 && params.n_predict != -1) {
|
if (params.interactive && n_remain <= 0 && params.n_predict != -1) {
|
||||||
|
|
83
ggml-cuda.cu
83
ggml-cuda.cu
|
@ -1564,12 +1564,14 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
|
||||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||||
const block_q4_K * bq4_K = (const block_q4_K *) vbq;
|
const block_q4_K * bq4_K = (const block_q4_K *) vbq;
|
||||||
|
|
||||||
// iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
|
|
||||||
const int bq8_offset = QR4_K * (iqs / (QI8_1/2));
|
|
||||||
|
|
||||||
float sumf_d = 0.0f;
|
float sumf_d = 0.0f;
|
||||||
float sumf_m = 0.0f;
|
float sumf_m = 0.0f;
|
||||||
|
|
||||||
|
#ifndef GGML_QKK_64
|
||||||
|
|
||||||
|
// iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
|
||||||
|
const int bq8_offset = QR4_K * (iqs / (QI8_1/2));
|
||||||
|
|
||||||
const float d = bq4_K->d;
|
const float d = bq4_K->d;
|
||||||
const float dmin = bq4_K->dmin;
|
const float dmin = bq4_K->dmin;
|
||||||
|
|
||||||
|
@ -1614,6 +1616,43 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
|
||||||
}
|
}
|
||||||
|
|
||||||
return d*sumf_d - dmin*sumf_m;
|
return d*sumf_d - dmin*sumf_m;
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
uint16_t aux16[2];
|
||||||
|
const uint8_t * s = (const uint8_t *)aux16;
|
||||||
|
|
||||||
|
const uint16_t * a = (const uint16_t *)bq4_K->scales;
|
||||||
|
aux16[0] = a[0] & 0x0f0f;
|
||||||
|
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
||||||
|
|
||||||
|
const float dall = bq4_K->d[0];
|
||||||
|
const float dmin = bq4_K->d[1];
|
||||||
|
|
||||||
|
const float d8_1 = bq8_1[0].d;
|
||||||
|
const float d8_2 = bq8_1[1].d;
|
||||||
|
|
||||||
|
const int ui1 = *((const int *)bq8_1[0].qs + iqs);
|
||||||
|
const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
|
||||||
|
const int ui3 = *((const int *)bq8_1[1].qs + iqs);
|
||||||
|
const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4);
|
||||||
|
|
||||||
|
const int * q4 = (const int *)bq4_K->qs + iqs;
|
||||||
|
const int v1 = q4[0];
|
||||||
|
const int v2 = q4[4];
|
||||||
|
|
||||||
|
const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0));
|
||||||
|
const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
|
||||||
|
const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
|
||||||
|
const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0));
|
||||||
|
|
||||||
|
sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]);
|
||||||
|
sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]);
|
||||||
|
|
||||||
|
return dall * sumf_d - dmin * sumf_m;
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
#else
|
#else
|
||||||
return 0.0f; // only to satisfy the compiler
|
return 0.0f; // only to satisfy the compiler
|
||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||||
|
@ -1625,6 +1664,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
|
||||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||||
const block_q5_K * bq5_K = (const block_q5_K *) vbq;
|
const block_q5_K * bq5_K = (const block_q5_K *) vbq;
|
||||||
|
|
||||||
|
#ifndef GGML_QKK_64
|
||||||
|
|
||||||
const int bq8_offset = QR5_K * (iqs / (QI8_1/2));
|
const int bq8_offset = QR5_K * (iqs / (QI8_1/2));
|
||||||
const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4));
|
const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4));
|
||||||
const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4));
|
const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4));
|
||||||
|
@ -1680,6 +1721,42 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
|
||||||
}
|
}
|
||||||
|
|
||||||
return d*sumf_d - dmin*sumf_m;
|
return d*sumf_d - dmin*sumf_m;
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
const int8_t * s = bq5_K->scales;
|
||||||
|
|
||||||
|
const float d = bq5_K->d;
|
||||||
|
|
||||||
|
const float d8_1 = bq8_1[0].d;
|
||||||
|
const float d8_2 = bq8_1[1].d;
|
||||||
|
|
||||||
|
const int ui1 = *((const int *)bq8_1[0].qs + iqs);
|
||||||
|
const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
|
||||||
|
const int ui3 = *((const int *)bq8_1[1].qs + iqs);
|
||||||
|
const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4);
|
||||||
|
|
||||||
|
const int * ql = (const int *)bq5_K->qs + iqs;
|
||||||
|
const int vl1 = ql[0];
|
||||||
|
const int vl2 = ql[4];
|
||||||
|
|
||||||
|
const int step = 4 * iqs; // 0, 4, 8, 12
|
||||||
|
const int im = step/8; // = 0 for iqs = 0, 1, = 1 for iqs = 2, 3
|
||||||
|
const int in = step%8; // 0, 4, 0, 4
|
||||||
|
const int vh = (*((const int *)(bq5_K->qh + in))) >> im;
|
||||||
|
|
||||||
|
const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f);
|
||||||
|
const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f);
|
||||||
|
const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f);
|
||||||
|
const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f);
|
||||||
|
|
||||||
|
const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1])
|
||||||
|
+ d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]);
|
||||||
|
|
||||||
|
return d * sumf_d;
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
#else
|
#else
|
||||||
return 0.0f; // only to satisfy the compiler
|
return 0.0f; // only to satisfy the compiler
|
||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||||
|
|
|
@ -61,6 +61,13 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
|
||||||
// get data from the device into host memory
|
// get data from the device into host memory
|
||||||
void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
|
void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
|
||||||
|
|
||||||
|
// try to find operations that can be run concurrently in the graph
|
||||||
|
// you should run it again if the topology of your graph changes
|
||||||
|
void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
|
||||||
|
|
||||||
|
// if the graph has been optimized for concurrently dispatch
|
||||||
|
bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);
|
||||||
|
|
||||||
// same as ggml_graph_compute but uses Metal
|
// same as ggml_graph_compute but uses Metal
|
||||||
// creates gf->n_threads command buffers in parallel
|
// creates gf->n_threads command buffers in parallel
|
||||||
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
|
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
|
||||||
|
|
145
ggml-metal.m
145
ggml-metal.m
|
@ -36,6 +36,9 @@ struct ggml_metal_context {
|
||||||
int n_buffers;
|
int n_buffers;
|
||||||
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
||||||
|
|
||||||
|
int concur_list[GGML_MAX_NODES];
|
||||||
|
int concur_list_len;
|
||||||
|
|
||||||
// custom kernels
|
// custom kernels
|
||||||
#define GGML_METAL_DECL_KERNEL(name) \
|
#define GGML_METAL_DECL_KERNEL(name) \
|
||||||
id<MTLFunction> function_##name; \
|
id<MTLFunction> function_##name; \
|
||||||
|
@ -98,6 +101,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||||
ctx->device = MTLCreateSystemDefaultDevice();
|
ctx->device = MTLCreateSystemDefaultDevice();
|
||||||
ctx->queue = [ctx->device newCommandQueue];
|
ctx->queue = [ctx->device newCommandQueue];
|
||||||
ctx->n_buffers = 0;
|
ctx->n_buffers = 0;
|
||||||
|
ctx->concur_list_len = 0;
|
||||||
|
|
||||||
// determine if we can use MPS
|
// determine if we can use MPS
|
||||||
if (MPSSupportsMTLDevice(ctx->device)) {
|
if (MPSSupportsMTLDevice(ctx->device)) {
|
||||||
|
@ -217,6 +221,13 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
|
||||||
ctx->n_cb = n_cb;
|
ctx->n_cb = n_cb;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
|
||||||
|
if (ctx->concur_list_len) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// finds the Metal buffer that contains the tensor data on the GPU device
|
// finds the Metal buffer that contains the tensor data on the GPU device
|
||||||
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
||||||
// Metal buffer based on the host memory pointer
|
// Metal buffer based on the host memory pointer
|
||||||
|
@ -355,11 +366,98 @@ void ggml_metal_get_tensor(
|
||||||
memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
|
memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_metal_graph_find_concurrency(
|
||||||
|
struct ggml_metal_context * ctx,
|
||||||
|
struct ggml_cgraph * gf) {
|
||||||
|
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
|
||||||
|
int nodes_unused[GGML_MAX_NODES];
|
||||||
|
|
||||||
|
for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
|
||||||
|
for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
|
||||||
|
ctx->concur_list_len = 0;
|
||||||
|
|
||||||
|
int n_left = gf->n_nodes;
|
||||||
|
int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
|
||||||
|
int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
|
||||||
|
|
||||||
|
while (n_left > 0) {
|
||||||
|
// number of nodes at a layer (that can be issued concurrently)
|
||||||
|
int concurrency = 0;
|
||||||
|
for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
|
||||||
|
if (nodes_unused[i]) {
|
||||||
|
// if the requirements for gf->nodes[i] are satisfied
|
||||||
|
int exe_flag=1;
|
||||||
|
// scan all srcs
|
||||||
|
for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
|
||||||
|
struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
|
||||||
|
if (src_cur) {
|
||||||
|
// if is leaf nodes it's satisfied.
|
||||||
|
if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
|
||||||
|
|
||||||
|
// otherwise this src should be the output from previous nodes.
|
||||||
|
int is_found = 0;
|
||||||
|
// scan 2*search_depth back because we inserted barrier.
|
||||||
|
for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
|
||||||
|
if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
|
||||||
|
}
|
||||||
|
if (is_found == 0) {exe_flag = 0; break;}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (exe_flag) {
|
||||||
|
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
|
||||||
|
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
|
||||||
|
int64_t data_start = (int64_t) gf->nodes[i]->data;
|
||||||
|
int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
|
||||||
|
for (int j = n_start; j < i; j++) {
|
||||||
|
if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
|
||||||
|
&& gf->nodes[j]->op != GGML_OP_VIEW \
|
||||||
|
&& gf->nodes[j]->op != GGML_OP_TRANSPOSE \
|
||||||
|
&& gf->nodes[j]->op != GGML_OP_PERMUTE) {
|
||||||
|
if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
|
||||||
|
((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
exe_flag = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (exe_flag) {
|
||||||
|
ctx->concur_list[level_pos + concurrency] = i;
|
||||||
|
nodes_unused[i] = 0;
|
||||||
|
concurrency++;
|
||||||
|
ctx->concur_list_len++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n_left -= concurrency;
|
||||||
|
// adding a barrier different layer
|
||||||
|
ctx->concur_list[level_pos + concurrency] = -1;
|
||||||
|
ctx->concur_list_len++;
|
||||||
|
// jump all sorted nodes at nodes_bak
|
||||||
|
while (!nodes_unused[n_start]) {n_start++;}
|
||||||
|
level_pos += concurrency + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ctx->concur_list_len > GGML_MAX_NODES) {
|
||||||
|
fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_metal_graph_compute(
|
void ggml_metal_graph_compute(
|
||||||
struct ggml_metal_context * ctx,
|
struct ggml_metal_context * ctx,
|
||||||
struct ggml_cgraph * gf) {
|
struct ggml_cgraph * gf) {
|
||||||
metal_printf("%s: evaluating graph\n", __func__);
|
metal_printf("%s: evaluating graph\n", __func__);
|
||||||
|
|
||||||
|
// if there is ctx->concur_list, dispatch concurrently
|
||||||
|
// else fallback to serial dispatch
|
||||||
|
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
|
||||||
|
|
||||||
|
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
|
||||||
|
|
||||||
|
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
|
||||||
|
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
|
||||||
|
|
||||||
// create multiple command buffers and enqueue them
|
// create multiple command buffers and enqueue them
|
||||||
// then, we encode the graph into the command buffers in parallel
|
// then, we encode the graph into the command buffers in parallel
|
||||||
|
|
||||||
|
@ -378,7 +476,7 @@ void ggml_metal_graph_compute(
|
||||||
dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
|
dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
|
||||||
|
|
||||||
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
||||||
const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
|
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
||||||
|
|
||||||
dispatch_async(queue, ^{
|
dispatch_async(queue, ^{
|
||||||
size_t offs_src0 = 0;
|
size_t offs_src0 = 0;
|
||||||
|
@ -390,9 +488,20 @@ void ggml_metal_graph_compute(
|
||||||
id<MTLComputeCommandEncoder> encoder = nil;
|
id<MTLComputeCommandEncoder> encoder = nil;
|
||||||
|
|
||||||
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
||||||
const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
|
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
|
||||||
|
|
||||||
|
for (int ind = node_start; ind < node_end; ++ind) {
|
||||||
|
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
||||||
|
|
||||||
|
if (i == -1) {
|
||||||
|
if (encoder == nil) {
|
||||||
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = node_start; i < node_end; ++i) {
|
|
||||||
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
||||||
|
|
||||||
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
||||||
|
@ -463,7 +572,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ggml_nelements(src1) == ne10) {
|
if (ggml_nelements(src1) == ne10) {
|
||||||
|
@ -484,7 +593,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ggml_nelements(src1) == ne10) {
|
if (ggml_nelements(src1) == ne10) {
|
||||||
|
@ -505,7 +614,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
const float scale = *(const float *) src1->data;
|
const float scale = *(const float *) src1->data;
|
||||||
|
@ -524,7 +633,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_UNARY_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_silu];
|
[encoder setComputePipelineState:ctx->pipeline_silu];
|
||||||
|
@ -538,7 +647,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_UNARY_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_relu];
|
[encoder setComputePipelineState:ctx->pipeline_relu];
|
||||||
|
@ -552,7 +661,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_UNARY_OP_GELU:
|
case GGML_UNARY_OP_GELU:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
||||||
|
@ -572,7 +681,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
const int nth = 32;
|
const int nth = 32;
|
||||||
|
@ -590,7 +699,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
const int n_past = ((int32_t *)(dst->op_params))[0];
|
const int n_past = ((int32_t *)(dst->op_params))[0];
|
||||||
|
@ -653,7 +762,7 @@ void ggml_metal_graph_compute(
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
int nth0 = 32;
|
int nth0 = 32;
|
||||||
|
@ -780,7 +889,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
|
@ -809,7 +918,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
float eps;
|
float eps;
|
||||||
|
@ -832,7 +941,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
const float eps = 1e-5f;
|
const float eps = 1e-5f;
|
||||||
|
@ -854,7 +963,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_ALIBI:
|
case GGML_OP_ALIBI:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
||||||
|
@ -897,7 +1006,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
|
@ -941,7 +1050,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
}
|
}
|
||||||
|
|
||||||
const int nth = 32;
|
const int nth = 32;
|
||||||
|
|
113
ggml-metal.metal
113
ggml-metal.metal
|
@ -387,87 +387,90 @@ kernel void kernel_rms_norm(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
|
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
||||||
float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
|
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
||||||
|
// we assume that the yl's have been multiplied with the appropriate scale factor
|
||||||
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
||||||
|
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
||||||
float d = qb_curr->d;
|
float d = qb_curr->d;
|
||||||
float4 acc = 0.f;
|
float2 acc = 0.f;
|
||||||
device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
|
||||||
for (int i = 0; i < 16; i+=2) {
|
for (int i = 0; i < 8; i+=2) {
|
||||||
acc[0] += yl[i] * (qs[i / 2] & 0x000F);
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
||||||
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
||||||
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
|
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
|
||||||
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
|
+ yl[i + 9] * (qs[i / 2] & 0xF000);
|
||||||
}
|
}
|
||||||
return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
|
return d * (sumy * -8.f + acc[0] + acc[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
|
// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
|
||||||
float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
|
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
||||||
|
// we assume that the yl's have been multiplied with the appropriate scale factor
|
||||||
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
||||||
|
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
|
||||||
float d = qb_curr->d;
|
float d = qb_curr->d;
|
||||||
float m = qb_curr->m;
|
float m = qb_curr->m;
|
||||||
float4 acc = 0.f;
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
|
||||||
device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
|
float2 acc = 0.f;
|
||||||
for (int i = 0; i < 16; i+=2) {
|
for (int i = 0; i < 8; i+=2) {
|
||||||
acc[0] += yl[i] * (qs[i / 2] & 0x000F);
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
||||||
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
||||||
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
|
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
|
||||||
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
|
+ yl[i + 9] * (qs[i / 2] & 0xF000);
|
||||||
}
|
}
|
||||||
return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
|
return d * (acc[0] + acc[1]) + sumy * m;
|
||||||
}
|
}
|
||||||
|
|
||||||
// putting them in the kernel cause a significant performance penalty
|
// putting them in the kernel cause a significant performance penalty
|
||||||
#define N_DST 4 // each SIMD group works on 4 rows
|
#define N_DST 4 // each SIMD group works on 4 rows
|
||||||
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
||||||
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
||||||
template<typename block_q_type>
|
//Note: This is a template, but strictly speaking it only applies to
|
||||||
|
// quantizations where the block size is 32. It also does not
|
||||||
|
// giard against the number of rows not being divisible by
|
||||||
|
// N_DST, so this is another explicit assumption of the implementation.
|
||||||
|
template<typename block_q_type, int nr, int nsg, int nw>
|
||||||
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
|
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
|
||||||
int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
|
int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
|
||||||
uint2 tgpig, uint tiisg, uint sgitg) {
|
uint2 tgpig, uint tiisg, uint sgitg) {
|
||||||
const int nb = ne00/QK4_0;
|
const int nb = ne00/QK4_0;
|
||||||
const int r0 = tgpig.x;
|
const int r0 = tgpig.x;
|
||||||
const int r1 = tgpig.y;
|
const int r1 = tgpig.y;
|
||||||
device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
||||||
|
device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb;
|
||||||
device const float * y = (device const float *) src1 + r1*ne10;
|
device const float * y = (device const float *) src1 + r1*ne10;
|
||||||
float4 y_curr[8]; // src1 vector cache
|
float yl[16]; // src1 vector cache
|
||||||
float sumf[N_DST]={0.f}, all_sum;
|
float sumf[nr]={0.f};
|
||||||
thread float * yl=(thread float *)y_curr;
|
|
||||||
|
|
||||||
// each thread in a SIMD group deals with 1 block.
|
const int ix = tiisg/2;
|
||||||
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
|
const int il = 8*(tiisg%2);
|
||||||
|
|
||||||
|
device const float * yb = y + ix * QK4_0 + il;
|
||||||
|
|
||||||
|
// each thread in a SIMD group deals with half a block.
|
||||||
|
for (int ib = ix; ib < nb; ib += nw/2) {
|
||||||
float sumy = 0;
|
float sumy = 0;
|
||||||
for (int i = 0; i < QK4_0 / 4; i++) {
|
for (int i = 0; i < 8; i += 2) {
|
||||||
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
|
sumy += yb[i] + yb[i+1];
|
||||||
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
yl[i+0] = yb[i+ 0];
|
||||||
|
yl[i+1] = yb[i+ 1]/256.f;
|
||||||
|
sumy += yb[i+16] + yb[i+17];
|
||||||
|
yl[i+8] = yb[i+16]/16.f;
|
||||||
|
yl[i+9] = yb[i+17]/4096.f;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; row++) {
|
for (int row = 0; row < nr; row++) {
|
||||||
sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
|
sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// from now loads two rows every time and 16 blocks per row
|
yb += QK4_0 * 16;
|
||||||
int ir = tiisg / (N_SIMDWIDTH / 2);
|
|
||||||
int ib = tiisg % (N_SIMDWIDTH / 2);
|
|
||||||
for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
|
|
||||||
int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
|
|
||||||
float sumy = 0;
|
|
||||||
for (int i = 0; i < QK4_0 / 4; i++) {
|
|
||||||
y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
|
|
||||||
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; row+=2) {
|
for (int row = 0; row < nr; ++row) {
|
||||||
if (nb_start + ib < nb) {
|
const float tot = simd_sum(sumf[row]);
|
||||||
sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
|
if (tiisg == 0 && first_row + row < ne01) {
|
||||||
}
|
dst[r1*ne0 + first_row + row] = tot;
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
|
||||||
all_sum = simd_sum(sumf[row]);
|
|
||||||
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
|
|
||||||
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -483,7 +486,7 @@ kernel void kernel_mul_mat_q4_0_f32(
|
||||||
uint2 tgpig[[threadgroup_position_in_grid]],
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mat_q4_1_f32(
|
kernel void kernel_mul_mat_q4_1_f32(
|
||||||
|
@ -497,7 +500,7 @@ kernel void kernel_mul_mat_q4_1_f32(
|
||||||
uint2 tgpig[[threadgroup_position_in_grid]],
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mat_f16_f32(
|
kernel void kernel_mul_mat_f16_f32(
|
||||||
|
|
86
ggml.c
86
ggml.c
|
@ -4230,6 +4230,15 @@ bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
|
||||||
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) {
|
||||||
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||||
|
|
||||||
|
return
|
||||||
|
tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] &&
|
||||||
|
tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
|
||||||
|
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
||||||
|
}
|
||||||
|
|
||||||
bool ggml_is_permuted(const struct ggml_tensor * tensor) {
|
bool ggml_is_permuted(const struct ggml_tensor * tensor) {
|
||||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||||
|
|
||||||
|
@ -7022,14 +7031,16 @@ struct ggml_tensor * ggml_flash_attn(
|
||||||
}
|
}
|
||||||
|
|
||||||
//struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
|
//struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
|
||||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, q->ne);
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, q->n_dims, q->ne);
|
||||||
|
|
||||||
|
int32_t t = masked ? 1 : 0;
|
||||||
|
ggml_set_op_params(result, &t, sizeof(t));
|
||||||
|
|
||||||
result->op = GGML_OP_FLASH_ATTN;
|
result->op = GGML_OP_FLASH_ATTN;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
result->src[0] = q;
|
result->src[0] = q;
|
||||||
result->src[1] = k;
|
result->src[1] = k;
|
||||||
result->src[2] = v;
|
result->src[2] = v;
|
||||||
result->src[3] = ggml_new_i32(ctx, masked ? 1 : 0);
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -7053,7 +7064,7 @@ struct ggml_tensor * ggml_flash_ff(
|
||||||
}
|
}
|
||||||
|
|
||||||
//struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
//struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, a->ne);
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne);
|
||||||
|
|
||||||
result->op = GGML_OP_FLASH_FF;
|
result->op = GGML_OP_FLASH_FF;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
@ -7119,13 +7130,15 @@ struct ggml_tensor * ggml_flash_attn_back(
|
||||||
|
|
||||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||||
|
|
||||||
|
int32_t masked_i = masked ? 1 : 0;
|
||||||
|
ggml_set_op_params(result, &masked_i, sizeof(masked_i));
|
||||||
|
|
||||||
result->op = GGML_OP_FLASH_ATTN_BACK;
|
result->op = GGML_OP_FLASH_ATTN_BACK;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
result->src[0] = q;
|
result->src[0] = q;
|
||||||
result->src[1] = k;
|
result->src[1] = k;
|
||||||
result->src[2] = v;
|
result->src[2] = v;
|
||||||
result->src[3] = d;
|
result->src[3] = d;
|
||||||
result->src[4] = ggml_new_i32(ctx, masked ? 1 : 0);
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -9815,8 +9828,8 @@ static void ggml_compute_forward_gelu_f32(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
|
||||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
|
||||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||||
|
|
||||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||||
|
@ -9874,8 +9887,8 @@ static void ggml_compute_forward_gelu_quick_f32(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
|
||||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
|
||||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||||
|
|
||||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||||
|
@ -9933,8 +9946,8 @@ static void ggml_compute_forward_silu_f32(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
|
||||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
|
||||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||||
|
|
||||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||||
|
@ -9993,9 +10006,9 @@ static void ggml_compute_forward_silu_back_f32(
|
||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
const struct ggml_tensor * grad,
|
const struct ggml_tensor * grad,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
GGML_ASSERT(ggml_is_contiguous(grad));
|
GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad));
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
|
||||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
|
||||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||||
GGML_ASSERT(ggml_are_same_shape(src0, grad));
|
GGML_ASSERT(ggml_are_same_shape(src0, grad));
|
||||||
|
|
||||||
|
@ -14760,7 +14773,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_ATTN:
|
case GGML_OP_FLASH_ATTN:
|
||||||
{
|
{
|
||||||
const int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
|
const int32_t t = ggml_get_op_params_i32(tensor, 0);
|
||||||
GGML_ASSERT(t == 0 || t == 1);
|
GGML_ASSERT(t == 0 || t == 1);
|
||||||
const bool masked = t != 0;
|
const bool masked = t != 0;
|
||||||
ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor);
|
ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor);
|
||||||
|
@ -14771,7 +14784,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_ATTN_BACK:
|
case GGML_OP_FLASH_ATTN_BACK:
|
||||||
{
|
{
|
||||||
int32_t t = ggml_get_i32_1d(tensor->src[4], 0);
|
int32_t t = ggml_get_op_params_i32(tensor, 0);
|
||||||
GGML_ASSERT(t == 0 || t == 1);
|
GGML_ASSERT(t == 0 || t == 1);
|
||||||
bool masked = t != 0;
|
bool masked = t != 0;
|
||||||
ggml_compute_forward_flash_attn_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], masked, tensor);
|
ggml_compute_forward_flash_attn_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], masked, tensor);
|
||||||
|
@ -15389,7 +15402,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
{
|
{
|
||||||
struct ggml_tensor * flash_grad = NULL;
|
struct ggml_tensor * flash_grad = NULL;
|
||||||
if (src0->grad || src1->grad || tensor->src[2]->grad) {
|
if (src0->grad || src1->grad || tensor->src[2]->grad) {
|
||||||
int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
|
int32_t t = ggml_get_op_params_i32(tensor, 0);
|
||||||
GGML_ASSERT(t == 0 || t == 1);
|
GGML_ASSERT(t == 0 || t == 1);
|
||||||
bool masked = t != 0;
|
bool masked = t != 0;
|
||||||
flash_grad =
|
flash_grad =
|
||||||
|
@ -15661,6 +15674,34 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
|
||||||
|
|
||||||
|
static size_t hash(void * p) {
|
||||||
|
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool hash_insert(void * hash_table[], void * p) {
|
||||||
|
size_t h = hash(p);
|
||||||
|
|
||||||
|
// linear probing
|
||||||
|
size_t i = h;
|
||||||
|
while (hash_table[i] != NULL && hash_table[i] != p) {
|
||||||
|
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
|
||||||
|
if (i == h) {
|
||||||
|
// hash table is full
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hash_table[i] == p) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// insert
|
||||||
|
hash_table[i] = p;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
|
static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
|
||||||
if (node->grad == NULL) {
|
if (node->grad == NULL) {
|
||||||
// this usually happens when we generate intermediate nodes from constants in the backward pass
|
// this usually happens when we generate intermediate nodes from constants in the backward pass
|
||||||
|
@ -15671,17 +15712,9 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if already visited
|
// check if already visited
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
if (hash_insert(cgraph->visited_hash_table, node)) {
|
||||||
if (cgraph->nodes[i] == node) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < cgraph->n_leafs; i++) {
|
|
||||||
if (cgraph->leafs[i] == node) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < GGML_MAX_SRC; ++i) {
|
for (int i = 0; i < GGML_MAX_SRC; ++i) {
|
||||||
if (node->src[i]) {
|
if (node->src[i]) {
|
||||||
|
@ -15743,6 +15776,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
|
||||||
/*.nodes =*/ { NULL },
|
/*.nodes =*/ { NULL },
|
||||||
/*.grads =*/ { NULL },
|
/*.grads =*/ { NULL },
|
||||||
/*.leafs =*/ { NULL },
|
/*.leafs =*/ { NULL },
|
||||||
|
/*.hash_table =*/ { NULL },
|
||||||
/*.perf_runs =*/ 0,
|
/*.perf_runs =*/ 0,
|
||||||
/*.perf_cycles =*/ 0,
|
/*.perf_cycles =*/ 0,
|
||||||
/*.perf_time_us =*/ 0,
|
/*.perf_time_us =*/ 0,
|
||||||
|
@ -15784,7 +15818,7 @@ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cg
|
||||||
|
|
||||||
if (node->is_param) {
|
if (node->is_param) {
|
||||||
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
|
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
|
||||||
ggml_build_forward_impl(&result, node->grad, true);
|
ggml_build_forward_expand(&result, node->grad);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
9
ggml.h
9
ggml.h
|
@ -441,7 +441,7 @@ extern "C" {
|
||||||
|
|
||||||
void * extra; // extra things e.g. for ggml-cuda.cu
|
void * extra; // extra things e.g. for ggml-cuda.cu
|
||||||
|
|
||||||
char padding[8];
|
char padding[4];
|
||||||
};
|
};
|
||||||
|
|
||||||
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
|
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
|
||||||
|
@ -462,6 +462,11 @@ extern "C" {
|
||||||
void * abort_callback_data;
|
void * abort_callback_data;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// next prime after GGML_MAX_NODES
|
||||||
|
// #define GGML_GRAPH_HASHTABLE_SIZE 4099
|
||||||
|
// next prime after GGML_MAX_NODES * 2 (nodes + leafs)
|
||||||
|
#define GGML_GRAPH_HASHTABLE_SIZE 8273
|
||||||
|
|
||||||
// computation graph
|
// computation graph
|
||||||
struct ggml_cgraph {
|
struct ggml_cgraph {
|
||||||
int n_nodes;
|
int n_nodes;
|
||||||
|
@ -471,6 +476,8 @@ extern "C" {
|
||||||
struct ggml_tensor * grads[GGML_MAX_NODES];
|
struct ggml_tensor * grads[GGML_MAX_NODES];
|
||||||
struct ggml_tensor * leafs[GGML_MAX_NODES];
|
struct ggml_tensor * leafs[GGML_MAX_NODES];
|
||||||
|
|
||||||
|
void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE];
|
||||||
|
|
||||||
// performance
|
// performance
|
||||||
int perf_runs;
|
int perf_runs;
|
||||||
int64_t perf_cycles;
|
int64_t perf_cycles;
|
||||||
|
|
325
k_quants.c
325
k_quants.c
|
@ -1666,6 +1666,62 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
|
|
||||||
*s = hsum_float_8(acc) + summs;
|
*s = hsum_float_8(acc) + summs;
|
||||||
|
|
||||||
|
#elif defined __AVX__
|
||||||
|
|
||||||
|
const __m128i m3 = _mm_set1_epi8(3);
|
||||||
|
|
||||||
|
__m256 acc = _mm256_setzero_ps();
|
||||||
|
|
||||||
|
uint32_t ud, um;
|
||||||
|
const uint8_t * restrict db = (const uint8_t *)&ud;
|
||||||
|
const uint8_t * restrict mb = (const uint8_t *)&um;
|
||||||
|
|
||||||
|
float summs = 0;
|
||||||
|
|
||||||
|
// TODO: optimize this
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
|
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
||||||
|
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
||||||
|
|
||||||
|
const uint8_t * restrict q2 = x[i].qs;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
|
const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
|
||||||
|
ud = (sc[0] >> 0) & 0x0f0f0f0f;
|
||||||
|
um = (sc[0] >> 4) & 0x0f0f0f0f;
|
||||||
|
|
||||||
|
int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
|
||||||
|
summs += dmin * smin;
|
||||||
|
|
||||||
|
const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
|
||||||
|
const __m128i q2_0 = _mm_and_si128(q2bits, m3);
|
||||||
|
const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
|
||||||
|
const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
|
||||||
|
const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
|
||||||
|
|
||||||
|
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
||||||
|
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
||||||
|
|
||||||
|
const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0));
|
||||||
|
const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1));
|
||||||
|
const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0));
|
||||||
|
const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1));
|
||||||
|
|
||||||
|
const __m256i p_0 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0));
|
||||||
|
const __m256i p_1 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1));
|
||||||
|
const __m256i p_2 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2));
|
||||||
|
const __m256i p_3 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3));
|
||||||
|
|
||||||
|
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc);
|
||||||
|
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc);
|
||||||
|
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc);
|
||||||
|
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc);
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = hsum_float_8(acc) + summs;
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
|
@ -2295,6 +2351,93 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
|
|
||||||
*s = hsum_float_8(acc);
|
*s = hsum_float_8(acc);
|
||||||
|
|
||||||
|
#elif defined __AVX__
|
||||||
|
|
||||||
|
const __m128i m3 = _mm_set1_epi8(3);
|
||||||
|
const __m128i m1 = _mm_set1_epi8(1);
|
||||||
|
|
||||||
|
__m256 acc = _mm256_setzero_ps();
|
||||||
|
|
||||||
|
uint64_t aux64;
|
||||||
|
|
||||||
|
uint16_t aux16[2];
|
||||||
|
const int8_t * aux8 = (const int8_t *)aux16;
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
|
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
||||||
|
|
||||||
|
const uint8_t * restrict q3 = x[i].qs;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
|
const uint16_t a = *(const uint16_t *)x[i].scales;
|
||||||
|
aux16[0] = a & 0x0f0f;
|
||||||
|
aux16[1] = (a >> 4) & 0x0f0f;
|
||||||
|
|
||||||
|
const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8);
|
||||||
|
const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8);
|
||||||
|
const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8);
|
||||||
|
const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8);
|
||||||
|
|
||||||
|
memcpy(&aux64, x[i].hmask, 8);
|
||||||
|
|
||||||
|
__m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
|
||||||
|
__m128i q3h_1 = _mm_srli_epi16(q3h_0, 2);
|
||||||
|
__m128i q3h_2 = _mm_srli_epi16(q3h_0, 4);
|
||||||
|
__m128i q3h_3 = _mm_srli_epi16(q3h_0, 6);
|
||||||
|
q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2);
|
||||||
|
q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2);
|
||||||
|
q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2);
|
||||||
|
q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2);
|
||||||
|
|
||||||
|
// load low 2 bits
|
||||||
|
const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
|
||||||
|
|
||||||
|
// prepare low and high bits
|
||||||
|
const __m128i q3l_0 = _mm_and_si128(q3bits, m3);
|
||||||
|
const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3);
|
||||||
|
const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3);
|
||||||
|
const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3);
|
||||||
|
|
||||||
|
// load Q8 quants
|
||||||
|
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
||||||
|
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
||||||
|
|
||||||
|
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16,
|
||||||
|
// and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
|
||||||
|
// and 2 if the high bit was set)
|
||||||
|
const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0));
|
||||||
|
const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1));
|
||||||
|
const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0));
|
||||||
|
const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1));
|
||||||
|
|
||||||
|
__m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0));
|
||||||
|
__m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1));
|
||||||
|
__m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0));
|
||||||
|
__m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1));
|
||||||
|
|
||||||
|
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
|
||||||
|
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
|
||||||
|
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
|
||||||
|
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
|
||||||
|
|
||||||
|
// multiply with scales
|
||||||
|
p16_0 = _mm_madd_epi16(scale_0, p16_0);
|
||||||
|
p16_1 = _mm_madd_epi16(scale_1, p16_1);
|
||||||
|
p16_2 = _mm_madd_epi16(scale_2, p16_2);
|
||||||
|
p16_3 = _mm_madd_epi16(scale_3, p16_3);
|
||||||
|
|
||||||
|
p16_0 = _mm_add_epi32(p16_0, p16_2);
|
||||||
|
p16_1 = _mm_add_epi32(p16_1, p16_3);
|
||||||
|
__m256i p16 = _mm256_set_m128i(p16_1, p16_0);
|
||||||
|
|
||||||
|
// multiply with block scale and accumulate
|
||||||
|
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = hsum_float_8(acc);
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
int8_t aux8[QK_K];
|
int8_t aux8[QK_K];
|
||||||
|
@ -2781,6 +2924,60 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
|
|
||||||
*s = hsum_float_8(acc) - summs;
|
*s = hsum_float_8(acc) - summs;
|
||||||
|
|
||||||
|
#elif defined __AVX__
|
||||||
|
|
||||||
|
const __m128i m4 = _mm_set1_epi8(0xF);
|
||||||
|
|
||||||
|
__m256 acc = _mm256_setzero_ps();
|
||||||
|
|
||||||
|
float summs = 0;
|
||||||
|
|
||||||
|
uint16_t aux16[2];
|
||||||
|
const uint8_t * scales = (const uint8_t *)aux16;
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
|
const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
|
||||||
|
const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
|
||||||
|
const __m256 vd = _mm256_set1_ps(d);
|
||||||
|
|
||||||
|
const uint16_t * a = (const uint16_t *)x[i].scales;
|
||||||
|
aux16[0] = a[0] & 0x0f0f;
|
||||||
|
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
||||||
|
|
||||||
|
summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
|
||||||
|
|
||||||
|
const uint8_t * restrict q4 = x[i].qs;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
|
const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
|
||||||
|
const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0);
|
||||||
|
const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1);
|
||||||
|
const __m128i q4_0 = _mm_and_si128(q4bits_0, m4);
|
||||||
|
const __m128i q4_1 = _mm_and_si128(q4bits_1, m4);
|
||||||
|
const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4);
|
||||||
|
const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4);
|
||||||
|
|
||||||
|
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
||||||
|
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
||||||
|
|
||||||
|
const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
|
||||||
|
const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
|
||||||
|
const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
|
||||||
|
const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
|
||||||
|
|
||||||
|
const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0);
|
||||||
|
const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1);
|
||||||
|
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_set_m128i(p32_1, p32_0))), acc);
|
||||||
|
|
||||||
|
const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2);
|
||||||
|
const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3);
|
||||||
|
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_set_m128i(p32_3, p32_2))), acc);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = hsum_float_8(acc) - summs;
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
uint8_t aux8[QK_K];
|
uint8_t aux8[QK_K];
|
||||||
|
@ -3295,6 +3492,63 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
|
|
||||||
*s = hsum_float_8(acc);
|
*s = hsum_float_8(acc);
|
||||||
|
|
||||||
|
#elif defined __AVX__
|
||||||
|
|
||||||
|
const __m128i m4 = _mm_set1_epi8(0xF);
|
||||||
|
const __m128i mone = _mm_set1_epi8(1);
|
||||||
|
|
||||||
|
__m256 acc = _mm256_setzero_ps();
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
|
const uint8_t * restrict q5 = x[i].qs;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
|
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
||||||
|
|
||||||
|
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
|
||||||
|
|
||||||
|
const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]);
|
||||||
|
const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]);
|
||||||
|
const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]);
|
||||||
|
const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]);
|
||||||
|
|
||||||
|
int64_t aux64;
|
||||||
|
memcpy(&aux64, x[i].qh, 8);
|
||||||
|
const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64);
|
||||||
|
const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2);
|
||||||
|
|
||||||
|
const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4);
|
||||||
|
const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4);
|
||||||
|
const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4);
|
||||||
|
const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4);
|
||||||
|
|
||||||
|
const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4);
|
||||||
|
const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4);
|
||||||
|
const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4);
|
||||||
|
const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4);
|
||||||
|
|
||||||
|
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
||||||
|
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
||||||
|
|
||||||
|
const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0)));
|
||||||
|
const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1)));
|
||||||
|
const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0)));
|
||||||
|
const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1)));
|
||||||
|
const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0)));
|
||||||
|
const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1)));
|
||||||
|
const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0)));
|
||||||
|
const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1)));
|
||||||
|
|
||||||
|
const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2));
|
||||||
|
const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3));
|
||||||
|
|
||||||
|
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_set_m128i(dot_1, dot_0))), acc);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = hsum_float_8(acc);
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
int8_t aux8[QK_K];
|
int8_t aux8[QK_K];
|
||||||
|
@ -3857,6 +4111,77 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
|
|
||||||
*s = hsum_float_8(acc);
|
*s = hsum_float_8(acc);
|
||||||
|
|
||||||
|
#elif defined __AVX__
|
||||||
|
|
||||||
|
const __m128i m4 = _mm_set1_epi8(0xF);
|
||||||
|
const __m128i m2 = _mm_set1_epi8(3);
|
||||||
|
const __m128i m32s = _mm_set1_epi8(32);
|
||||||
|
|
||||||
|
__m256 acc = _mm256_setzero_ps();
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
|
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
||||||
|
|
||||||
|
const uint8_t * restrict q4 = x[i].ql;
|
||||||
|
const uint8_t * restrict qh = x[i].qh;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
|
const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
|
||||||
|
const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
|
||||||
|
const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
|
||||||
|
const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
|
||||||
|
|
||||||
|
__m128i sumi_0 = _mm_setzero_si128();
|
||||||
|
__m128i sumi_1 = _mm_setzero_si128();
|
||||||
|
|
||||||
|
const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
|
||||||
|
const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
|
||||||
|
|
||||||
|
const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
|
||||||
|
const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
|
||||||
|
|
||||||
|
const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4);
|
||||||
|
const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4);
|
||||||
|
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4);
|
||||||
|
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4);
|
||||||
|
|
||||||
|
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0);
|
||||||
|
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1);
|
||||||
|
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2);
|
||||||
|
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3);
|
||||||
|
|
||||||
|
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
||||||
|
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
||||||
|
|
||||||
|
__m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0));
|
||||||
|
__m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1));
|
||||||
|
__m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0));
|
||||||
|
__m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1));
|
||||||
|
|
||||||
|
__m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
|
||||||
|
__m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
|
||||||
|
__m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
|
||||||
|
__m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
|
||||||
|
|
||||||
|
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
|
||||||
|
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
|
||||||
|
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
|
||||||
|
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
|
||||||
|
|
||||||
|
p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
|
||||||
|
p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
|
||||||
|
p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
|
||||||
|
p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
|
||||||
|
|
||||||
|
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
|
||||||
|
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
|
||||||
|
|
||||||
|
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(_mm256_set_m128i(sumi_1, sumi_0))), acc);
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = hsum_float_8(acc);
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
int8_t aux8[QK_K];
|
int8_t aux8[QK_K];
|
||||||
|
|
|
@ -1721,12 +1721,17 @@ static bool llama_eval_internal(
|
||||||
// run the computation
|
// run the computation
|
||||||
ggml_build_forward_expand(&gf, cur);
|
ggml_build_forward_expand(&gf, cur);
|
||||||
|
|
||||||
|
// fprintf(stderr, "graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf.n_nodes, gf.n_leafs);
|
||||||
|
|
||||||
#if GGML_USE_MPI
|
#if GGML_USE_MPI
|
||||||
ggml_mpi_graph_compute_pre(lctx.ctx_mpi, &gf, n_layer);
|
ggml_mpi_graph_compute_pre(lctx.ctx_mpi, &gf, n_layer);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef GGML_USE_METAL
|
#ifdef GGML_USE_METAL
|
||||||
if (lctx.ctx_metal && N == 1) {
|
if (lctx.ctx_metal && N == 1) {
|
||||||
|
if (!ggml_metal_if_optimized(lctx.ctx_metal)) {
|
||||||
|
ggml_metal_graph_find_concurrency(lctx.ctx_metal,&gf);
|
||||||
|
}
|
||||||
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
|
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
|
||||||
ggml_metal_graph_compute(lctx.ctx_metal, &gf);
|
ggml_metal_graph_compute(lctx.ctx_metal, &gf);
|
||||||
ggml_metal_get_tensor (lctx.ctx_metal, cur);
|
ggml_metal_get_tensor (lctx.ctx_metal, cur);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue