Merge branch 'master' into sycl_fix_non_intel_fp16

This commit is contained in:
OuadiElfarouki 2024-04-03 16:43:32 +01:00
commit a7c6758214
17 changed files with 827 additions and 890 deletions

View file

@ -18,6 +18,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
### Hot topics ### Hot topics
- **MoE memory layout has been updated - reconvert models for `mmap` support and regenerate `imatrix` https://github.com/ggerganov/llama.cpp/pull/6387**
- Model sharding instructions using `gguf-split` https://github.com/ggerganov/llama.cpp/discussions/6404 - Model sharding instructions using `gguf-split` https://github.com/ggerganov/llama.cpp/discussions/6404
- Fix major bug in Metal batched inference https://github.com/ggerganov/llama.cpp/pull/6225 - Fix major bug in Metal batched inference https://github.com/ggerganov/llama.cpp/pull/6225
- Multi-GPU pipeline parallelizm support https://github.com/ggerganov/llama.cpp/pull/6017 - Multi-GPU pipeline parallelizm support https://github.com/ggerganov/llama.cpp/pull/6017

View file

@ -323,8 +323,7 @@ class Model(ABC):
toktypes: list[int] = [] toktypes: list[int] = []
if not tokenizer_path.is_file(): if not tokenizer_path.is_file():
print(f'Error: Missing {tokenizer_path}', file=sys.stderr) raise FileNotFoundError(f"File not found: {tokenizer_path}")
sys.exit(1)
tokenizer = SentencePieceProcessor(str(tokenizer_path)) tokenizer = SentencePieceProcessor(str(tokenizer_path))
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
@ -1216,6 +1215,8 @@ class LlamaModel(Model):
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
n_head = self.hparams.get("num_attention_heads") n_head = self.hparams.get("num_attention_heads")
n_kv_head = self.hparams.get("num_key_value_heads") n_kv_head = self.hparams.get("num_key_value_heads")
n_experts = self.hparams.get("num_local_experts")
experts = dict()
for name, data_torch in self.get_tensors(): for name, data_torch in self.get_tensors():
# we don't need these # we don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
@ -1236,6 +1237,153 @@ class LlamaModel(Model):
data = data.squeeze() data = data.squeeze()
# process the experts separately
if name.find("block_sparse_moe.experts") != -1:
experts[name] = data
if len(experts) >= n_experts:
# merge the experts into a single 3d tensor
for bid in range(block_count):
for wid in range(1, 4):
full = True
for xid in range(n_experts):
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.w{wid}.weight"
if ename not in experts:
full = False
break
if not full:
continue
datas = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.w{wid}.weight"
datas.append(experts[ename])
del experts[ename]
data = np.stack(datas, axis=0)
data_dtype = data.dtype
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
if self.ftype == 1 and data_dtype == np.float32:
data = data.astype(np.float16)
merged_name = f"layers.{bid}.feed_forward.experts.w{wid}.weight"
new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()
print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)
continue
# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()
n_dims = len(data.shape)
data_dtype = data.dtype
# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
# 1d tensors need to be converted to float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)
# if f16 desired, convert any float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts.keys()}")
@Model.register("GrokForCausalLM")
class GrokModel(Model):
model_arch = gguf.MODEL_ARCH.GROK
def set_vocab(self):
self._set_vocab_sentencepiece()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_name("Grok")
def write_tensors(self):
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
n_experts = self.hparams.get("num_local_experts")
experts = dict()
for name, data_torch in self.get_tensors():
# we don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
continue
old_dtype = data_torch.dtype
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)
data = data_torch.squeeze().numpy()
# process the experts separately
if name.find(".moe.") != -1:
experts[name] = data
if len(experts) >= n_experts:
# merge the experts into a single 3d tensor
for bid in range(block_count):
for wid in ["linear", "linear_1", "linear_v"]:
full = True
for xid in range(n_experts):
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid}.weight"
if ename not in experts:
full = False
break
if not full:
continue
datas = []
for xid in range(n_experts):
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid}.weight"
datas.append(experts[ename])
del experts[ename]
data = np.stack(datas, axis=0)
data_dtype = data.dtype
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
if self.ftype == 1 and data_dtype == np.float32:
data = data.astype(np.float16)
merged_name = f"transformer.decoder_layer.{bid}.moe.{wid}.weight"
new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()
print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)
continue
# map tensor names # map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None: if new_name is None:
@ -1262,21 +1410,6 @@ class LlamaModel(Model):
self.gguf_writer.add_tensor(new_name, data) self.gguf_writer.add_tensor(new_name, data)
@Model.register("GrokForCausalLM")
class GrokModel(Model):
model_arch = gguf.MODEL_ARCH.GROK
def set_vocab(self):
self._set_vocab_sentencepiece()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_name("Grok")
@Model.register("MiniCPMForCausalLM") @Model.register("MiniCPMForCausalLM")
class MiniCPMModel(Model): class MiniCPMModel(Model):
model_arch = gguf.MODEL_ARCH.MINICPM model_arch = gguf.MODEL_ARCH.MINICPM

View file

@ -828,6 +828,15 @@ def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor:
return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description) return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description)
def pack_experts_lazy(lazy_tensors: list[LazyTensor]) -> LazyTensor:
def load() -> Tensor:
tensors = [lazy_tensor.load() for lazy_tensor in lazy_tensors]
return UnquantizedTensor(np.array([tensor.ndarray for tensor in tensors]))
s = lazy_tensors[0].shape.copy()
s.insert(0, len(lazy_tensors))
return LazyTensor(load, s, lazy_tensors[0].data_type, 'pack_experts ' + ' | '.join(lt.description for lt in lazy_tensors))
# Functionality that simulates `torch.load` but where individual tensors are # Functionality that simulates `torch.load` but where individual tensors are
# only loaded into memory on demand, not all at once. # only loaded into memory on demand, not all at once.
# PyTorch can't do this natively as of time of writing: # PyTorch can't do this natively as of time of writing:
@ -1246,6 +1255,22 @@ def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) ->
tmp = model tmp = model
# merge experts into one tensor
if params.n_experts and params.n_experts > 0:
for i_l in range(params.n_layer):
for w in range(1, 4):
experts = []
for e in range(params.n_experts):
if f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight" in model:
experts.append(model[f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight"])
del tmp[f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight"]
elif f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight" in model:
experts.append(model[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"])
del tmp[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"]
else:
raise ValueError(f"Expert tensor not found: layers.{i_l}.feed_forward.experts.{e}.w{w}.weight")
tmp[f"layers.{i_l}.feed_forward.experts.w{w}.weight"] = pack_experts_lazy(experts)
# HF models permut or pack some of the tensors, so we need to undo that # HF models permut or pack some of the tensors, so we need to undo that
for i in itertools.count(): for i in itertools.count():
if f"model.layers.{i}.self_attn.q_proj.weight" in model: if f"model.layers.{i}.self_attn.q_proj.weight" in model:

View file

@ -98,35 +98,38 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
const float * data = is_host ? (const float *) src1->data : m_src1_data.data(); const float * data = is_host ? (const float *) src1->data : m_src1_data.data();
// this has been adapted to the new format of storing merged experts in a single 3d tensor
// ref: https://github.com/ggerganov/llama.cpp/pull/6387
if (t->op == GGML_OP_MUL_MAT_ID) { if (t->op == GGML_OP_MUL_MAT_ID) {
const int idx = ((int32_t *) t->op_params)[0]; const int idx = ((int32_t *) t->op_params)[0];
const int n_as = ((int32_t *) t->op_params)[1]; const ggml_tensor * ids = t->src[2];
const int n_as = src0->ne[2];
// the top-k selected expert ids are stored in the src0 tensor // the top-k selected expert ids are stored in the ids tensor
// for simplicity, always copy src0 to host, because it is small // for simplicity, always copy ids to host, because it is small
// take into account that src0 is not contiguous! // take into account that ids is not contiguous!
GGML_ASSERT(src0->ne[1] == src1->ne[1]); GGML_ASSERT(ids->ne[1] == src1->ne[1]);
GGML_ASSERT(n_as*ggml_nrows(src0)*sizeof(int) == GGML_PAD(ggml_nbytes(src0), n_as*sizeof(int))); GGML_ASSERT(n_as*ggml_nrows(ids)*sizeof(int) == GGML_PAD(ggml_nbytes(ids), n_as*sizeof(int)));
m_ids.resize(ggml_nbytes(src0)/sizeof(int)); m_ids.resize(ggml_nbytes(ids)/sizeof(int));
ggml_backend_tensor_get(src0, m_ids.data(), 0, ggml_nbytes(src0)); ggml_backend_tensor_get(ids, m_ids.data(), 0, ggml_nbytes(ids));
auto & e = m_stats[wname];
++e.ncall;
// NOTE: since we select top-k experts, the number of calls for the expert tensors will be k times larger
// using the following line, we can correct for that if needed by replacing the line above with:
//if (idx == t->src[0]->ne[0] - 1) ++e.ncall;
// loop over all possible experts, regardless if they are used or not in the batch // loop over all possible experts, regardless if they are used or not in the batch
// this is necessary to guarantee equal number of "ncall" for each tensor
for (int ex = 0; ex < n_as; ++ex) { for (int ex = 0; ex < n_as; ++ex) {
src0 = t->src[2 + ex]; size_t e_start = ex*src1->ne[0];
wname = filter_tensor_name(src0->name);
auto& e = m_stats[wname];
if (e.values.empty()) { if (e.values.empty()) {
e.values.resize(src1->ne[0], 0); e.values.resize(src1->ne[0]*n_as, 0);
} }
else if (e.values.size() != (size_t)src1->ne[0]) { else if (e.values.size() != (size_t)src1->ne[0]*n_as) {
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]); fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as);
exit(1); //GGML_ASSERT(false); exit(1); //GGML_ASSERT(false);
} }
// NOTE: since we select top-k experts, the number of calls for the expert tensors will be k times larger
// using the following line, we can correct for that if needed
//if (idx == t->src[0]->ne[0] - 1) ++e.ncall;
++e.ncall;
if (m_params.verbosity > 1) { if (m_params.verbosity > 1) {
printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type); printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type);
} }
@ -136,7 +139,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
if (excur != ex) continue; if (excur != ex) continue;
const float * x = data + row * src1->ne[0]; const float * x = data + row * src1->ne[0];
for (int j = 0; j < (int)src1->ne[0]; ++j) { for (int j = 0; j < (int)src1->ne[0]; ++j) {
e.values[j] += x[j]*x[j]; e.values[e_start + j] += x[j]*x[j];
} }
} }
if (e.ncall > m_last_call) { if (e.ncall > m_last_call) {

View file

@ -116,13 +116,13 @@ static void load_imatrix(const std::string & imatrix_file, std::unordered_map<st
std::ifstream in(imatrix_file.c_str(), std::ios::binary); std::ifstream in(imatrix_file.c_str(), std::ios::binary);
if (!in) { if (!in) {
printf("%s: failed to open %s\n",__func__, imatrix_file.c_str()); printf("%s: failed to open %s\n",__func__, imatrix_file.c_str());
return; exit(1);
} }
int n_entries; int n_entries;
in.read((char *)&n_entries, sizeof(n_entries)); in.read((char *)&n_entries, sizeof(n_entries));
if (in.fail() || n_entries < 1) { if (in.fail() || n_entries < 1) {
printf("%s: no data in file %s\n", __func__, imatrix_file.c_str()); printf("%s: no data in file %s\n", __func__, imatrix_file.c_str());
return; exit(1);
} }
for (int i = 0; i < n_entries; ++i) { for (int i = 0; i < n_entries; ++i) {
int len; in.read((char *)&len, sizeof(len)); int len; in.read((char *)&len, sizeof(len));
@ -130,11 +130,11 @@ static void load_imatrix(const std::string & imatrix_file, std::unordered_map<st
in.read((char *)name_as_vec.data(), len); in.read((char *)name_as_vec.data(), len);
if (in.fail()) { if (in.fail()) {
printf("%s: failed reading name for entry %d from %s\n", __func__, i+1, imatrix_file.c_str()); printf("%s: failed reading name for entry %d from %s\n", __func__, i+1, imatrix_file.c_str());
return; exit(1);
} }
name_as_vec[len] = 0; name_as_vec[len] = 0;
std::string name{name_as_vec.data()}; std::string name{name_as_vec.data()};
auto & e = imatrix_data[std::move(name)]; auto & e = imatrix_data[name];
int ncall; int ncall;
in.read((char *)&ncall, sizeof(ncall)); in.read((char *)&ncall, sizeof(ncall));
int nval; int nval;
@ -142,18 +142,22 @@ static void load_imatrix(const std::string & imatrix_file, std::unordered_map<st
if (in.fail() || nval < 1) { if (in.fail() || nval < 1) {
printf("%s: failed reading number of values for entry %d\n", __func__, i); printf("%s: failed reading number of values for entry %d\n", __func__, i);
imatrix_data = {}; imatrix_data = {};
return; exit(1);
} }
e.resize(nval); e.resize(nval);
in.read((char *)e.data(), nval*sizeof(float)); in.read((char *)e.data(), nval*sizeof(float));
if (in.fail()) { if (in.fail()) {
printf("%s: failed reading data for entry %d\n", __func__, i); printf("%s: failed reading data for entry %d\n", __func__, i);
imatrix_data = {}; imatrix_data = {};
return; exit(1);
} }
if (ncall > 0) { if (ncall > 0) {
for (auto& v : e) v /= ncall; for (auto& v : e) v /= ncall;
} }
if (getenv("LLAMA_TRACE")) {
printf("%s: loaded data (size = %6d, ncall = %6d) for '%s'\n", __func__, int(e.size()), ncall, name.c_str());
}
} }
printf("%s: loaded %d importance matrix entries from %s\n", __func__, int(imatrix_data.size()), imatrix_file.c_str()); printf("%s: loaded %d importance matrix entries from %s\n", __func__, int(imatrix_data.size()), imatrix_file.c_str());
} }

View file

@ -401,10 +401,8 @@ GGML_CALL static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t
GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
if (tensor->view_src != NULL && tensor->view_offs == 0) { if (tensor->view_src != NULL) {
assert(tensor->view_src->buffer->buft == buffer->buft); assert(tensor->view_src->buffer->buft == buffer->buft);
tensor->backend = tensor->view_src->backend;
tensor->extra = tensor->view_src->extra;
return; return;
} }
@ -1962,227 +1960,49 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
} }
} }
#if 0
template<typename ... Srcs>
static __global__ void k_compute_batched_ptrs_id(
const void ** ptrs_src, void ** ptrs_dst,
int ne12, int ne13,
int ne23,
int nb02, int nb03,
int nb12, int nb13,
int nb2, int nb3,
int r2, int r3,
ggml_type src0_type, half * src0_as_f16, int64_t src0_ne,
const half * src1_f16, half * dst_f16,
const int32_t * ids, const int id,
Srcs... src0s) {
int i = ids[id];
half * src0_f16;
const void * srcs_ar[] = { (const half *) src0s... };
if (src0_type == GGML_TYPE_F16) {
src0_f16 = (half *) srcs_ar[i];
} else {
src0_f16 = src0_as_f16;
if (threadIdx.x == 0 && threadIdx.y == 0) {
const to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(src0_type);
to_fp16(srcs_ar[i], src0_f16, src0_ne, cudaStreamFireAndForget);
}
}
int i13 = blockIdx.x * blockDim.x + threadIdx.x;
int i12 = blockIdx.y * blockDim.y + threadIdx.y;
if (i13 >= ne13 || i12 >= ne12) {
return;
}
int i03 = i13 / r3;
int i02 = i12 / r2;
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_f16 + i02*nb02 + i03*nb03;
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_f16 + i12*nb12/2 + i13*nb13/2;
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
}
static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
const struct ggml_tensor * ids = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
const struct ggml_tensor * src00 = dst->src[2];
const int id = dst->op_params[0];
GGML_ASSERT(!ggml_is_transposed(src00));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(src00->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
const int64_t ne00 = src00->ne[0]; GGML_UNUSED(ne00);
const int64_t ne01 = src00->ne[1];
const int64_t ne02 = src00->ne[2];
const int64_t ne03 = src00->ne[3];
//const int64_t nb01 = src00->nb[1];
const int64_t nb02 = src00->nb[2]; GGML_UNUSED(nb02);
const int64_t nb03 = src00->nb[3]; GGML_UNUSED(nb03);
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
const int64_t ne12 = src1->ne[2];
const int64_t ne13 = src1->ne[3];
//const int64_t nb11 = src1->nb[1];
const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
const int64_t ne1 = ggml_nelements(src1);
const int64_t ne = ggml_nelements(dst);
ggml_cuda_set_device(g_main_device);
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));
//ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
//void * src0_ddq = src0_extra->data_device[g_main_device];
//half * src0_as_f16 = (half *) src0_ddq;
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
// convert src1 to fp16
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
size_t src1_as = 0;
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
size_t dst_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
// broadcast factors
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;
const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;
// use cublasGemmBatchedEx
const int ne23 = ne12*ne13;
const void ** ptrs_src = nullptr;
void ** ptrs_dst = nullptr;
size_t ptrs_src_s = 0;
size_t ptrs_dst_s = 0;
ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
int64_t src0_ne = ggml_nelements(src00);
half * src0_as_f16 = nullptr;
size_t src0_as = 0;
if (src00->type != GGML_TYPE_F16) {
src0_as_f16 = (half *) ggml_cuda_pool_malloc(src0_ne * sizeof(half), &src0_as);
}
static_assert(GGML_MAX_SRC == 6, "GGML_MAX_SRC == 6");
dim3 block_dims(ne13, ne12);
k_compute_batched_ptrs_id<<<1, block_dims, 0, main_stream>>>(
ptrs_src, ptrs_dst,
ne12, ne13,
ne23,
ne00*ne01*sizeof(half), ne00*ne01*ne02*sizeof(half),
nb12, nb13,
dst->nb[2], dst->nb[3],
r2, r3,
src00->type, src0_as_f16, src0_ne,
src1_as_f16, dst_f16,
(const int *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device], id,
dst->src[2] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[2]->extra)->data_device[g_main_device] : nullptr,
dst->src[3] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[3]->extra)->data_device[g_main_device] : nullptr,
dst->src[4] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[4]->extra)->data_device[g_main_device] : nullptr,
dst->src[5] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[5]->extra)->data_device[g_main_device] : nullptr
);
CUDA_CHECK(cudaGetLastError());
CUBLAS_CHECK(
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, ne00,
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, ne10,
&beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
ne23,
CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
if (src0_as != 0) {
ggml_cuda_pool_free(src0_as_f16, src0_as);
}
if (ptrs_src_s != 0) {
ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
}
if (ptrs_dst_s != 0) {
ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
}
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
ggml_cuda_pool_free(src1_as_f16, src1_as);
ggml_cuda_pool_free(dst_f16, dst_as);
}
#endif
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
#if 0
ggml_cuda_mul_mat_id_cublas(dst);
// TODO: mmq/mmv support
#endif
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * ids = dst->src[2];
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
cudaStream_t stream = ctx.stream(); cudaStream_t stream = ctx.stream();
const size_t nb11 = src1->nb[1]; const size_t nb11 = src1->nb[1];
const size_t nb1 = dst->nb[1]; const size_t nb1 = dst->nb[1];
const struct ggml_tensor * ids = src0;
const int32_t id = ((int32_t *) dst->op_params)[0]; const int32_t id = ((int32_t *) dst->op_params)[0];
const int32_t n_as = ((int32_t *) dst->op_params)[1]; const int32_t n_as = src0->ne[2];
std::vector<char> ids_host(ggml_nbytes(ids)); std::vector<char> ids_host(ggml_nbytes(ids));
const char * ids_dev = (const char *) ids->data; const char * ids_dev = (const char *) ids->data;
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream)); CUDA_CHECK(cudaStreamSynchronize(stream));
ggml_tensor src0_row = *src0;
ggml_tensor src1_row = *src1; ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst; ggml_tensor dst_row = *dst;
char * src0_original = (char *) src0->data;
char * src1_original = (char *) src1->data; char * src1_original = (char *) src1->data;
char * dst_original = (char *) dst->data; char * dst_original = (char *) dst->data;
src0_row.ne[2] = 1;
src0_row.ne[3] = 1;
src0_row.nb[3] = src0->nb[2];
if (src1->ne[1] == 1) { if (src1->ne[1] == 1) {
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]); const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(row_id >= 0 && row_id < n_as); GGML_ASSERT(row_id >= 0 && row_id < n_as);
const struct ggml_tensor * src0_row = dst->src[row_id + 2]; src0_row.data = src0_original + row_id*src0->nb[2];
src1_row.data = src1_original + i01*src1->nb[1]; src1_row.data = src1_original + i01*src1->nb[1];
dst_row.data = dst_original + i01*dst->nb[1]; dst_row.data = dst_original + i01*dst->nb[1];
ggml_cuda_mul_mat(ctx, src0_row, &src1_row, &dst_row); ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
} }
} else { } else {
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
@ -2192,8 +2012,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
dst_row.data = dst_contiguous.get(); dst_row.data = dst_contiguous.get();
for (int32_t row_id = 0; row_id < n_as; ++row_id) { for (int32_t row_id = 0; row_id < n_as; ++row_id) {
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
int64_t num_src1_rows = 0; int64_t num_src1_rows = 0;
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]); const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
@ -2213,6 +2031,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
continue; continue;
} }
src0_row.data = src0_original + row_id*src0->nb[2];
src1_row.ne[1] = num_src1_rows; src1_row.ne[1] = num_src1_rows;
dst_row.ne[1] = num_src1_rows; dst_row.ne[1] = num_src1_rows;
@ -2224,7 +2044,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
dst_row.nb[2] = num_src1_rows*nb1; dst_row.nb[2] = num_src1_rows*nb1;
dst_row.nb[3] = num_src1_rows*nb1; dst_row.nb[3] = num_src1_rows*nb1;
ggml_cuda_mul_mat(ctx, src0_row, &src1_row, &dst_row); ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
num_src1_rows = 0; num_src1_rows = 0;
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
@ -2389,7 +2209,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) { if (err != cudaSuccess) {
fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst)); fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst));
GGML_ASSERT(false); CUDA_CHECK(err);
} }
return true; return true;

View file

@ -8,32 +8,41 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) {
} }
template<ggml_sort_order order> template<ggml_sort_order order>
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) { static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
// bitonic sort // bitonic sort
int col = threadIdx.x; int col = threadIdx.x;
int row = blockIdx.y; int row = blockIdx.y;
if (col >= ncols) return; if (col >= ncols_pad) {
return;
}
const float * x_row = x + row * ncols; const float * x_row = x + row * ncols;
int * dst_row = dst + row * ncols; extern __shared__ int dst_row[];
// initialize indices // initialize indices
if (col < ncols) {
dst_row[col] = col; dst_row[col] = col;
}
__syncthreads(); __syncthreads();
for (int k = 2; k <= ncols; k *= 2) { for (int k = 2; k <= ncols_pad; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) { for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j; int ixj = col ^ j;
if (ixj > col) { if (ixj > col) {
if ((col & k) == 0) { if ((col & k) == 0) {
if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) { if (dst_row[col] >= ncols ||
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
) {
ggml_cuda_swap(dst_row[col], dst_row[ixj]); ggml_cuda_swap(dst_row[col], dst_row[ixj]);
} }
} else { } else {
if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) { if (dst_row[ixj] >= ncols ||
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
) {
ggml_cuda_swap(dst_row[col], dst_row[ixj]); ggml_cuda_swap(dst_row[col], dst_row[ixj]);
} }
} }
@ -41,18 +50,35 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n
__syncthreads(); __syncthreads();
} }
} }
// copy the result to dst without the padding
if (col < ncols) {
dst[row * ncols + col] = dst_row[col];
}
}
static int next_power_of_2(int x) {
int n = 1;
while (n < x) {
n *= 2;
}
return n;
} }
static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) { static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
// bitonic sort requires ncols to be power of 2 // bitonic sort requires ncols to be power of 2
GGML_ASSERT((ncols & (ncols - 1)) == 0); const int ncols_pad = next_power_of_2(ncols);
const dim3 block_dims(ncols, 1, 1); const dim3 block_dims(ncols_pad, 1, 1);
const dim3 block_nums(1, nrows, 1); const dim3 block_nums(1, nrows, 1);
const size_t shared_mem = ncols_pad * sizeof(int);
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
if (order == GGML_SORT_ORDER_ASC) { if (order == GGML_SORT_ORDER_ASC) {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols); k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
} else if (order == GGML_SORT_ORDER_DESC) { } else if (order == GGML_SORT_ORDER_DESC) {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols); k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
} else { } else {
GGML_ASSERT(false); GGML_ASSERT(false);
} }

View file

@ -1685,37 +1685,31 @@ static enum ggml_status ggml_metal_graph_compute(
{ {
//GGML_ASSERT(ne00 == ne10); //GGML_ASSERT(ne00 == ne10);
//GGML_ASSERT(ne03 == ne13); //GGML_ASSERT(ne03 == ne13);
const int n_as = src0->ne[2];
GGML_ASSERT(src0t == GGML_TYPE_I32);
const int n_as = ((int32_t *) dst->op_params)[1];
// TODO: make this more general
GGML_ASSERT(n_as <= 8);
// max size of the src1ids array in the kernel shared buffer // max size of the src1ids array in the kernel shared buffer
GGML_ASSERT(ne11 <= 4096); GGML_ASSERT(ne11 <= 4096);
const int64_t ne20 = src2 ? src2->ne[0] : 0; // src2 = ids
const int64_t ne21 = src2 ? src2->ne[1] : 0; const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20);
const int64_t ne22 = src2 ? src2->ne[2] : 0; const int64_t ne21 = src2->ne[1];
const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
const uint64_t nb21 = src2 ? src2->nb[1] : 0; const uint64_t nb21 = src2->nb[1];
const uint64_t nb22 = src2 ? src2->nb[2] : 0; const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23); const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
GGML_ASSERT(!ggml_is_transposed(src2)); GGML_ASSERT(src2t == GGML_TYPE_I32);
GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(src1t == GGML_TYPE_F32); GGML_ASSERT(src1t == GGML_TYPE_F32);
const uint r2 = ne12/ne22;
const uint r3 = ne13/ne23;
// find the break-even point where the matrix-matrix kernel becomes more efficient compared // find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel // to the matrix-vector kernel
int ne11_mm_min = n_as; int ne11_mm_min = n_as;
@ -1723,7 +1717,10 @@ static enum ggml_status ggml_metal_graph_compute(
const int idx = ((int32_t *) dst->op_params)[0]; const int idx = ((int32_t *) dst->op_params)[0];
// batch size // batch size
GGML_ASSERT(ne01 == ne11); GGML_ASSERT(ne21 == ne11); // ?
GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
const uint r2 = 1;
const uint r3 = 1;
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
@ -1732,7 +1729,7 @@ static enum ggml_status ggml_metal_graph_compute(
// indirect matrix multiplication // indirect matrix multiplication
// !!! // !!!
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
ne20 % 32 == 0 && ne20 >= 64 && ne00 % 32 == 0 && ne00 >= 64 &&
ne11 > ne11_mm_min) { ne11 > ne11_mm_min) {
// some Metal matrix data types require aligned pointers // some Metal matrix data types require aligned pointers
@ -1745,7 +1742,7 @@ static enum ggml_status ggml_metal_graph_compute(
id<MTLComputePipelineState> pipeline = nil; id<MTLComputePipelineState> pipeline = nil;
switch (src2->type) { switch (src0->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
@ -1774,36 +1771,27 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3]; [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9]; [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10]; [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11]; [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12]; [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16]; [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17]; [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
[encoder setBytes:&idx length:sizeof(idx) atIndex:18]; [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
// TODO: how to make this an array? read Metal docs [encoder setBytes:&idx length:sizeof(idx) atIndex:19];
for (int j = 0; j < 8; ++j) {
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
size_t offs_src_cur = 0;
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
}
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0]; [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else { } else {
int nth0 = 32; int nth0 = 32;
int nth1 = 1; int nth1 = 1;
@ -1813,7 +1801,7 @@ static enum ggml_status ggml_metal_graph_compute(
id<MTLComputePipelineState> pipeline = nil; id<MTLComputePipelineState> pipeline = nil;
// use custom matrix x vector kernel // use custom matrix x vector kernel
switch (src2t) { switch (src0t) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
{ {
GGML_ASSERT(src1t == GGML_TYPE_F32); GGML_ASSERT(src1t == GGML_TYPE_F32);
@ -1947,8 +1935,8 @@ static enum ggml_status ggml_metal_graph_compute(
} }
}; };
if (ggml_is_quantized(src2t)) { if (ggml_is_quantized(src0t)) {
GGML_ASSERT(ne20 >= nth0*nth1); GGML_ASSERT(ne00 >= nth0*nth1);
} }
const int64_t _ne1 = 1; // kernels needs a reference in constant memory const int64_t _ne1 = 1; // kernels needs a reference in constant memory
@ -1957,75 +1945,66 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3]; [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6];
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8]; [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8];
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11]; [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:20]; [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21]; [encoder setBytes:&r2 length:sizeof(r2) atIndex:21];
[encoder setBytes:&idx length:sizeof(idx) atIndex:22]; [encoder setBytes:&r3 length:sizeof(r3) atIndex:22];
// TODO: how to make this an array? read Metal docs [encoder setBytes:&idx length:sizeof(idx) atIndex:23];
for (int j = 0; j < 8; ++j) {
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
size_t offs_src_cur = 0; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur); src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 || src2t == GGML_TYPE_Q5_0 || const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 || src2t == GGML_TYPE_Q2_K ||
src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ1_M || src2t == GGML_TYPE_IQ2_S) {
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
[encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) { else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
[encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) { else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
const int mem_size = 32*sizeof(float); const int mem_size = 32*sizeof(float);
[encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src2t == GGML_TYPE_Q4_K) { else if (src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src2t == GGML_TYPE_Q3_K) { else if (src0t == GGML_TYPE_Q3_K) {
#ifdef GGML_QKK_64 #ifdef GGML_QKK_64
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#else #else
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#endif #endif
} }
else if (src2t == GGML_TYPE_Q5_K) { else if (src0t == GGML_TYPE_Q5_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src2t == GGML_TYPE_Q6_K) { else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else { } else {
const int64_t ny = (_ne1 + nrows - 1)/nrows; const int64_t ny = (_ne1 + nrows - 1)/nrows;
[encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
} }
} break; } break;
@ -2432,6 +2411,16 @@ static enum ggml_status ggml_metal_graph_compute(
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
// bitonic sort requires the number of elements to be power of 2
int64_t ne00_padded = 1;
while (ne00_padded < ne00) {
ne00_padded *= 2;
}
// Metal kernels require the buffer size to be multiple of 16 bytes
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
id<MTLComputePipelineState> pipeline = nil; id<MTLComputePipelineState> pipeline = nil;
switch (order) { switch (order) {
@ -2444,8 +2433,10 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
} break; } break;
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
{ {

File diff suppressed because it is too large Load diff

57
ggml.c
View file

@ -4573,45 +4573,38 @@ void ggml_mul_mat_set_prec(
// ggml_mul_mat_id // ggml_mul_mat_id
// NOTE: id will be removed in the future and instead all the experts listed in ids will be computed
// this will allow computing all the used experts in a single matrix multiplication
struct ggml_tensor * ggml_mul_mat_id( struct ggml_tensor * ggml_mul_mat_id(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * const as[], struct ggml_tensor * as,
int n_as,
struct ggml_tensor * ids, struct ggml_tensor * ids,
int id, int id,
struct ggml_tensor * b) { struct ggml_tensor * b) {
GGML_ASSERT(ids->type == GGML_TYPE_I32); GGML_ASSERT(ids->type == GGML_TYPE_I32);
GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
GGML_ASSERT(ids->ne[1] == b->ne[1]); GGML_ASSERT(ids->ne[1] == b->ne[1]); // must have an expert per b row
GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]); GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2); GGML_ASSERT(id >= 0 && id < ids->ne[0]); // valid id
GGML_ASSERT(id >= 0 && id < ids->ne[0]); GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
bool is_node = false; bool is_node = false;
if (as[0]->grad || b->grad) { if (as->grad || b->grad) {
is_node = true; is_node = true;
} }
const int64_t ne[4] = { as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3] }; const int64_t ne[4] = { as->ne[1], b->ne[1], b->ne[2], b->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
ggml_set_op_params_i32(result, 0, id); ggml_set_op_params_i32(result, 0, id);
ggml_set_op_params_i32(result, 1, n_as);
result->op = GGML_OP_MUL_MAT_ID; result->op = GGML_OP_MUL_MAT_ID;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = ids; result->src[0] = as;
result->src[1] = b; result->src[1] = b;
result->src[2] = ids;
for (int i = 0; i < n_as; i++) {
struct ggml_tensor * a = as[i];
GGML_ASSERT(ggml_are_same_shape(as[0], a));
GGML_ASSERT(ggml_can_mul_mat(a, b));
GGML_ASSERT(!ggml_is_transposed(a));
result->src[i + 2] = a;
}
return result; return result;
} }
@ -10948,10 +10941,9 @@ static void ggml_compute_forward_mul_mat_id(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
const struct ggml_tensor * ids = dst->src[0]; const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1]; const struct ggml_tensor * src1 = dst->src[1];
const struct ggml_tensor * ids = dst->src[2];
const struct ggml_tensor * src0 = dst->src[2]; // only for GGML_TENSOR_BINARY_OP_LOCALS
GGML_TENSOR_BINARY_OP_LOCALS GGML_TENSOR_BINARY_OP_LOCALS
@ -10981,13 +10973,13 @@ static void ggml_compute_forward_mul_mat_id(
GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3); GGML_ASSERT(nb2 <= nb3);
// broadcast factors // broadcast is not supported with mmid
const int64_t r2 = ne12/ne02; assert(ne12 == 1);
const int64_t r3 = ne13/ne03; assert(ne13 == 1);
// row groups // row groups
const int id = ggml_get_op_params_i32(dst, 0); const int id = ggml_get_op_params_i32(dst, 0);
const int n_as = ggml_get_op_params_i32(dst, 1); const int n_as = src0->ne[2];
char * wdata_src1_end = (src1->type == vec_dot_type) ? char * wdata_src1_end = (src1->type == vec_dot_type) ?
(char *) params->wdata : (char *) params->wdata :
@ -11047,7 +11039,7 @@ static void ggml_compute_forward_mul_mat_id(
continue; continue;
} }
const struct ggml_tensor * src0_cur = dst->src[cur_a + 2]; size_t src0_offset = cur_a*src0->nb[2];
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10); const size_t row_size = ggml_row_size(vec_dot_type, ne10);
@ -11082,9 +11074,6 @@ static void ggml_compute_forward_mul_mat_id(
continue; continue;
} }
assert(ne12 % ne02 == 0);
assert(ne13 % ne03 == 0);
// block-tiling attempt // block-tiling attempt
const int64_t blck_0 = 16; const int64_t blck_0 = 16;
const int64_t blck_1 = 16; const int64_t blck_1 = 16;
@ -11101,14 +11090,14 @@ static void ggml_compute_forward_mul_mat_id(
const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11); const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11);
// broadcast src0 into src1 // broadcast src0 into src1
const int64_t i03 = i13/r3; //const int64_t i03 = i13/r3;
const int64_t i02 = i12/r2; //const int64_t i02 = i12/r2;
const int64_t i1 = i11; const int64_t i1 = i11;
const int64_t i2 = i12; const int64_t i2 = i12;
const int64_t i3 = i13; const int64_t i3 = i13;
const char * src0_row = (const char *) src0_cur->data + (0 + i02*nb02 + i03*nb03); const char * src0_row = (const char *) src0->data + src0_offset;
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
@ -18464,13 +18453,13 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
{ {
cur = 0; cur = 0;
const struct ggml_tensor * src0 = node->src[2]; const struct ggml_tensor * src0 = node->src[0];
const struct ggml_tensor * src1 = node->src[1]; const struct ggml_tensor * src1 = node->src[1];
const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
if (src1->type != vec_dot_type) { if (src1->type != vec_dot_type) {
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)); cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
} }
const int n_as = ggml_get_op_params_i32(node, 1); const int n_as = src0->ne[2];
cur += GGML_PAD(cur, sizeof(int64_t)); // align cur += GGML_PAD(cur, sizeof(int64_t)); // align
cur += n_as * sizeof(int64_t); // matrix_row_counts cur += n_as * sizeof(int64_t); // matrix_row_counts
cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows

3
ggml.h
View file

@ -1164,8 +1164,7 @@ extern "C" {
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b) // ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
GGML_API struct ggml_tensor * ggml_mul_mat_id( GGML_API struct ggml_tensor * ggml_mul_mat_id(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * const as[], struct ggml_tensor * as,
int n_as,
struct ggml_tensor * ids, struct ggml_tensor * ids,
int id, int id,
struct ggml_tensor * b); struct ggml_tensor * b);

View file

@ -221,9 +221,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn", MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn",
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate.{xid}", MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",

View file

@ -231,9 +231,8 @@ class TensorNameMap:
), ),
MODEL_TENSOR.FFN_UP_EXP: ( MODEL_TENSOR.FFN_UP_EXP: (
"layers.{bid}.feed_forward.experts.{xid}.w3", # mixtral "layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral "transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
"transformer.decoder_layer.{bid}.moe.{xid}.linear_v", # Grok
), ),
# AWQ-activation gate # AWQ-activation gate
@ -252,9 +251,8 @@ class TensorNameMap:
), ),
MODEL_TENSOR.FFN_GATE_EXP: ( MODEL_TENSOR.FFN_GATE_EXP: (
"layers.{bid}.feed_forward.experts.{xid}.w1", # mixtral "layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", # mixtral "transformer.decoder_layer.{bid}.moe.linear" # Grok (merged)
"transformer.decoder_layer.{bid}.moe.{xid}.linear" # Grok
), ),
# Feed-forward down # Feed-forward down
@ -280,10 +278,8 @@ class TensorNameMap:
), ),
MODEL_TENSOR.FFN_DOWN_EXP: ( MODEL_TENSOR.FFN_DOWN_EXP: (
"layers.{bid}.feed_forward.experts.{xid}.w2", # mixtral "layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", # mixtral "transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
"transformer.decoder_layer.{bid}.moe.{xid}.linear_1", # Grok
), ),
MODEL_TENSOR.ATTN_Q_NORM: ( MODEL_TENSOR.ATTN_Q_NORM: (

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "gguf" name = "gguf"
version = "0.8.0" version = "0.9.0"
description = "Read and write ML models in GGUF for GGML" description = "Read and write ML models in GGUF for GGML"
authors = ["GGML <ggml@ggml.ai>"] authors = ["GGML <ggml@ggml.ai>"]
packages = [ packages = [

371
llama.cpp
View file

@ -426,9 +426,12 @@ enum llm_tensor {
LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_FFN_UP, LLM_TENSOR_FFN_UP,
LLM_TENSOR_FFN_ACT, LLM_TENSOR_FFN_ACT,
LLM_TENSOR_FFN_DOWN_EXP, LLM_TENSOR_FFN_DOWN_EXP, // split experts for backward compatibility
LLM_TENSOR_FFN_GATE_EXP, LLM_TENSOR_FFN_GATE_EXP,
LLM_TENSOR_FFN_UP_EXP, LLM_TENSOR_FFN_UP_EXP,
LLM_TENSOR_FFN_DOWN_EXPS, // merged experts
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_LAYER_OUT_NORM, LLM_TENSOR_LAYER_OUT_NORM,
@ -463,6 +466,9 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
}, },
}, },
{ {
@ -516,6 +522,9 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
}, },
@ -1864,9 +1873,9 @@ struct llama_layer {
// ff MoE // ff MoE
struct ggml_tensor * ffn_gate_inp; struct ggml_tensor * ffn_gate_inp;
struct ggml_tensor * ffn_gate_exp[LLAMA_MAX_EXPERTS]; struct ggml_tensor * ffn_gate_exps;
struct ggml_tensor * ffn_down_exp[LLAMA_MAX_EXPERTS]; struct ggml_tensor * ffn_down_exps;
struct ggml_tensor * ffn_up_exp [LLAMA_MAX_EXPERTS]; struct ggml_tensor * ffn_up_exps ;
// ff bias // ff bias
struct ggml_tensor * ffn_down_b; // b2 struct ggml_tensor * ffn_down_b; // b2
@ -2868,19 +2877,19 @@ struct llama_model_loader {
llama_mmaps mappings; llama_mmaps mappings;
// Holds information on a model weights // Holds information on a model weight
struct llama_tensor_weights { struct llama_tensor_weight {
uint16_t idx; // source file index uint16_t idx; // source file index
size_t offs; // tensor data offset in the original file size_t offs; // tensor data offset in the original file
ggml_tensor * tensor; ggml_tensor * tensor;
llama_tensor_weights(uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) { llama_tensor_weight(uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) {
const int tensor_idx = gguf_find_tensor(gguf_ctx, name); const int tensor_idx = gguf_find_tensor(gguf_ctx, name);
offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx); offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
} }
}; };
std::vector<llama_tensor_weights> weights; std::vector<llama_tensor_weight> weights;
std::unordered_map<std::string, struct llama_model_kv_override> kv_overrides; std::unordered_map<std::string, struct llama_model_kv_override> kv_overrides;
@ -2920,7 +2929,7 @@ struct llama_model_loader {
// For subsidiary files, `meta` tensor data offset must not be used, // For subsidiary files, `meta` tensor data offset must not be used,
// so we build a unified tensors index for weights. // so we build a unified tensors index for weights.
for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
weights.emplace_back(llama_tensor_weights(0, cur->name, meta, cur)); weights.emplace_back(0, cur->name, meta, cur);
} }
files.emplace_back(new llama_file(fname.c_str(), "rb")); files.emplace_back(new llama_file(fname.c_str(), "rb"));
contexts.emplace_back(ctx); contexts.emplace_back(ctx);
@ -2960,7 +2969,7 @@ struct llama_model_loader {
// Save tensors data offset info of the shard. // Save tensors data offset info of the shard.
for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
weights.emplace_back(llama_tensor_weights(idx, cur->name, ctx_gguf, cur)); weights.emplace_back(idx, cur->name, ctx_gguf, cur);
} }
files.emplace_back(new llama_file(split_path, "rb")); files.emplace_back(new llama_file(split_path, "rb"));
contexts.emplace_back(ctx); contexts.emplace_back(ctx);
@ -3164,21 +3173,37 @@ struct llama_model_loader {
return weights.at(i).tensor->name; return weights.at(i).tensor->name;
} }
const llama_tensor_weights & get_weights(const char * name) const { const llama_tensor_weight * get_weight(const char * name) const {
for (const auto & weight : weights) { for (const auto & weight : weights) {
if (strcmp(name, weight.tensor->name) == 0) { if (strcmp(name, weight.tensor->name) == 0) {
return weight; return &weight;
} }
} }
throw std::runtime_error(format("tensor %s not found", name)); return nullptr;
}
const llama_tensor_weight & require_weight(const char * name) const {
const llama_tensor_weight * weight = get_weight(name);
if (!weight) {
throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name));
}
return *weight;
} }
struct ggml_tensor * get_tensor_meta(const char * name) const { struct ggml_tensor * get_tensor_meta(const char * name) const {
try { const auto * weight = get_weight(name);
return get_weights(name).tensor; if (!weight) {
} catch (const std::runtime_error & e) { return nullptr;
return NULL;
} }
return weight->tensor;
}
struct ggml_tensor * require_tensor_meta(const char * name) const {
struct ggml_tensor * tensor = get_tensor_meta(name);
if (!tensor) {
throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name));
}
return tensor;
} }
struct ggml_tensor * get_tensor_meta(int i) const { struct ggml_tensor * get_tensor_meta(int i) const {
@ -3194,7 +3219,7 @@ struct llama_model_loader {
return tensor; return tensor;
} }
struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector<int64_t> & ne, bool required = true) { const struct ggml_tensor * check_tensor_dims(const std::string & name, const std::vector<int64_t> & ne, bool required) const {
const struct ggml_tensor * cur = get_tensor_meta(name.c_str()); const struct ggml_tensor * cur = get_tensor_meta(name.c_str());
if (cur == NULL) { if (cur == NULL) {
@ -3206,8 +3231,8 @@ struct llama_model_loader {
{ {
bool is_ok = true; bool is_ok = true;
for (size_t i = 0; i < ne.size(); ++i) { for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
if (ne[i] != cur->ne[i]) { if ((i < ne.size() && ne[i] != cur->ne[i]) || (i >= ne.size() && cur->ne[i] != 1)) {
is_ok = false; is_ok = false;
break; break;
} }
@ -3221,9 +3246,47 @@ struct llama_model_loader {
} }
} }
return cur;
}
struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector<int64_t> & ne, bool required = true) {
const struct ggml_tensor * cur = check_tensor_dims(name, ne, required);
if (cur == NULL) {
return NULL;
}
return create_tensor_for(ctx, cur); return create_tensor_for(ctx, cur);
} }
struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::vector<int64_t> & ne, size_t offset, bool required = true) {
const struct ggml_tensor * cur = check_tensor_dims(name, ne, required);
if (cur == NULL) {
return NULL;
}
if (cur->type != base->type) {
throw std::runtime_error(format("%s: tensor '%s' has wrong type; expected %s, got %s", __func__, name.c_str(), ggml_type_name(base->type), ggml_type_name(cur->type)));
}
std::array<int64_t, GGML_MAX_DIMS> dims;
for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
dims[i] = i < ne.size() ? ne[i] : 1;
}
struct ggml_tensor * tensor = ggml_view_4d(ctx, base,
dims[0], dims[1], dims[2], dims[3],
cur->nb[1], cur->nb[2], cur->nb[3],
offset);
ggml_set_name(tensor, name.c_str());
n_created++;
return tensor;
}
void done_getting_tensors() const { void done_getting_tensors() const {
if (n_created != n_tensors) { if (n_created != n_tensors) {
throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created));
@ -3236,7 +3299,7 @@ struct llama_model_loader {
mmaps_used.reserve(files.size()); mmaps_used.reserve(files.size());
for (const auto & file : files) { for (const auto & file : files) {
std::unique_ptr<llama_mmap> mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, ggml_is_numa())); std::unique_ptr<llama_mmap> mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, ggml_is_numa()));
mmaps_used.emplace_back(std::make_pair(mapping->size, 0)); mmaps_used.emplace_back(mapping->size, 0);
if (mlock_mmaps) { if (mlock_mmaps) {
std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock()); std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock());
mlock_mmap->init(mapping->addr); mlock_mmap->init(mapping->addr);
@ -3260,18 +3323,25 @@ struct llama_model_loader {
*last = 0; *last = 0;
*addr = mapping->addr; *addr = mapping->addr;
for (ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor; tensor = ggml_get_next_tensor(ctx, tensor)) { for (ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor; tensor = ggml_get_next_tensor(ctx, tensor)) {
const auto & w = get_weights(ggml_get_name(tensor)); try {
if (w.idx != idx) { const auto * weight = get_weight(ggml_get_name(tensor));
if (!weight) {
continue; continue;
} }
*first = std::min(*first, w.offs); if (weight->idx != idx) {
*last = std::max(*last, w.offs + ggml_nbytes(tensor)); continue;
}
*first = std::min(*first, weight->offs);
*last = std::max(*last, weight->offs + ggml_nbytes(tensor));
} catch(...) {
// the tensor is not in the model
}
} }
} }
// for backwards compatibility, does not support ggml-backend // for backwards compatibility, does not support ggml-backend
void load_data_for(struct ggml_tensor * cur) const { void load_data_for(struct ggml_tensor * cur) const {
const auto & w = get_weights(ggml_get_name(cur)); const auto & w = require_weight(ggml_get_name(cur));
if (use_mmap) { if (use_mmap) {
const auto & mapping = mappings.at(w.idx); const auto & mapping = mappings.at(w.idx);
@ -3304,44 +3374,49 @@ struct llama_model_loader {
std::vector<no_init<uint8_t>> read_buf; std::vector<no_init<uint8_t>> read_buf;
for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) { for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) {
const auto * weight = get_weight(ggml_get_name(cur));
if (weight == nullptr) {
// this can happen with split experts models
continue;
}
if (progress_callback) { if (progress_callback) {
if (!progress_callback((float) size_done / size_data, progress_callback_user_data)) { if (!progress_callback((float) size_done / size_data, progress_callback_user_data)) {
return false; return false;
} }
} }
const auto & w = get_weights(ggml_get_name(cur));
size_t n_size = ggml_nbytes(cur); size_t n_size = ggml_nbytes(cur);
if (use_mmap) { if (use_mmap) {
const auto & mapping = mappings.at(w.idx); const auto & mapping = mappings.at(weight->idx);
ggml_backend_buffer_t buf_mmap = nullptr; ggml_backend_buffer_t buf_mmap = nullptr;
if (bufs_mmap.count(w.idx)) { if (bufs_mmap.count(weight->idx)) {
buf_mmap = bufs_mmap.at(w.idx); buf_mmap = bufs_mmap.at(weight->idx);
} }
GGML_ASSERT(buf_mmap || cur->data); // either we have a buffer to allocate the tensor in, or it is already allocated GGML_ASSERT(buf_mmap || cur->data); // either we have a buffer to allocate the tensor in, or it is already allocated
if (buf_mmap && cur->data == nullptr) { if (buf_mmap && cur->data == nullptr) {
ggml_backend_tensor_alloc(buf_mmap, cur, (uint8_t *) mapping->addr + w.offs); ggml_backend_tensor_alloc(buf_mmap, cur, (uint8_t *) mapping->addr + weight->offs);
if (lmlocks) { if (lmlocks) {
const auto & lmlock = lmlocks->at(w.idx); const auto & lmlock = lmlocks->at(weight->idx);
lmlock->grow_to(w.offs + ggml_nbytes(cur)); lmlock->grow_to(weight->offs + ggml_nbytes(cur));
} }
auto & mmap_used = mmaps_used[w.idx]; auto & mmap_used = mmaps_used[weight->idx];
mmap_used.first = std::min(mmap_used.first, w.offs); mmap_used.first = std::min(mmap_used.first, weight->offs);
mmap_used.second = std::max(mmap_used.second, w.offs + n_size); mmap_used.second = std::max(mmap_used.second, weight->offs + n_size);
} else { } else {
ggml_backend_tensor_set(cur, (uint8_t *) mapping->addr + w.offs, 0, n_size); ggml_backend_tensor_set(cur, (uint8_t *) mapping->addr + weight->offs, 0, n_size);
} }
} else { } else {
GGML_ASSERT(w.idx < files.size()); GGML_ASSERT(weight->idx < files.size());
const auto & file = files.at(w.idx); const auto & file = files.at(weight->idx);
if (ggml_backend_buffer_is_host(cur->buffer)) { if (ggml_backend_buffer_is_host(cur->buffer)) {
file->seek(w.offs, SEEK_SET); file->seek(weight->offs, SEEK_SET);
file->read_raw(cur->data, ggml_nbytes(cur)); file->read_raw(cur->data, ggml_nbytes(cur));
} else { } else {
read_buf.resize(ggml_nbytes(cur)); read_buf.resize(ggml_nbytes(cur));
file->seek(w.offs, SEEK_SET); file->seek(weight->offs, SEEK_SET);
file->read_raw(read_buf.data(), ggml_nbytes(cur)); file->read_raw(read_buf.data(), ggml_nbytes(cur));
ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
} }
@ -4270,6 +4345,7 @@ static bool llm_load_tensors(
const int64_t n_layer = hparams.n_layer; const int64_t n_layer = hparams.n_layer;
const int64_t i_gpu_start = std::max((int64_t) hparams.n_layer - n_gpu_layers, (int64_t) 0); const int64_t i_gpu_start = std::max((int64_t) hparams.n_layer - n_gpu_layers, (int64_t) 0);
bool use_mmap_buffer = true;
// there is very little benefit to offloading the input layer, so always keep it on the CPU // there is very little benefit to offloading the input layer, so always keep it on the CPU
model.buft_input = llama_default_buffer_type_cpu(true); model.buft_input = llama_default_buffer_type_cpu(true);
@ -4358,6 +4434,10 @@ static bool llm_load_tensors(
// create one context per buffer type // create one context per buffer type
size_t ctx_size = ggml_tensor_overhead()*(ml.n_tensors + 1); // +1 for models where tok_embd is duplicated as output size_t ctx_size = ggml_tensor_overhead()*(ml.n_tensors + 1); // +1 for models where tok_embd is duplicated as output
// for moe merged tensors
ctx_size += ggml_tensor_overhead()*hparams.n_expert*n_layer;
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map; std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
for (auto & it : buft_layer_count) { for (auto & it : buft_layer_count) {
struct ggml_init_params params = { struct ggml_init_params params = {
@ -4384,6 +4464,11 @@ static bool llm_load_tensors(
const int64_t n_vocab = hparams.n_vocab; const int64_t n_vocab = hparams.n_vocab;
const int64_t n_vocab_type = hparams.n_vocab_type; const int64_t n_vocab_type = hparams.n_vocab_type;
const int64_t n_ff = hparams.n_ff; const int64_t n_ff = hparams.n_ff;
const int64_t n_expert = hparams.n_expert;
if (n_expert > 0 && hparams.n_expert_used == 0) {
throw std::runtime_error("model has expert layers but no expert layers are used");
}
GGML_ASSERT(n_embd_gqa == n_embd_k_gqa); GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
@ -4438,30 +4523,50 @@ static bool llm_load_tensors(
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd}, false); if (n_expert == 0) {
if (layer.ffn_gate_inp == nullptr) {
GGML_ASSERT(hparams.n_expert == 0);
GGML_ASSERT(hparams.n_expert_used == 0);
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
} else { } else {
GGML_ASSERT(hparams.n_expert > 0); layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
GGML_ASSERT(hparams.n_expert_used > 0);
// MoE branch layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false);
for (uint32_t x = 0; x < hparams.n_expert; ++x) { if (layer.ffn_gate_exps) {
layer.ffn_gate_exp[x] = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), {n_embd, n_ff}); layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert});
layer.ffn_down_exp[x] = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd}); layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert});
layer.ffn_up_exp[x] = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, x), {n_embd, n_ff}); } else {
// merge split expert into a single tensor for compatibility with older models
// requires disabling mmap
use_mmap_buffer = false;
ggml_type type_gate = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, 0).c_str())->type;
ggml_type type_down = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, 0).c_str())->type;
ggml_type type_up = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, 0).c_str())->type;
layer.ffn_gate_exps = ggml_new_tensor_3d(ctx_split, type_gate, n_embd, n_ff, n_expert);
layer.ffn_down_exps = ggml_new_tensor_3d(ctx_split, type_down, n_ff, n_embd, n_expert);
layer.ffn_up_exps = ggml_new_tensor_3d(ctx_split, type_up, n_embd, n_ff, n_expert);
ggml_set_name(layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i).c_str());
ggml_set_name(layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i).c_str());
ggml_set_name(layer.ffn_up_exps, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i).c_str());
for (uint32_t x = 0; x < n_expert; ++x) {
// the individual experts are loaded into a view of the merged tensor
ml.create_tensor_as_view(ctx_split, layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), { n_embd, n_ff }, layer.ffn_gate_exps->nb[2]*x);
ml.create_tensor_as_view(ctx_split, layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd }, layer.ffn_down_exps->nb[2]*x);
ml.create_tensor_as_view(ctx_split, layer.ffn_up_exps, tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, x), { n_embd, n_ff }, layer.ffn_up_exps->nb[2]*x);
}
} }
} }
} }
} break; } break;
case LLM_ARCH_GROK: case LLM_ARCH_GROK:
{ {
if (n_expert == 0) {
throw std::runtime_error("Grok model cannot have zero experts");
}
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
// output // output
@ -4493,16 +4598,35 @@ static bool llm_load_tensors(
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd}); layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
GGML_ASSERT(hparams.n_expert > 0); layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false);
GGML_ASSERT(hparams.n_expert_used > 0); if (layer.ffn_gate_exps) {
layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert});
layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert});
} else {
// merge split expert into a single tensor for compatibility with older models
// requires disabling mmap
use_mmap_buffer = false;
// MoE branch ggml_type type_gate = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, 0).c_str())->type;
for (uint32_t x = 0; x < hparams.n_expert; ++x) { ggml_type type_down = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, 0).c_str())->type;
layer.ffn_gate_exp[x] = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), {n_embd, n_ff}); ggml_type type_up = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, 0).c_str())->type;
layer.ffn_down_exp[x] = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd});
layer.ffn_up_exp[x] = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, x), {n_embd, n_ff}); layer.ffn_gate_exps = ggml_new_tensor_3d(ctx_split, type_gate, n_embd, n_ff, n_expert);
layer.ffn_down_exps = ggml_new_tensor_3d(ctx_split, type_down, n_ff, n_embd, n_expert);
layer.ffn_up_exps = ggml_new_tensor_3d(ctx_split, type_up, n_embd, n_ff, n_expert);
ggml_set_name(layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i).c_str());
ggml_set_name(layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i).c_str());
ggml_set_name(layer.ffn_up_exps, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i).c_str());
for (uint32_t x = 0; x < n_expert; ++x) {
// the individual experts are loaded into a view of the merged tensor
ml.create_tensor_as_view(ctx_split, layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), { n_embd, n_ff }, layer.ffn_gate_exps->nb[2]*x);
ml.create_tensor_as_view(ctx_split, layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd }, layer.ffn_down_exps->nb[2]*x);
ml.create_tensor_as_view(ctx_split, layer.ffn_up_exps, tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, x), { n_embd, n_ff }, layer.ffn_up_exps->nb[2]*x);
}
} }
layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
@ -5308,7 +5432,7 @@ static bool llm_load_tensors(
// only the mmap region containing the tensors in the model is mapped to the backend buffer // only the mmap region containing the tensors in the model is mapped to the backend buffer
// this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers
// this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size
if (ml.use_mmap && buft == llama_default_buffer_type_cpu(true)) { if (ml.use_mmap && use_mmap_buffer && buft == llama_default_buffer_type_cpu(true)) {
for (uint32_t idx = 0; idx < ml.files.size(); idx++) { for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
void * addr = nullptr; void * addr = nullptr;
size_t first, last; size_t first, last;
@ -5332,7 +5456,7 @@ static bool llm_load_tensors(
} }
} }
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
else if (ml.use_mmap && buft == ggml_backend_metal_buffer_type()) { else if (ml.use_mmap && use_mmap_buffer && buft == ggml_backend_metal_buffer_type()) {
for (uint32_t idx = 0; idx < ml.files.size(); idx++) { for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
const size_t max_size = ggml_get_max_tensor_size(ctx); const size_t max_size = ggml_get_max_tensor_size(ctx);
void * addr = nullptr; void * addr = nullptr;
@ -5415,9 +5539,11 @@ static bool llm_load_tensors(
} }
} }
if (use_mmap_buffer) {
for (auto & mapping : ml.mappings) { for (auto & mapping : ml.mappings) {
model.mappings.emplace_back(std::move(mapping)); model.mappings.emplace_back(std::move(mapping));
} }
}
// loading time will be recalculate after the first eval, so // loading time will be recalculate after the first eval, so
// we take page faults deferred by mmap() into consideration // we take page faults deferred by mmap() into consideration
@ -6284,19 +6410,19 @@ struct llm_build_context {
for (int i = 0; i < n_expert_used; ++i) { for (int i = 0; i < n_expert_used; ++i) {
ggml_tensor * cur_expert; ggml_tensor * cur_expert;
ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur); ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
cb(cur_up, "ffn_moe_up", il); cb(cur_up, "ffn_moe_up", il);
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exp, n_expert, selected_experts, i, cur); ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
cb(cur_gate, "ffn_moe_gate", il); cb(cur_gate, "ffn_moe_gate", il);
cur_gate = ggml_silu(ctx0, cur_gate); cur_gate = ggml_silu(ctx0, cur_gate);
cb(cur_gate, "ffn_moe_silu", il); cb(cur_gate, "ffn_moe_silu", il);
cur_expert = ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd] cur_expert = ggml_mul(ctx0, cur_up, cur_gate);
cb(cur_expert, "ffn_moe_gate_par", il); cb(cur_expert, "ffn_moe_gate_par", il);
cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exp, n_expert, selected_experts, i, cur_expert); // [n_tokens, n_embd] cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
cb(cur_expert, "ffn_moe_down", il); cb(cur_expert, "ffn_moe_down", il);
cur_expert = ggml_mul(ctx0, cur_expert, cur_expert = ggml_mul(ctx0, cur_expert,
@ -6818,20 +6944,20 @@ struct llm_build_context {
for (int i = 0; i < n_expert_used; ++i) { for (int i = 0; i < n_expert_used; ++i) {
ggml_tensor * cur_expert; ggml_tensor * cur_expert;
ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur); ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
cb(cur_up, "ffn_moe_up", il); cb(cur_up, "ffn_moe_up", il);
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exp, n_expert, selected_experts, i, cur); ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
cb(cur_gate, "ffn_moe_gate", il); cb(cur_gate, "ffn_moe_gate", il);
//GeLU //GeLU
cur_gate = ggml_gelu(ctx0, cur_gate); cur_gate = ggml_gelu(ctx0, cur_gate);
cb(cur_gate, "ffn_moe_gelu", il); cb(cur_gate, "ffn_moe_gelu", il);
cur_expert = ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd] cur_expert = ggml_mul(ctx0, cur_up, cur_gate);
cb(cur_expert, "ffn_moe_gate_par", il); cb(cur_expert, "ffn_moe_gate_par", il);
cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exp, n_expert, selected_experts, i, cur_expert); // [n_tokens, n_embd] cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
cb(cur_expert, "ffn_moe_down", il); cb(cur_expert, "ffn_moe_down", il);
cur_expert = ggml_mul(ctx0, cur_expert, cur_expert = ggml_mul(ctx0, cur_expert,
@ -12902,7 +13028,6 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
// sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work // sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work
// for getting the current layer as I initially thought, and we need to resort to parsing the // for getting the current layer as I initially thought, and we need to resort to parsing the
// tensor name. // tensor name.
n_layer /= n_expert;
if (sscanf(name, "blk.%d.", &i_layer) != 1) { if (sscanf(name, "blk.%d.", &i_layer) != 1) {
throw std::runtime_error(format("Failed to determine layer for tensor %s", name)); throw std::runtime_error(format("Failed to determine layer for tensor %s", name));
} }
@ -13264,7 +13389,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
kv_overrides = v->data(); kv_overrides = v->data();
} }
llama_model_loader ml(fname_inp, use_mmap, kv_overrides); llama_model_loader ml(fname_inp, use_mmap, kv_overrides);
ml.init_mappings(false); // no prefetching? ml.init_mappings(false); // no prefetching
llama_model model; llama_model model;
llm_load_arch(ml, model); llm_load_arch(ml, model);
@ -13316,20 +13441,15 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
// TODO: avoid hardcoded tensor names - use the TN_* constants // TODO: avoid hardcoded tensor names - use the TN_* constants
if (name.find("attn_v.weight") != std::string::npos || name.find("attn_qkv.weight") != std::string::npos) { if (name.find("attn_v.weight") != std::string::npos || name.find("attn_qkv.weight") != std::string::npos) {
++qs.n_attention_wv; ++qs.n_attention_wv;
} else if (name.find("ffn_down") != std::string::npos) {
++qs.n_ffn_down;
} else if (name.find("ffn_gate") != std::string::npos) {
++qs.n_ffn_gate;
} else if (name.find("ffn_up") != std::string::npos) {
++qs.n_ffn_up;
} else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) { } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
qs.has_output = true; qs.has_output = true;
} }
} }
if (qs.n_attention_wv != qs.n_ffn_down || (uint32_t) qs.n_attention_wv != model.hparams.n_layer) {
LLAMA_LOG_WARN("%s ============ Strange model: n_attention_wv = %d, n_ffn_down = %d, hparams.n_layer = %d\n", qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
__func__, qs.n_attention_wv, qs.n_ffn_down, model.hparams.n_layer);
} // sanity checks
GGML_ASSERT(qs.n_attention_wv == (int)model.hparams.n_layer && "n_attention_wv != n_layer is unexpected");
size_t total_size_org = 0; size_t total_size_org = 0;
size_t total_size_new = 0; size_t total_size_new = 0;
@ -13359,6 +13479,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
// placeholder for the meta data // placeholder for the meta data
::zeros(fout, meta_size); ::zeros(fout, meta_size);
const auto tn = LLM_TN(model.arch);
for (int i = 0; i < ml.n_tensors; ++i) { for (int i = 0; i < ml.n_tensors; ++i) {
struct ggml_tensor * tensor = ml.get_tensor_meta(i); struct ggml_tensor * tensor = ml.get_tensor_meta(i);
@ -13381,8 +13503,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
// This used to be a regex, but <regex> has an extreme cost to compile times. // This used to be a regex, but <regex> has an extreme cost to compile times.
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
// quantize only 2D tensors // quantize only 2D and 3D tensors (experts)
quantize &= (ggml_n_dims(tensor) == 2); quantize &= (ggml_n_dims(tensor) >= 2);
quantize &= params->quantize_output_tensor || name != "output.weight"; quantize &= params->quantize_output_tensor || name != "output.weight";
quantize &= !params->only_copy; quantize &= !params->only_copy;
@ -13437,11 +13559,20 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (it == imatrix_data->end()) { if (it == imatrix_data->end()) {
LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
} else { } else {
if (it->second.size() == (size_t)tensor->ne[0]) { if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) {
imatrix = it->second.data(); imatrix = it->second.data();
} else { } else {
LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__, LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__,
int(it->second.size()), int(tensor->ne[0]), tensor->name); int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name);
// this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix
// this is a significant error and it may be good idea to abort the process if this happens,
// since many people will miss the error and not realize that most of the model is being quantized without an imatrix
// tok_embd should be ignored in this case, since it always causes this warning
if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) {
throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s",
int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name));
}
} }
} }
} }
@ -13478,15 +13609,24 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
new_data = work.data(); new_data = work.data();
const int n_per_row = tensor->ne[0]; const int n_per_row = tensor->ne[0];
const int nrows = nelements / n_per_row; const int nrows = tensor->ne[1];
static const int min_chunk_size = 32 * 512; static const int min_chunk_size = 32 * 512;
const int chunk_size = n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row); const int chunk_size = n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row);
const int nchunk = (nelements + chunk_size - 1)/chunk_size; const int nelements_matrix = tensor->ne[0] * tensor->ne[1];
const int nchunk = (nelements_matrix + chunk_size - 1)/chunk_size;
const int nthread_use = nthread > 1 ? std::max(1, std::min(nthread, nchunk)) : 1; const int nthread_use = nthread > 1 ? std::max(1, std::min(nthread, nchunk)) : 1;
new_size = llama_tensor_quantize_internal(new_type, f32_data, new_data, chunk_size, nrows, n_per_row, imatrix, workers, nthread_use);
// quantize each expert separately since they have different importance matrices
new_size = 0;
for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) {
const float * f32_data_03 = f32_data + i03 * nelements_matrix;
void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows;
const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
new_size += llama_tensor_quantize_internal(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
}
LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
} }
total_size_org += ggml_nbytes(tensor); total_size_org += ggml_nbytes(tensor);
@ -15697,6 +15837,55 @@ static int32_t llama_chat_apply_template_internal(
ss << message->content << "</s>"; ss << message->content << "</s>";
} }
} }
} else if (tmpl == "openchat" || tmpl.find("GPT4 Correct ") != std::string::npos) {
// openchat/openchat-3.5-0106,
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << message->content << "<|end_of_turn|>";
} else {
role[0] = toupper(role[0]);
ss << "GPT4 Correct " << role << ": " << message->content << "<|end_of_turn|>";
}
}
if (add_ass) {
ss << "GPT4 Correct Assistant:";
}
} else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl.find("USER: ") != std::string::npos && tmpl.find("ASSISTANT: ") != std::string::npos)) {
// eachadea/vicuna-13b-1.1 (and Orca variant)
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
// Orca-Vicuna variant uses a system prefix
if (tmpl == "vicuna-orca" || tmpl.find("SYSTEM: ") != std::string::npos) {
ss << "SYSTEM: " << message->content << "\n";
} else {
ss << message->content << "\n\n";
}
} else if (role == "user") {
ss << "USER: " << message->content << "\n";
} else if (role == "assistant") {
ss << "ASSISTANT: " << message->content << "</s>\n";
}
}
if (add_ass) {
ss << "ASSISTANT:";
}
} else if (tmpl == "deepseek" || (tmpl.find("### Instruction:") != std::string::npos && tmpl.find("<|EOT|>") != std::string::npos)) {
// deepseek-ai/deepseek-coder-33b-instruct
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << message->content;
} else if (role == "user") {
ss << "### Instruction:\n" << message->content << "\n";
} else if (role == "assistant") {
ss << "### Response:\n" << message->content << "\n<|EOT|>\n";
}
}
if (add_ass) {
ss << "### Response:\n";
}
} else { } else {
// template not supported // template not supported
return -1; return -1;

View file

@ -979,17 +979,13 @@ struct test_mul_mat_id : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
// C^T = A * B^T: (k, m) * (k, n) => (m, n) // C^T = A * B^T: (k, m) * (k, n) => (m, n)
std::vector<ggml_tensor *> mats; ggml_tensor * mats = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);
for (int i = 0; i < n_mats; i++) {
ggml_tensor * a = ggml_new_tensor_2d(ctx, type_a, k, m);
mats.push_back(a);
}
ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n); ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
if (v) { if (v) {
ids = ggml_view_2d(ctx, ids, n_mats/2, ids->ne[1], ids->nb[1], 0); ids = ggml_view_2d(ctx, ids, n_mats/2, ids->ne[1], ids->nb[1], 0);
} }
ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, k, n); ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, k, n);
ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), n_mats, ids, v ? id/2 : id, b); ggml_tensor * out = ggml_mul_mat_id(ctx, mats, ids, v ? id/2 : id, b);
return out; return out;
} }
@ -1477,91 +1473,6 @@ struct test_leaky_relu : public test_case {
} }
}; };
// Mixtral MOE
struct test_moe : public test_case {
const int n_experts;
const int n_experts_per_tok;
const int n_tokens;
const int n_embd;
const int n_ff;
std::string op_desc(ggml_tensor * t) override {
return "MOE";
GGML_UNUSED(t);
}
std::string vars() override {
return VARS_TO_STR5(n_experts, n_experts_per_tok, n_tokens, n_embd, n_ff);
}
test_moe(int n_experts = 8, int n_experts_per_tok = 2, int n_tokens = 1, int n_embd = 4096, int n_ff = 14336)
: n_experts(n_experts), n_experts_per_tok(n_experts_per_tok), n_tokens(n_tokens), n_embd(n_embd), n_ff(n_ff) {
}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * ffn_gate_inp = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_experts);
std::vector<ggml_tensor *> ffn_up_exp(n_experts);
std::vector<ggml_tensor *> ffn_gate_exp(n_experts);
std::vector<ggml_tensor *> ffn_down_exp(n_experts);
for (int i = 0; i < n_experts; ++i) {
ffn_up_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
ffn_gate_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
ffn_down_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
}
ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur);
ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, nullptr, 1.0f/sqrtf(n_embd), 0.0f);
// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok);
ggml_tensor * weights = ggml_get_rows(ctx,
ggml_reshape_3d(ctx, probs, 1, n_experts, n_tokens), selected_experts);
weights = ggml_reshape_2d(ctx, weights, n_experts_per_tok, n_tokens);
ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights);
weights = ggml_div(ctx, weights, weights_sum);
// compute expert outputs
ggml_tensor * moe_out = nullptr;
for (int i = 0; i < n_experts_per_tok; ++i) {
ggml_tensor * cur_expert;
ggml_tensor * cur_up = ggml_mul_mat_id(ctx, ffn_up_exp.data(), n_experts, selected_experts, i, cur);
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx, ffn_gate_exp.data(), n_experts, selected_experts, i, cur);
cur_gate = ggml_silu(ctx, cur_gate);
cur_expert = ggml_mul(ctx, cur_up, cur_gate);
cur_expert = ggml_mul_mat_id(ctx, ffn_down_exp.data(), n_experts, selected_experts, i, cur_expert);
cur_expert = ggml_mul(ctx, cur_expert,
ggml_view_2d(ctx, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
if (i == 0) {
moe_out = cur_expert;
} else {
moe_out = ggml_add(ctx, moe_out, cur_expert);
}
}
cur = moe_out;
return cur;
}
};
enum llm_norm_type { enum llm_norm_type {
LLM_NORM, LLM_NORM,
LLM_NORM_RMS, LLM_NORM_RMS,
@ -2169,6 +2080,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) { for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
} }
test_cases.emplace_back(new test_sum_rows()); test_cases.emplace_back(new test_sum_rows());
@ -2182,11 +2094,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
// these tests are disabled to save execution time, but they can be handy for debugging // these tests are disabled to save execution time, but they can be handy for debugging
#if 0 #if 0
#if !defined(__SANITIZE_THREAD__)
// FIXME: these tests use too much memory with thread sanitizer
test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024));
//test_cases.emplace_back(new test_moe(8, 2, 8, 4096, 14336));
#endif
test_cases.emplace_back(new test_llama(1)); test_cases.emplace_back(new test_llama(1));
test_cases.emplace_back(new test_llama(2)); test_cases.emplace_back(new test_llama(2));
test_cases.emplace_back(new test_falcon(1)); test_cases.emplace_back(new test_falcon(1));

View file

@ -33,6 +33,18 @@ int main(void) {
"{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}", "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}",
// OrionStarAI/Orion-14B-Chat // OrionStarAI/Orion-14B-Chat
"{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}", "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
// openchat/openchat-3.5-0106
// The included chat_template differs from the author's suggestions here: https://huggingface.co/openchat/openchat_3.5/discussions/5#65448109b4a3f3a2f486fd9d
// So we match against the included template but implement the suggested version.
"{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}",
// deepseek-ai/deepseek-coder-33b-instruct
"{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}",
// eachadea/vicuna-13b-1.1
// No template included in tokenizer_config.json, so this template likely needs to be manually set.
"{%- for message in messages %}{%- if message['role'] == 'system' -%}{{- '' + message['content'] + '\n\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '</s>\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}",
// Orca-Vicuna
// No template included in tokenizer_config.json, so this template likely needs to be manually set.
"{%- for message in messages %}{%- if message['role'] == 'system' -%}{{-'SYSTEM: ' + message['content'] + '\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '</s>\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}",
}; };
std::vector<std::string> expected_output = { std::vector<std::string> expected_output = {
// teknium/OpenHermes-2.5-Mistral-7B // teknium/OpenHermes-2.5-Mistral-7B
@ -49,6 +61,14 @@ int main(void) {
"<start_of_turn>user\nYou are a helpful assistant\n\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n", "<start_of_turn>user\nYou are a helpful assistant\n\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
// OrionStarAI/Orion-14B-Chat // OrionStarAI/Orion-14B-Chat
"Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>", "Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
// openchat/openchat-3.5-0106
"You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:",
// deepseek-ai/deepseek-coder-33b-instruct
"You are a helpful assistant### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\n I am an assistant \n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n",
// eachadea/vicuna-13b-1.1
"You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there</s>\nUSER: Who are you\nASSISTANT: I am an assistant </s>\nUSER: Another question\nASSISTANT:",
// Orca-Vicuna
"SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there</s>\nUSER: Who are you\nASSISTANT: I am an assistant </s>\nUSER: Another question\nASSISTANT:",
}; };
std::vector<char> formatted_chat(1024); std::vector<char> formatted_chat(1024);
int32_t res; int32_t res;