Merge branch 'ggerganov:master' into add-inferenceable
This commit is contained in:
commit
d3179662cd
18 changed files with 345 additions and 151 deletions
|
@ -178,6 +178,7 @@ struct cmd_params {
|
||||||
std::vector<ggml_type> type_v;
|
std::vector<ggml_type> type_v;
|
||||||
std::vector<int> n_threads;
|
std::vector<int> n_threads;
|
||||||
std::vector<int> n_gpu_layers;
|
std::vector<int> n_gpu_layers;
|
||||||
|
std::vector<std::string> rpc_servers;
|
||||||
std::vector<llama_split_mode> split_mode;
|
std::vector<llama_split_mode> split_mode;
|
||||||
std::vector<int> main_gpu;
|
std::vector<int> main_gpu;
|
||||||
std::vector<bool> no_kv_offload;
|
std::vector<bool> no_kv_offload;
|
||||||
|
@ -202,6 +203,7 @@ static const cmd_params cmd_params_defaults = {
|
||||||
/* type_v */ {GGML_TYPE_F16},
|
/* type_v */ {GGML_TYPE_F16},
|
||||||
/* n_threads */ {cpu_get_num_math()},
|
/* n_threads */ {cpu_get_num_math()},
|
||||||
/* n_gpu_layers */ {99},
|
/* n_gpu_layers */ {99},
|
||||||
|
/* rpc_servers */ {""},
|
||||||
/* split_mode */ {LLAMA_SPLIT_MODE_LAYER},
|
/* split_mode */ {LLAMA_SPLIT_MODE_LAYER},
|
||||||
/* main_gpu */ {0},
|
/* main_gpu */ {0},
|
||||||
/* no_kv_offload */ {false},
|
/* no_kv_offload */ {false},
|
||||||
|
@ -230,6 +232,7 @@ static void print_usage(int /* argc */, char ** argv) {
|
||||||
printf(" -ctv, --cache-type-v <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str());
|
printf(" -ctv, --cache-type-v <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str());
|
||||||
printf(" -t, --threads <n> (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str());
|
printf(" -t, --threads <n> (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str());
|
||||||
printf(" -ngl, --n-gpu-layers <n> (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str());
|
printf(" -ngl, --n-gpu-layers <n> (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str());
|
||||||
|
printf(" -rpc, --rpc <rpc_servers> (default: %s)\n", join(cmd_params_defaults.rpc_servers, ",").c_str());
|
||||||
printf(" -sm, --split-mode <none|layer|row> (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
|
printf(" -sm, --split-mode <none|layer|row> (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
|
||||||
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
|
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
|
||||||
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
|
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
|
||||||
|
@ -384,6 +387,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
auto p = split<int>(argv[i], split_delim);
|
auto p = split<int>(argv[i], split_delim);
|
||||||
params.n_gpu_layers.insert(params.n_gpu_layers.end(), p.begin(), p.end());
|
params.n_gpu_layers.insert(params.n_gpu_layers.end(), p.begin(), p.end());
|
||||||
|
} else if (arg == "-rpc" || arg == "--rpc") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.rpc_servers.push_back(argv[i]);
|
||||||
} else if (arg == "-sm" || arg == "--split-mode") {
|
} else if (arg == "-sm" || arg == "--split-mode") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -519,6 +528,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||||
if (params.type_k.empty()) { params.type_k = cmd_params_defaults.type_k; }
|
if (params.type_k.empty()) { params.type_k = cmd_params_defaults.type_k; }
|
||||||
if (params.type_v.empty()) { params.type_v = cmd_params_defaults.type_v; }
|
if (params.type_v.empty()) { params.type_v = cmd_params_defaults.type_v; }
|
||||||
if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; }
|
if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; }
|
||||||
|
if (params.rpc_servers.empty()) { params.rpc_servers = cmd_params_defaults.rpc_servers; }
|
||||||
if (params.split_mode.empty()) { params.split_mode = cmd_params_defaults.split_mode; }
|
if (params.split_mode.empty()) { params.split_mode = cmd_params_defaults.split_mode; }
|
||||||
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
|
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
|
||||||
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
|
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
|
||||||
|
@ -541,6 +551,7 @@ struct cmd_params_instance {
|
||||||
ggml_type type_v;
|
ggml_type type_v;
|
||||||
int n_threads;
|
int n_threads;
|
||||||
int n_gpu_layers;
|
int n_gpu_layers;
|
||||||
|
std::string rpc_servers;
|
||||||
llama_split_mode split_mode;
|
llama_split_mode split_mode;
|
||||||
int main_gpu;
|
int main_gpu;
|
||||||
bool no_kv_offload;
|
bool no_kv_offload;
|
||||||
|
@ -553,6 +564,9 @@ struct cmd_params_instance {
|
||||||
llama_model_params mparams = llama_model_default_params();
|
llama_model_params mparams = llama_model_default_params();
|
||||||
|
|
||||||
mparams.n_gpu_layers = n_gpu_layers;
|
mparams.n_gpu_layers = n_gpu_layers;
|
||||||
|
if (!rpc_servers.empty()) {
|
||||||
|
mparams.rpc_servers = rpc_servers.c_str();
|
||||||
|
}
|
||||||
mparams.split_mode = split_mode;
|
mparams.split_mode = split_mode;
|
||||||
mparams.main_gpu = main_gpu;
|
mparams.main_gpu = main_gpu;
|
||||||
mparams.tensor_split = tensor_split.data();
|
mparams.tensor_split = tensor_split.data();
|
||||||
|
@ -564,6 +578,7 @@ struct cmd_params_instance {
|
||||||
bool equal_mparams(const cmd_params_instance & other) const {
|
bool equal_mparams(const cmd_params_instance & other) const {
|
||||||
return model == other.model &&
|
return model == other.model &&
|
||||||
n_gpu_layers == other.n_gpu_layers &&
|
n_gpu_layers == other.n_gpu_layers &&
|
||||||
|
rpc_servers == other.rpc_servers &&
|
||||||
split_mode == other.split_mode &&
|
split_mode == other.split_mode &&
|
||||||
main_gpu == other.main_gpu &&
|
main_gpu == other.main_gpu &&
|
||||||
use_mmap == other.use_mmap &&
|
use_mmap == other.use_mmap &&
|
||||||
|
@ -592,6 +607,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||||
// this ordering minimizes the number of times that each model needs to be reloaded
|
// this ordering minimizes the number of times that each model needs to be reloaded
|
||||||
for (const auto & m : params.model)
|
for (const auto & m : params.model)
|
||||||
for (const auto & nl : params.n_gpu_layers)
|
for (const auto & nl : params.n_gpu_layers)
|
||||||
|
for (const auto & rpc : params.rpc_servers)
|
||||||
for (const auto & sm : params.split_mode)
|
for (const auto & sm : params.split_mode)
|
||||||
for (const auto & mg : params.main_gpu)
|
for (const auto & mg : params.main_gpu)
|
||||||
for (const auto & ts : params.tensor_split)
|
for (const auto & ts : params.tensor_split)
|
||||||
|
@ -618,6 +634,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||||
/* .type_v = */ tv,
|
/* .type_v = */ tv,
|
||||||
/* .n_threads = */ nt,
|
/* .n_threads = */ nt,
|
||||||
/* .n_gpu_layers = */ nl,
|
/* .n_gpu_layers = */ nl,
|
||||||
|
/* .rpc_servers = */ rpc,
|
||||||
/* .split_mode = */ sm,
|
/* .split_mode = */ sm,
|
||||||
/* .main_gpu = */ mg,
|
/* .main_gpu = */ mg,
|
||||||
/* .no_kv_offload= */ nkvo,
|
/* .no_kv_offload= */ nkvo,
|
||||||
|
@ -643,6 +660,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||||
/* .type_v = */ tv,
|
/* .type_v = */ tv,
|
||||||
/* .n_threads = */ nt,
|
/* .n_threads = */ nt,
|
||||||
/* .n_gpu_layers = */ nl,
|
/* .n_gpu_layers = */ nl,
|
||||||
|
/* .rpc_servers = */ rpc,
|
||||||
/* .split_mode = */ sm,
|
/* .split_mode = */ sm,
|
||||||
/* .main_gpu = */ mg,
|
/* .main_gpu = */ mg,
|
||||||
/* .no_kv_offload= */ nkvo,
|
/* .no_kv_offload= */ nkvo,
|
||||||
|
@ -668,6 +686,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||||
/* .type_v = */ tv,
|
/* .type_v = */ tv,
|
||||||
/* .n_threads = */ nt,
|
/* .n_threads = */ nt,
|
||||||
/* .n_gpu_layers = */ nl,
|
/* .n_gpu_layers = */ nl,
|
||||||
|
/* .rpc_servers = */ rpc,
|
||||||
/* .split_mode = */ sm,
|
/* .split_mode = */ sm,
|
||||||
/* .main_gpu = */ mg,
|
/* .main_gpu = */ mg,
|
||||||
/* .no_kv_offload= */ nkvo,
|
/* .no_kv_offload= */ nkvo,
|
||||||
|
@ -692,6 +711,7 @@ struct test {
|
||||||
static const bool kompute;
|
static const bool kompute;
|
||||||
static const bool metal;
|
static const bool metal;
|
||||||
static const bool sycl;
|
static const bool sycl;
|
||||||
|
static const bool rpc;
|
||||||
static const bool gpu_blas;
|
static const bool gpu_blas;
|
||||||
static const bool blas;
|
static const bool blas;
|
||||||
static const std::string cpu_info;
|
static const std::string cpu_info;
|
||||||
|
@ -790,6 +810,9 @@ struct test {
|
||||||
if (sycl) {
|
if (sycl) {
|
||||||
return GGML_SYCL_NAME;
|
return GGML_SYCL_NAME;
|
||||||
}
|
}
|
||||||
|
if (rpc) {
|
||||||
|
return "RPC";
|
||||||
|
}
|
||||||
if (gpu_blas) {
|
if (gpu_blas) {
|
||||||
return "GPU BLAS";
|
return "GPU BLAS";
|
||||||
}
|
}
|
||||||
|
@ -803,7 +826,7 @@ struct test {
|
||||||
static const std::vector<std::string> & get_fields() {
|
static const std::vector<std::string> & get_fields() {
|
||||||
static const std::vector<std::string> fields = {
|
static const std::vector<std::string> fields = {
|
||||||
"build_commit", "build_number",
|
"build_commit", "build_number",
|
||||||
"cuda", "opencl", "vulkan", "kompute", "metal", "sycl", "gpu_blas", "blas",
|
"cuda", "opencl", "vulkan", "kompute", "metal", "sycl", "rpc", "gpu_blas", "blas",
|
||||||
"cpu_info", "gpu_info",
|
"cpu_info", "gpu_info",
|
||||||
"model_filename", "model_type", "model_size", "model_n_params",
|
"model_filename", "model_type", "model_size", "model_n_params",
|
||||||
"n_batch", "n_ubatch",
|
"n_batch", "n_ubatch",
|
||||||
|
@ -859,7 +882,7 @@ struct test {
|
||||||
std::vector<std::string> values = {
|
std::vector<std::string> values = {
|
||||||
build_commit, std::to_string(build_number),
|
build_commit, std::to_string(build_number),
|
||||||
std::to_string(cuda), std::to_string(opencl), std::to_string(vulkan), std::to_string(vulkan),
|
std::to_string(cuda), std::to_string(opencl), std::to_string(vulkan), std::to_string(vulkan),
|
||||||
std::to_string(metal), std::to_string(sycl), std::to_string(gpu_blas), std::to_string(blas),
|
std::to_string(metal), std::to_string(sycl), std::to_string(rpc), std::to_string(gpu_blas), std::to_string(blas),
|
||||||
cpu_info, gpu_info,
|
cpu_info, gpu_info,
|
||||||
model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
|
model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
|
||||||
std::to_string(n_batch), std::to_string(n_ubatch),
|
std::to_string(n_batch), std::to_string(n_ubatch),
|
||||||
|
@ -894,6 +917,7 @@ const bool test::metal = !!ggml_cpu_has_metal();
|
||||||
const bool test::gpu_blas = !!ggml_cpu_has_gpublas();
|
const bool test::gpu_blas = !!ggml_cpu_has_gpublas();
|
||||||
const bool test::blas = !!ggml_cpu_has_blas();
|
const bool test::blas = !!ggml_cpu_has_blas();
|
||||||
const bool test::sycl = !!ggml_cpu_has_sycl();
|
const bool test::sycl = !!ggml_cpu_has_sycl();
|
||||||
|
const bool test::rpc = !!ggml_cpu_has_rpc();
|
||||||
const std::string test::cpu_info = get_cpu_info();
|
const std::string test::cpu_info = get_cpu_info();
|
||||||
const std::string test::gpu_info = get_gpu_info();
|
const std::string test::gpu_info = get_gpu_info();
|
||||||
|
|
||||||
|
|
|
@ -1870,7 +1870,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
|
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
||||||
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
||||||
// use cublasGemmStridedBatchedEx
|
// use cublasGemmStridedBatchedEx
|
||||||
CUBLAS_CHECK(
|
CUBLAS_CHECK(
|
||||||
|
@ -2886,7 +2886,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
|
return true;
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
|
return ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
#include "concat.cuh"
|
#include "concat.cuh"
|
||||||
|
|
||||||
|
// contiguous kernels
|
||||||
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
|
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
|
||||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
if (nidx >= ne0) {
|
if (nidx >= ne0) {
|
||||||
|
@ -92,39 +93,104 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, int n
|
||||||
concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
|
concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// non-contiguous kernel (slow)
|
||||||
|
static __global__ void concat_f32_non_cont(
|
||||||
|
const char * src0,
|
||||||
|
const char * src1,
|
||||||
|
char * dst,
|
||||||
|
int64_t ne00,
|
||||||
|
int64_t ne01,
|
||||||
|
int64_t ne02,
|
||||||
|
int64_t ne03,
|
||||||
|
uint64_t nb00,
|
||||||
|
uint64_t nb01,
|
||||||
|
uint64_t nb02,
|
||||||
|
uint64_t nb03,
|
||||||
|
int64_t /*ne10*/,
|
||||||
|
int64_t /*ne11*/,
|
||||||
|
int64_t /*ne12*/,
|
||||||
|
int64_t /*ne13*/,
|
||||||
|
uint64_t nb10,
|
||||||
|
uint64_t nb11,
|
||||||
|
uint64_t nb12,
|
||||||
|
uint64_t nb13,
|
||||||
|
int64_t ne0,
|
||||||
|
int64_t /*ne1*/,
|
||||||
|
int64_t /*ne2*/,
|
||||||
|
int64_t /*ne3*/,
|
||||||
|
uint64_t nb0,
|
||||||
|
uint64_t nb1,
|
||||||
|
uint64_t nb2,
|
||||||
|
uint64_t nb3,
|
||||||
|
int32_t dim) {
|
||||||
|
const int64_t i3 = blockIdx.z;
|
||||||
|
const int64_t i2 = blockIdx.y;
|
||||||
|
const int64_t i1 = blockIdx.x;
|
||||||
|
|
||||||
|
int64_t o[4] = {0, 0, 0, 0};
|
||||||
|
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
|
||||||
|
|
||||||
|
const float * x;
|
||||||
|
|
||||||
|
for (int i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
|
||||||
|
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
||||||
|
x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
|
||||||
|
} else {
|
||||||
|
x = (const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
|
||||||
|
}
|
||||||
|
|
||||||
|
float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
|
*y = *x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const ggml_tensor * src1 = dst->src[1];
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
|
||||||
const float * src0_d = (const float *)src0->data;
|
|
||||||
const float * src1_d = (const float *)src1->data;
|
|
||||||
|
|
||||||
float * dst_d = (float *)dst->data;
|
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
const int32_t dim = ((int32_t *) dst->op_params)[0];
|
const int32_t dim = ((int32_t *) dst->op_params)[0];
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
if (dim != 3) {
|
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||||
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
|
const float * src0_d = (const float *)src0->data;
|
||||||
concat_f32_cuda(
|
const float * src1_d = (const float *)src1->data;
|
||||||
src0_d + i3 * (src0->nb[3] / 4),
|
|
||||||
src1_d + i3 * (src1->nb[3] / 4),
|
float * dst_d = (float *)dst->data;
|
||||||
dst_d + i3 * ( dst->nb[3] / 4),
|
|
||||||
src0->ne[0], src0->ne[1], src0->ne[2],
|
if (dim != 3) {
|
||||||
dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
|
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
|
||||||
|
concat_f32_cuda(
|
||||||
|
src0_d + i3 * (src0->nb[3] / 4),
|
||||||
|
src1_d + i3 * (src1->nb[3] / 4),
|
||||||
|
dst_d + i3 * ( dst->nb[3] / 4),
|
||||||
|
src0->ne[0], src0->ne[1], src0->ne[2],
|
||||||
|
dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const size_t size0 = ggml_nbytes(src0);
|
||||||
|
const size_t size1 = ggml_nbytes(src1);
|
||||||
|
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const size_t size0 = ggml_nbytes(src0);
|
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
|
||||||
const size_t size1 = ggml_nbytes(src1);
|
concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
|
||||||
|
(const char *)src0->data,
|
||||||
CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
|
(const char *)src1->data,
|
||||||
CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
|
( char *)dst->data,
|
||||||
|
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||||
|
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||||
|
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
|
||||||
|
src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
|
||||||
|
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||||
|
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -170,6 +170,8 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
float * dst_d = (float *)dst->data;
|
float * dst_d = (float *)dst->data;
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
@ -188,6 +190,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||||
float * dst_d = (float *)dst->data;
|
float * dst_d = (float *)dst->data;
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
@ -202,6 +206,8 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
float * dst_d = (float *)dst->data;
|
float * dst_d = (float *)dst->data;
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
|
|
@ -61,7 +61,7 @@ static __global__ void rope(
|
||||||
template<typename T, bool has_pos, bool has_freq_facs>
|
template<typename T, bool has_pos, bool has_freq_facs>
|
||||||
static __global__ void rope_neox(
|
static __global__ void rope_neox(
|
||||||
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
|
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||||
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors
|
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors
|
||||||
) {
|
) {
|
||||||
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||||
|
|
||||||
|
@ -85,15 +85,13 @@ static __global__ void rope_neox(
|
||||||
const int i = row*ncols + ib*n_dims + ic/2;
|
const int i = row*ncols + ib*n_dims + ic/2;
|
||||||
const int i2 = row/p_delta_rows;
|
const int i2 = row/p_delta_rows;
|
||||||
|
|
||||||
float cur_rot = inv_ndims * ic - ib;
|
|
||||||
|
|
||||||
const int p = has_pos ? pos[i2] : 0;
|
const int p = has_pos ? pos[i2] : 0;
|
||||||
const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
|
const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
|
||||||
|
|
||||||
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor;
|
const float theta_base = p*powf(theta_scale, col/2.0f)/freq_factor;
|
||||||
|
|
||||||
float cos_theta, sin_theta;
|
float cos_theta, sin_theta;
|
||||||
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
rope_yarn(theta_base, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||||
|
|
||||||
const float x0 = x[i + 0];
|
const float x0 = x[i + 0];
|
||||||
const float x1 = x[i + n_dims/2];
|
const float x1 = x[i + n_dims/2];
|
||||||
|
@ -174,30 +172,29 @@ static void rope_neox_cuda(
|
||||||
const dim3 block_nums(nrows, num_blocks_x, 1);
|
const dim3 block_nums(nrows, num_blocks_x, 1);
|
||||||
|
|
||||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||||
const float inv_ndims = -1.0f / n_dims;
|
|
||||||
|
|
||||||
if (pos == nullptr) {
|
if (pos == nullptr) {
|
||||||
if (freq_factors == nullptr) {
|
if (freq_factors == nullptr) {
|
||||||
rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
|
rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
|
||||||
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||||
theta_scale, inv_ndims, freq_factors
|
theta_scale, freq_factors
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
|
rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
|
||||||
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||||
theta_scale, inv_ndims, freq_factors
|
theta_scale, freq_factors
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (freq_factors == nullptr) {
|
if (freq_factors == nullptr) {
|
||||||
rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
|
rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
|
||||||
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||||
theta_scale, inv_ndims, freq_factors
|
theta_scale, freq_factors
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
|
rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
|
||||||
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||||
theta_scale, inv_ndims, freq_factors
|
theta_scale, freq_factors
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -254,6 +251,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
float * dst_d = (float *)dst->data;
|
float * dst_d = (float *)dst->data;
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT(src0->type == dst->type);
|
GGML_ASSERT(src0->type == dst->type);
|
||||||
|
|
|
@ -1597,7 +1597,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne00 == ne10);
|
GGML_ASSERT(ne00 == ne10);
|
||||||
|
|
||||||
// TODO: assert that dim2 and dim3 are contiguous
|
|
||||||
GGML_ASSERT(ne12 % ne02 == 0);
|
GGML_ASSERT(ne12 % ne02 == 0);
|
||||||
GGML_ASSERT(ne13 % ne03 == 0);
|
GGML_ASSERT(ne13 % ne03 == 0);
|
||||||
|
|
||||||
|
|
|
@ -1519,7 +1519,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne00 == ne10);
|
GGML_ASSERT(ne00 == ne10);
|
||||||
|
|
||||||
// TODO: assert that dim2 and dim3 are contiguous
|
|
||||||
GGML_ASSERT(ne12 % ne02 == 0);
|
GGML_ASSERT(ne12 % ne02 == 0);
|
||||||
GGML_ASSERT(ne13 % ne03 == 0);
|
GGML_ASSERT(ne13 % ne03 == 0);
|
||||||
|
|
||||||
|
@ -2187,6 +2186,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne00 % 4 == 0);
|
GGML_ASSERT(ne00 % 4 == 0);
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
|
||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, dst->op_params, sizeof(float));
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
@ -2214,6 +2214,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne00 % 4 == 0);
|
GGML_ASSERT(ne00 % 4 == 0);
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
//float eps;
|
//float eps;
|
||||||
//memcpy(&eps, dst->op_params, sizeof(float));
|
//memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
@ -2247,6 +2248,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
{
|
{
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
|
||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, dst->op_params, sizeof(float));
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
|
|
|
@ -1767,13 +1767,13 @@ kernel void kernel_rope(
|
||||||
|
|
||||||
const int64_t p = pos[i2];
|
const int64_t p = pos[i2];
|
||||||
|
|
||||||
const float theta_0 = (float)p;
|
const float theta_base = (float)p;
|
||||||
const float inv_ndims = -1.f/n_dims;
|
const float inv_ndims = -1.f/n_dims;
|
||||||
|
|
||||||
if (!is_neox) {
|
if (!is_neox) {
|
||||||
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
||||||
|
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
||||||
|
|
||||||
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
|
||||||
float cos_theta, sin_theta;
|
float cos_theta, sin_theta;
|
||||||
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||||
|
|
||||||
|
@ -1789,18 +1789,14 @@ kernel void kernel_rope(
|
||||||
} else {
|
} else {
|
||||||
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
|
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
|
||||||
if (ic < n_dims) {
|
if (ic < n_dims) {
|
||||||
const int64_t ib = 0;
|
const int64_t i0 = ic/2;
|
||||||
|
|
||||||
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
|
||||||
const float cur_rot = inv_ndims*ic - ib;
|
|
||||||
const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
|
|
||||||
|
|
||||||
const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
|
const float theta = theta_base * pow(freq_base, inv_ndims*ic);
|
||||||
|
|
||||||
float cos_theta, sin_theta;
|
float cos_theta, sin_theta;
|
||||||
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||||
|
|
||||||
const int64_t i0 = ib*n_dims + ic/2;
|
|
||||||
|
|
||||||
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||||
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
|
@ -15183,7 +15183,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
|
||||||
const int64_t r2 = ne12/ne02;
|
const int64_t r2 = ne12/ne02;
|
||||||
const int64_t r3 = ne13/ne03;
|
const int64_t r3 = ne13/ne03;
|
||||||
|
|
||||||
if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
|
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
||||||
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
||||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
||||||
*g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans,
|
*g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans,
|
||||||
|
|
115
ggml.c
115
ggml.c
|
@ -60,6 +60,9 @@
|
||||||
|
|
||||||
typedef volatile LONG atomic_int;
|
typedef volatile LONG atomic_int;
|
||||||
typedef atomic_int atomic_bool;
|
typedef atomic_int atomic_bool;
|
||||||
|
typedef atomic_int atomic_flag;
|
||||||
|
|
||||||
|
#define ATOMIC_FLAG_INIT 0
|
||||||
|
|
||||||
static void atomic_store(atomic_int * ptr, LONG val) {
|
static void atomic_store(atomic_int * ptr, LONG val) {
|
||||||
InterlockedExchange(ptr, val);
|
InterlockedExchange(ptr, val);
|
||||||
|
@ -73,6 +76,12 @@ static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
|
||||||
static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
|
static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
|
||||||
return atomic_fetch_add(ptr, -(dec));
|
return atomic_fetch_add(ptr, -(dec));
|
||||||
}
|
}
|
||||||
|
static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
|
||||||
|
return InterlockedExchange(ptr, 1);
|
||||||
|
}
|
||||||
|
static void atomic_flag_clear(atomic_flag * ptr) {
|
||||||
|
InterlockedExchange(ptr, 0);
|
||||||
|
}
|
||||||
|
|
||||||
typedef HANDLE pthread_t;
|
typedef HANDLE pthread_t;
|
||||||
|
|
||||||
|
@ -2883,24 +2892,20 @@ struct ggml_state {
|
||||||
|
|
||||||
// global state
|
// global state
|
||||||
static struct ggml_state g_state;
|
static struct ggml_state g_state;
|
||||||
static atomic_int g_state_barrier = 0;
|
static atomic_flag g_state_critical = ATOMIC_FLAG_INIT;
|
||||||
|
|
||||||
// barrier via spin lock
|
// barrier via spin lock
|
||||||
inline static void ggml_critical_section_start(void) {
|
inline static void ggml_critical_section_start(void) {
|
||||||
int processing = atomic_fetch_add(&g_state_barrier, 1);
|
while (atomic_flag_test_and_set(&g_state_critical)) {
|
||||||
|
// spin
|
||||||
while (processing > 0) {
|
sched_yield();
|
||||||
// wait for other threads to finish
|
|
||||||
atomic_fetch_sub(&g_state_barrier, 1);
|
|
||||||
sched_yield(); // TODO: reconsider this
|
|
||||||
processing = atomic_fetch_add(&g_state_barrier, 1);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: make this somehow automatically executed
|
// TODO: make this somehow automatically executed
|
||||||
// some sort of "sentry" mechanism
|
// some sort of "sentry" mechanism
|
||||||
inline static void ggml_critical_section_end(void) {
|
inline static void ggml_critical_section_end(void) {
|
||||||
atomic_fetch_sub(&g_state_barrier, 1);
|
atomic_flag_clear(&g_state_critical);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(__gnu_linux__)
|
#if defined(__gnu_linux__)
|
||||||
|
@ -3216,7 +3221,11 @@ GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
|
||||||
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) {
|
GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
|
||||||
|
return ggml_is_contiguous(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
|
||||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -3225,6 +3234,14 @@ static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * te
|
||||||
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
|
||||||
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||||
|
|
||||||
|
return
|
||||||
|
tensor->nb[0] == ggml_type_size(tensor->type) &&
|
||||||
|
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
||||||
|
}
|
||||||
|
|
||||||
GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
|
GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
|
||||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||||
|
|
||||||
|
@ -6392,6 +6409,16 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_rope_xpos_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
int n_dims,
|
||||||
|
float base,
|
||||||
|
bool down) {
|
||||||
|
return ggml_rope_impl(ctx, a, b, NULL, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_rope_back
|
// ggml_rope_back
|
||||||
|
|
||||||
struct ggml_tensor * ggml_rope_back(
|
struct ggml_tensor * ggml_rope_back(
|
||||||
|
@ -11405,8 +11432,8 @@ static void ggml_compute_forward_gelu_f32(
|
||||||
|
|
||||||
const struct ggml_tensor * src0 = dst->src[0];
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
|
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||||
|
|
||||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||||
|
@ -11468,8 +11495,8 @@ static void ggml_compute_forward_gelu_quick_f32(
|
||||||
|
|
||||||
const struct ggml_tensor * src0 = dst->src[0];
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
|
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||||
|
|
||||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||||
|
@ -11531,8 +11558,8 @@ static void ggml_compute_forward_silu_f32(
|
||||||
|
|
||||||
const struct ggml_tensor * src0 = dst->src[0];
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
|
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||||
|
|
||||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||||
|
@ -11643,9 +11670,9 @@ static void ggml_compute_forward_silu_back_f32(
|
||||||
const struct ggml_tensor * src0 = dst->src[0];
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
const struct ggml_tensor * grad = dst->src[1];
|
const struct ggml_tensor * grad = dst->src[1];
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad));
|
GGML_ASSERT(ggml_is_contiguous_1(grad));
|
||||||
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
|
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||||
GGML_ASSERT(ggml_are_same_shape(src0, grad));
|
GGML_ASSERT(ggml_are_same_shape(src0, grad));
|
||||||
|
|
||||||
|
@ -14343,7 +14370,7 @@ static void ggml_compute_forward_rope_f32(
|
||||||
int ir = 0;
|
int ir = 0;
|
||||||
|
|
||||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||||
const float inv_ndims = -1.f/n_dims;
|
|
||||||
float corr_dims[2];
|
float corr_dims[2];
|
||||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||||
|
|
||||||
|
@ -14392,7 +14419,7 @@ static void ggml_compute_forward_rope_f32(
|
||||||
const float cos_block_theta = cosf(block_theta);
|
const float cos_block_theta = cosf(block_theta);
|
||||||
const float sin_block_theta = sinf(block_theta) * sin_sign;
|
const float sin_block_theta = sinf(block_theta) * sin_sign;
|
||||||
|
|
||||||
theta_base *= theta_scale;
|
theta_base *= theta_scale;
|
||||||
block_theta *= theta_scale;
|
block_theta *= theta_scale;
|
||||||
|
|
||||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||||
|
@ -14427,29 +14454,22 @@ static void ggml_compute_forward_rope_f32(
|
||||||
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
|
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// TODO: this might be wrong for ne0 != n_dims - need double check
|
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
||||||
// it seems we have to rope just the first n_dims elements and do nothing with the rest
|
|
||||||
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
|
|
||||||
theta_base *= freq_scale;
|
|
||||||
for (int64_t ic = 0; ic < ne0; ic += 2) {
|
for (int64_t ic = 0; ic < ne0; ic += 2) {
|
||||||
if (ic < n_dims) {
|
if (ic < n_dims) {
|
||||||
const int64_t ib = 0;
|
const int64_t i0 = ic/2;
|
||||||
|
|
||||||
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
|
||||||
float cur_rot = inv_ndims * ic - ib;
|
|
||||||
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
|
|
||||||
|
|
||||||
float cos_theta, sin_theta;
|
float cos_theta, sin_theta;
|
||||||
rope_yarn(
|
rope_yarn(
|
||||||
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
|
||||||
&cos_theta, &sin_theta
|
&cos_theta, &sin_theta
|
||||||
);
|
);
|
||||||
sin_theta *= sin_sign;
|
|
||||||
|
|
||||||
|
sin_theta *= sin_sign;
|
||||||
theta_base *= theta_scale;
|
theta_base *= theta_scale;
|
||||||
|
|
||||||
const int64_t i0 = ib*n_dims + ic/2;
|
|
||||||
|
|
||||||
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
|
@ -14528,7 +14548,7 @@ static void ggml_compute_forward_rope_f16(
|
||||||
int ir = 0;
|
int ir = 0;
|
||||||
|
|
||||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||||
const float inv_ndims = -1.f/n_dims;
|
|
||||||
float corr_dims[2];
|
float corr_dims[2];
|
||||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||||
|
|
||||||
|
@ -14577,7 +14597,7 @@ static void ggml_compute_forward_rope_f16(
|
||||||
const float cos_block_theta = cosf(block_theta);
|
const float cos_block_theta = cosf(block_theta);
|
||||||
const float sin_block_theta = sinf(block_theta) * sin_sign;
|
const float sin_block_theta = sinf(block_theta) * sin_sign;
|
||||||
|
|
||||||
theta_base *= theta_scale;
|
theta_base *= theta_scale;
|
||||||
block_theta *= theta_scale;
|
block_theta *= theta_scale;
|
||||||
|
|
||||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||||
|
@ -14608,29 +14628,22 @@ static void ggml_compute_forward_rope_f16(
|
||||||
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// TODO: this might be wrong for ne0 != n_dims - need double check
|
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
||||||
// it seems we have to rope just the first n_dims elements and do nothing with the rest
|
|
||||||
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
|
|
||||||
theta_base *= freq_scale;
|
|
||||||
for (int64_t ic = 0; ic < ne0; ic += 2) {
|
for (int64_t ic = 0; ic < ne0; ic += 2) {
|
||||||
if (ic < n_dims) {
|
if (ic < n_dims) {
|
||||||
const int64_t ib = 0;
|
const int64_t i0 = ic/2;
|
||||||
|
|
||||||
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
|
||||||
float cur_rot = inv_ndims * ic - ib;
|
|
||||||
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
|
|
||||||
|
|
||||||
float cos_theta, sin_theta;
|
float cos_theta, sin_theta;
|
||||||
rope_yarn(
|
rope_yarn(
|
||||||
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
|
||||||
&cos_theta, &sin_theta
|
&cos_theta, &sin_theta
|
||||||
);
|
);
|
||||||
sin_theta *= sin_sign;
|
|
||||||
|
|
||||||
|
sin_theta *= sin_sign;
|
||||||
theta_base *= theta_scale;
|
theta_base *= theta_scale;
|
||||||
|
|
||||||
const int64_t i0 = ib*n_dims + ic/2;
|
|
||||||
|
|
||||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
|
@ -22857,6 +22870,14 @@ int ggml_cpu_has_sycl(void) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ggml_cpu_has_rpc(void) {
|
||||||
|
#if defined(GGML_USE_RPC)
|
||||||
|
return 1;
|
||||||
|
#else
|
||||||
|
return 0;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
int ggml_cpu_has_gpublas(void) {
|
int ggml_cpu_has_gpublas(void) {
|
||||||
return ggml_cpu_has_cuda() || ggml_cpu_has_clblast() || ggml_cpu_has_vulkan() || ggml_cpu_has_kompute() ||
|
return ggml_cpu_has_cuda() || ggml_cpu_has_clblast() || ggml_cpu_has_vulkan() || ggml_cpu_has_kompute() ||
|
||||||
ggml_cpu_has_sycl();
|
ggml_cpu_has_sycl();
|
||||||
|
|
15
ggml.h
15
ggml.h
|
@ -756,7 +756,6 @@ extern "C" {
|
||||||
GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
|
GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
|
||||||
|
|
||||||
GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor);
|
GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor);
|
||||||
GGML_API GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor);
|
|
||||||
GGML_API GGML_CALL bool ggml_is_permuted (const struct ggml_tensor * tensor);
|
GGML_API GGML_CALL bool ggml_is_permuted (const struct ggml_tensor * tensor);
|
||||||
GGML_API GGML_CALL bool ggml_is_empty (const struct ggml_tensor * tensor);
|
GGML_API GGML_CALL bool ggml_is_empty (const struct ggml_tensor * tensor);
|
||||||
GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
|
GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
|
||||||
|
@ -765,6 +764,11 @@ extern "C" {
|
||||||
GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor);
|
GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor);
|
||||||
GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars
|
GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars
|
||||||
|
|
||||||
|
GGML_API GGML_CALL bool ggml_is_contiguous (const struct ggml_tensor * tensor);
|
||||||
|
GGML_API GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous()
|
||||||
|
GGML_API GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
|
||||||
|
GGML_API GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
|
||||||
|
|
||||||
GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
||||||
GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
||||||
|
|
||||||
|
@ -1548,6 +1552,14 @@ extern "C" {
|
||||||
float beta_slow),
|
float beta_slow),
|
||||||
"use ggml_rope_ext_inplace instead");
|
"use ggml_rope_ext_inplace instead");
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_rope_xpos_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
int n_dims,
|
||||||
|
float base,
|
||||||
|
bool down);
|
||||||
|
|
||||||
// compute correction dims for YaRN RoPE scaling
|
// compute correction dims for YaRN RoPE scaling
|
||||||
GGML_CALL void ggml_rope_yarn_corr_dims(
|
GGML_CALL void ggml_rope_yarn_corr_dims(
|
||||||
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
|
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
|
||||||
|
@ -2420,6 +2432,7 @@ extern "C" {
|
||||||
GGML_API int ggml_cpu_has_sse3 (void);
|
GGML_API int ggml_cpu_has_sse3 (void);
|
||||||
GGML_API int ggml_cpu_has_ssse3 (void);
|
GGML_API int ggml_cpu_has_ssse3 (void);
|
||||||
GGML_API int ggml_cpu_has_sycl (void);
|
GGML_API int ggml_cpu_has_sycl (void);
|
||||||
|
GGML_API int ggml_cpu_has_rpc (void);
|
||||||
GGML_API int ggml_cpu_has_vsx (void);
|
GGML_API int ggml_cpu_has_vsx (void);
|
||||||
GGML_API int ggml_cpu_has_matmul_int8(void);
|
GGML_API int ggml_cpu_has_matmul_int8(void);
|
||||||
|
|
||||||
|
|
|
@ -2670,14 +2670,12 @@ void main() {
|
||||||
const uint i = row*p.ncols + ib*p.ndims + ic/2;
|
const uint i = row*p.ncols + ib*p.ndims + ic/2;
|
||||||
const uint i2 = row/p.p_delta_rows;
|
const uint i2 = row/p.p_delta_rows;
|
||||||
|
|
||||||
const float cur_rot = p.inv_ndims * ic - ib;
|
|
||||||
|
|
||||||
const int pos = data_b[i2];
|
const int pos = data_b[i2];
|
||||||
const float freq_factor = p.has_freq_facs != 0 ? data_freq_factors[ic/2] : 1.0f;
|
const float freq_factor = p.has_freq_facs != 0 ? data_freq_factors[ic/2] : 1.0f;
|
||||||
const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f) / freq_factor;
|
const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f) / freq_factor;
|
||||||
|
|
||||||
float cos_theta, sin_theta;
|
float cos_theta, sin_theta;
|
||||||
rope_yarn(theta_base, uint(cur_rot), cos_theta, sin_theta);
|
rope_yarn(theta_base, ic, cos_theta, sin_theta);
|
||||||
|
|
||||||
const float x0 = float(data_a[i + 0]);
|
const float x0 = float(data_a[i + 0]);
|
||||||
const float x1 = float(data_a[i + p.ndims/2]);
|
const float x1 = float(data_a[i + p.ndims/2]);
|
||||||
|
|
|
@ -144,6 +144,7 @@ def main() -> None:
|
||||||
parser.add_argument("--general-description", type=str, help="The models general.description", metavar='"Description ..."')
|
parser.add_argument("--general-description", type=str, help="The models general.description", metavar='"Description ..."')
|
||||||
parser.add_argument("--chat-template", type=str, help="Chat template string (or JSON string containing templates)", metavar='"{% ... %} ..."')
|
parser.add_argument("--chat-template", type=str, help="Chat template string (or JSON string containing templates)", metavar='"{% ... %} ..."')
|
||||||
parser.add_argument("--chat-template-config", type=Path, help="Config file containing chat template(s)", metavar='tokenizer_config.json')
|
parser.add_argument("--chat-template-config", type=Path, help="Config file containing chat template(s)", metavar='tokenizer_config.json')
|
||||||
|
parser.add_argument("--pre-tokenizer", type=str, help="The models tokenizer.ggml.pre", metavar='"pre tokenizer"')
|
||||||
parser.add_argument("--remove-metadata", action="append", type=str, help="Remove metadata (by key name) from output model", metavar='general.url')
|
parser.add_argument("--remove-metadata", action="append", type=str, help="Remove metadata (by key name) from output model", metavar='general.url')
|
||||||
parser.add_argument("--special-token", action="append", type=str, help="Special token by value", nargs=2, metavar=(' | '.join(token_names.keys()), '"<token>"'))
|
parser.add_argument("--special-token", action="append", type=str, help="Special token by value", nargs=2, metavar=(' | '.join(token_names.keys()), '"<token>"'))
|
||||||
parser.add_argument("--special-token-by-id", action="append", type=str, help="Special token by id", nargs=2, metavar=(' | '.join(token_names.keys()), '0'))
|
parser.add_argument("--special-token-by-id", action="append", type=str, help="Special token by id", nargs=2, metavar=(' | '.join(token_names.keys()), '0'))
|
||||||
|
@ -172,6 +173,9 @@ def main() -> None:
|
||||||
if template:
|
if template:
|
||||||
new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, template)
|
new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, template)
|
||||||
|
|
||||||
|
if args.pre_tokenizer:
|
||||||
|
new_metadata[gguf.Keys.Tokenizer.PRE] = MetadataDetails(gguf.GGUFValueType.STRING, args.pre_tokenizer)
|
||||||
|
|
||||||
if remove_metadata:
|
if remove_metadata:
|
||||||
logger.warning('*** Warning *** Warning *** Warning **')
|
logger.warning('*** Warning *** Warning *** Warning **')
|
||||||
logger.warning('* Most metadata is required for a fully functional GGUF file,')
|
logger.warning('* Most metadata is required for a fully functional GGUF file,')
|
||||||
|
|
52
llama.cpp
52
llama.cpp
|
@ -11187,46 +11187,69 @@ struct llm_build_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
||||||
struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, 0);
|
struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
|
||||||
|
ggml_row_size(q->type, hparams.n_embd_head_k),
|
||||||
|
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
|
||||||
|
0);
|
||||||
cb(q_nope, "q_nope", il);
|
cb(q_nope, "q_nope", il);
|
||||||
|
|
||||||
// and {n_head * n_embd_head_qk_rope, n_tokens}
|
// and {n_head * n_embd_head_qk_rope, n_tokens}
|
||||||
struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, ggml_element_size(q) * n_embd_head_qk_nope);
|
struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
|
||||||
|
ggml_row_size(q->type, hparams.n_embd_head_k),
|
||||||
|
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
|
||||||
|
ggml_row_size(q->type, n_embd_head_qk_nope));
|
||||||
cb(q_pe, "q_pe", il);
|
cb(q_pe, "q_pe", il);
|
||||||
|
|
||||||
// {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
|
// {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
|
||||||
struct ggml_tensor * compressed_kv_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
|
struct ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
|
||||||
cb(compressed_kv_pe, "compressed_kv_pe", il);
|
cb(kv_pe_compresseed, "kv_pe_compresseed", il);
|
||||||
|
|
||||||
// split into {kv_lora_rank, n_tokens}
|
// split into {kv_lora_rank, n_tokens}
|
||||||
struct ggml_tensor * compressed_kv = ggml_view_2d(ctx0, compressed_kv_pe, kv_lora_rank, n_tokens, compressed_kv_pe->nb[1], 0);
|
struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
|
||||||
cb(compressed_kv, "compressed_kv", il);
|
kv_pe_compresseed->nb[1],
|
||||||
|
0);
|
||||||
|
cb(kv_compressed, "kv_compressed", il);
|
||||||
|
|
||||||
// and {n_embd_head_qk_rope, n_tokens}
|
// and {n_embd_head_qk_rope, n_tokens}
|
||||||
struct ggml_tensor * k_pe = ggml_view_2d(ctx0, compressed_kv_pe, n_embd_head_qk_rope, n_tokens, compressed_kv_pe->nb[1], ggml_element_size(compressed_kv_pe)*kv_lora_rank);
|
struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
|
||||||
|
kv_pe_compresseed->nb[1],
|
||||||
|
kv_pe_compresseed->nb[1],
|
||||||
|
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
|
||||||
cb(k_pe, "k_pe", il);
|
cb(k_pe, "k_pe", il);
|
||||||
|
|
||||||
compressed_kv = llm_build_norm(ctx0, compressed_kv, hparams,
|
kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
|
||||||
|
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
|
||||||
model.layers[il].attn_kv_a_norm, NULL,
|
model.layers[il].attn_kv_a_norm, NULL,
|
||||||
LLM_NORM_RMS, cb, il);
|
LLM_NORM_RMS, cb, il);
|
||||||
cb(compressed_kv, "compressed_kv", il);
|
cb(kv_compressed, "kv_compressed", il);
|
||||||
|
|
||||||
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
|
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
|
||||||
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, compressed_kv);
|
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
|
||||||
cb(kv, "kv", il);
|
cb(kv, "kv", il);
|
||||||
|
|
||||||
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
||||||
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, ggml_element_size(kv) * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
|
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
|
||||||
|
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
|
||||||
|
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||||
|
0);
|
||||||
cb(k_nope, "k_nope", il);
|
cb(k_nope, "k_nope", il);
|
||||||
|
|
||||||
// and {n_head * n_embd_head_v, n_tokens}
|
// and {n_head * n_embd_head_v, n_tokens}
|
||||||
struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, ggml_element_size(kv) * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_embd_head_qk_nope);
|
struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
|
||||||
|
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||||
|
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
|
||||||
|
ggml_row_size(kv->type, (n_embd_head_qk_nope)));
|
||||||
cb(v_states, "v_states", il);
|
cb(v_states, "v_states", il);
|
||||||
|
|
||||||
v_states = ggml_cont(ctx0, v_states);
|
v_states = ggml_cont(ctx0, v_states);
|
||||||
cb(v_states, "v_states", il);
|
cb(v_states, "v_states", il);
|
||||||
|
|
||||||
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, ggml_element_size(kv) * hparams.n_embd_head_v * n_head, 0);
|
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
|
||||||
|
ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
|
||||||
|
0);
|
||||||
cb(v_states, "v_states", il);
|
cb(v_states, "v_states", il);
|
||||||
|
|
||||||
|
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
||||||
q_pe = ggml_rope_ext(
|
q_pe = ggml_rope_ext(
|
||||||
ctx0, q_pe, inp_pos, nullptr,
|
ctx0, q_pe, inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
||||||
|
@ -11235,8 +11258,9 @@ struct llm_build_context {
|
||||||
cb(q_pe, "q_pe", il);
|
cb(q_pe, "q_pe", il);
|
||||||
|
|
||||||
// shared RoPE key
|
// shared RoPE key
|
||||||
|
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
||||||
k_pe = ggml_rope_ext(
|
k_pe = ggml_rope_ext(
|
||||||
ctx0, ggml_view_3d(ctx0, k_pe, n_embd_head_qk_rope, 1, n_tokens, k_pe->nb[0], k_pe->nb[1], 0), inp_pos, nullptr,
|
ctx0, k_pe, inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor_scaled, beta_fast, beta_slow
|
ext_factor, attn_factor_scaled, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
|
|
|
@ -106,8 +106,6 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
|
||||||
# src/ggml-kompute.h -> ggml-kompute.h
|
# src/ggml-kompute.h -> ggml-kompute.h
|
||||||
# src/ggml-metal.h -> ggml-metal.h
|
# src/ggml-metal.h -> ggml-metal.h
|
||||||
# src/ggml-metal.m -> ggml-metal.m
|
# src/ggml-metal.m -> ggml-metal.m
|
||||||
# src/ggml-mpi.h -> ggml-mpi.h
|
|
||||||
# src/ggml-mpi.c -> ggml-mpi.c
|
|
||||||
# src/ggml-opencl.cpp -> ggml-opencl.cpp
|
# src/ggml-opencl.cpp -> ggml-opencl.cpp
|
||||||
# src/ggml-opencl.h -> ggml-opencl.h
|
# src/ggml-opencl.h -> ggml-opencl.h
|
||||||
# src/ggml-quants.c -> ggml-quants.c
|
# src/ggml-quants.c -> ggml-quants.c
|
||||||
|
@ -145,8 +143,6 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
|
||||||
-e 's/src\/ggml-kompute\.h/ggml-kompute.h/g' \
|
-e 's/src\/ggml-kompute\.h/ggml-kompute.h/g' \
|
||||||
-e 's/src\/ggml-metal\.h/ggml-metal.h/g' \
|
-e 's/src\/ggml-metal\.h/ggml-metal.h/g' \
|
||||||
-e 's/src\/ggml-metal\.m/ggml-metal.m/g' \
|
-e 's/src\/ggml-metal\.m/ggml-metal.m/g' \
|
||||||
-e 's/src\/ggml-mpi\.h/ggml-mpi.h/g' \
|
|
||||||
-e 's/src\/ggml-mpi\.c/ggml-mpi.c/g' \
|
|
||||||
-e 's/src\/ggml-opencl\.cpp/ggml-opencl.cpp/g' \
|
-e 's/src\/ggml-opencl\.cpp/ggml-opencl.cpp/g' \
|
||||||
-e 's/src\/ggml-opencl\.h/ggml-opencl.h/g' \
|
-e 's/src\/ggml-opencl\.h/ggml-opencl.h/g' \
|
||||||
-e 's/src\/ggml-quants\.c/ggml-quants.c/g' \
|
-e 's/src\/ggml-quants\.c/ggml-quants.c/g' \
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
126d34985705a5a2222723c145cb4e125ac689f3
|
2aae01fd9b8f9399f343cf18f46f38996ef52e2c
|
||||||
|
|
|
@ -14,8 +14,6 @@ cp -rpv ../ggml/src/ggml-kompute.h ./ggml-kompute.h
|
||||||
cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h
|
cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h
|
||||||
cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
|
cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
|
||||||
cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
|
cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
|
||||||
cp -rpv ../ggml/src/ggml-mpi.h ./ggml-mpi.h
|
|
||||||
cp -rpv ../ggml/src/ggml-mpi.c ./ggml-mpi.c
|
|
||||||
cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp
|
cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp
|
||||||
cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
|
cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
|
||||||
cp -rpv ../ggml/src/ggml-quants.c ./ggml-quants.c
|
cp -rpv ../ggml/src/ggml-quants.c ./ggml-quants.c
|
||||||
|
|
|
@ -1138,26 +1138,37 @@ struct test_soft_max : public test_case {
|
||||||
// GGML_OP_ROPE
|
// GGML_OP_ROPE
|
||||||
struct test_rope : public test_case {
|
struct test_rope : public test_case {
|
||||||
const ggml_type type;
|
const ggml_type type;
|
||||||
const std::array<int64_t, 4> ne;
|
const std::array<int64_t, 4> ne_a;
|
||||||
int n_dims;
|
int n_dims;
|
||||||
int mode;
|
int mode;
|
||||||
int n_ctx;
|
int n_ctx;
|
||||||
|
float fs; // freq_scale
|
||||||
|
float ef; // ext_factor
|
||||||
|
float af; // attn_factor
|
||||||
bool ff;
|
bool ff;
|
||||||
|
int v; // view (1 : non-contiguous a)
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR6(type, ne, n_dims, mode, n_ctx, ff);
|
return VARS_TO_STR10(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_rope(ggml_type type = GGML_TYPE_F32,
|
test_rope(ggml_type type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne = {10, 10, 10, 1},
|
std::array<int64_t, 4> ne_a = {10, 10, 10, 1},
|
||||||
int n_dims = 10, int mode = 0, int n_ctx = 512, bool ff = false)
|
int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f, float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0)
|
||||||
: type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx), ff(ff) {}
|
: type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
ggml_tensor * a;
|
||||||
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
|
if (v & 1) {
|
||||||
|
auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
|
||||||
|
a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
|
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
|
||||||
|
} else {
|
||||||
|
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||||
|
}
|
||||||
|
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
|
||||||
ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
|
ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
|
||||||
ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
|
ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1165,11 +1176,11 @@ struct test_rope : public test_case {
|
||||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||||
if (t->type == GGML_TYPE_I32) {
|
if (t->type == GGML_TYPE_I32) {
|
||||||
// pos
|
// pos
|
||||||
std::vector<int> data(ne[2]);
|
std::vector<int> data(ne_a[2]);
|
||||||
for (int i = 0; i < ne[2]; i++) {
|
for (int i = 0; i < ne_a[2]; i++) {
|
||||||
data[i] = rand() % n_ctx;
|
data[i] = rand() % n_ctx;
|
||||||
}
|
}
|
||||||
ggml_backend_tensor_set(t, data.data(), 0, ne[2] * sizeof(int));
|
ggml_backend_tensor_set(t, data.data(), 0, ne_a[2] * sizeof(int));
|
||||||
} else {
|
} else {
|
||||||
if (t->ne[0] == n_dims/2) {
|
if (t->ne[0] == n_dims/2) {
|
||||||
// frequency factors in the range [0.9f, 1.1f]
|
// frequency factors in the range [0.9f, 1.1f]
|
||||||
|
@ -1262,22 +1273,37 @@ struct test_concat : public test_case {
|
||||||
const std::array<int64_t, 4> ne_a;
|
const std::array<int64_t, 4> ne_a;
|
||||||
const int64_t ne_b_d;
|
const int64_t ne_b_d;
|
||||||
const int dim;
|
const int dim;
|
||||||
|
const int v; // view (1 << 0: non-cont a, 1 << 1: non-cont b)
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR4(type, ne_a, ne_b_d, dim);
|
return VARS_TO_STR5(type, ne_a, ne_b_d, dim, v);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_concat(ggml_type type = GGML_TYPE_F32,
|
test_concat(ggml_type type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne_a = {10, 10, 10, 10},
|
std::array<int64_t, 4> ne_a = {10, 10, 10, 10},
|
||||||
int64_t ne_b_d = 10,
|
int64_t ne_b_d = 10,
|
||||||
int dim = 2)
|
int dim = 2, int v = 0)
|
||||||
: type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim) {}
|
: type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim), v(v) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
auto ne_b = ne_a;
|
auto ne_b = ne_a;
|
||||||
ne_b[dim] = ne_b_d;
|
ne_b[dim] = ne_b_d;
|
||||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
ggml_tensor * a;
|
||||||
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
|
if (v & 1) {
|
||||||
|
auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
|
||||||
|
a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
|
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
|
||||||
|
} else {
|
||||||
|
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||||
|
}
|
||||||
|
ggml_tensor * b;
|
||||||
|
if (v & 2) {
|
||||||
|
auto ne = ne_b; ne[0] *= 3; ne[1] *= 2; ne[2] *= 4;
|
||||||
|
b = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
|
b = ggml_view_4d(ctx, b, ne_b[0], ne_b[1], ne_b[2], ne_b[3], b->nb[1], b->nb[2], b->nb[3], 0);
|
||||||
|
} else {
|
||||||
|
b = ggml_new_tensor(ctx, type, 4, ne_b.data());
|
||||||
|
}
|
||||||
ggml_tensor * out = ggml_concat(ctx, a, b, dim);
|
ggml_tensor * out = ggml_concat(ctx, a, b, dim);
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
@ -2198,26 +2224,46 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
|
||||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
|
||||||
|
|
||||||
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
{
|
||||||
// TODO: ff not supported yet for !neox
|
bool all = true;
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, false)); // llama 7B
|
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, false)); // llama 13B
|
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, false)); // llama 30B
|
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, false)); // llama 65B
|
|
||||||
|
|
||||||
for (bool ff : {false, true}) { // freq_factors
|
for (float v : { 0, 1 }) {
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, ff)); // neox (falcon 7B)
|
for (float fs : { 1.0f, 1.4245f }) {
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, ff)); // neox (falcon 7B)
|
for (float ef : { 0.0f, 0.7465f }) {
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512, ff)); // neox (falcon 40B)
|
for (float af : { 1.0f, 1.4245f }) {
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, ff)); // neox (falcon 40B)
|
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512, ff)); // neox (stablelm)
|
// TODO: ff not supported yet for !neox
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512, ff)); // neox (phi-2)
|
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 7B
|
||||||
|
if (all) {
|
||||||
|
test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 13B
|
||||||
|
test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 30B
|
||||||
|
test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 65B
|
||||||
|
}
|
||||||
|
|
||||||
|
for (bool ff : {false, true}) { // freq_factors
|
||||||
|
if (all) {
|
||||||
|
test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
|
||||||
|
test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
|
||||||
|
test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
|
||||||
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512, fs, ef, af, ff, v)); // neox (stablelm)
|
||||||
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512, fs, ef, af, ff, v)); // neox (phi-2)
|
||||||
|
}
|
||||||
|
|
||||||
|
test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
all = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int dim : { 0, 1, 2, 3, }) {
|
for (int v : { 0, 1, 2, 3 }) {
|
||||||
test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim));
|
for (int dim : { 0, 1, 2, 3, }) {
|
||||||
test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim));
|
test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim, v));
|
||||||
|
test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim, v));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {
|
for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue