Merge branch 'master' into auto-model-support

This commit is contained in:
teleprint-me 2024-05-29 12:46:06 -04:00
commit de0f0d0016
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48
9 changed files with 186 additions and 53 deletions

View file

@ -55,8 +55,8 @@ It has the similar design of other llama.cpp BLAS-based paths such as *OpenBLAS,
## OS ## OS
| OS | Status | Verified | | OS | Status | Verified |
|---------|---------|------------------------------------| |---------|---------|------------------------------------------------|
| Linux | Support | Ubuntu 22.04, Fedora Silverblue 39 | | Linux | Support | Ubuntu 22.04, Fedora Silverblue 39, Arch Linux |
| Windows | Support | Windows 11 | | Windows | Support | Windows 11 |
@ -70,7 +70,7 @@ It has the similar design of other llama.cpp BLAS-based paths such as *OpenBLAS,
|-------------------------------|---------|---------------------------------------| |-------------------------------|---------|---------------------------------------|
| Intel Data Center Max Series | Support | Max 1550, 1100 | | Intel Data Center Max Series | Support | Max 1550, 1100 |
| Intel Data Center Flex Series | Support | Flex 170 | | Intel Data Center Flex Series | Support | Flex 170 |
| Intel Arc Series | Support | Arc 770, 730M | | Intel Arc Series | Support | Arc 770, 730M, Arc A750 |
| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake | | Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake |
| Intel iGPU | Support | iGPU in i5-1250P, i7-1260P, i7-1165G7 | | Intel iGPU | Support | iGPU in i5-1250P, i7-1260P, i7-1165G7 |

View file

@ -178,6 +178,7 @@ struct cmd_params {
std::vector<ggml_type> type_v; std::vector<ggml_type> type_v;
std::vector<int> n_threads; std::vector<int> n_threads;
std::vector<int> n_gpu_layers; std::vector<int> n_gpu_layers;
std::vector<std::string> rpc_servers;
std::vector<llama_split_mode> split_mode; std::vector<llama_split_mode> split_mode;
std::vector<int> main_gpu; std::vector<int> main_gpu;
std::vector<bool> no_kv_offload; std::vector<bool> no_kv_offload;
@ -202,6 +203,7 @@ static const cmd_params cmd_params_defaults = {
/* type_v */ {GGML_TYPE_F16}, /* type_v */ {GGML_TYPE_F16},
/* n_threads */ {cpu_get_num_math()}, /* n_threads */ {cpu_get_num_math()},
/* n_gpu_layers */ {99}, /* n_gpu_layers */ {99},
/* rpc_servers */ {""},
/* split_mode */ {LLAMA_SPLIT_MODE_LAYER}, /* split_mode */ {LLAMA_SPLIT_MODE_LAYER},
/* main_gpu */ {0}, /* main_gpu */ {0},
/* no_kv_offload */ {false}, /* no_kv_offload */ {false},
@ -230,6 +232,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -ctv, --cache-type-v <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str()); printf(" -ctv, --cache-type-v <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str());
printf(" -t, --threads <n> (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str()); printf(" -t, --threads <n> (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str());
printf(" -ngl, --n-gpu-layers <n> (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str()); printf(" -ngl, --n-gpu-layers <n> (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str());
printf(" -rpc, --rpc <rpc_servers> (default: %s)\n", join(cmd_params_defaults.rpc_servers, ",").c_str());
printf(" -sm, --split-mode <none|layer|row> (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str()); printf(" -sm, --split-mode <none|layer|row> (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
@ -384,6 +387,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
} }
auto p = split<int>(argv[i], split_delim); auto p = split<int>(argv[i], split_delim);
params.n_gpu_layers.insert(params.n_gpu_layers.end(), p.begin(), p.end()); params.n_gpu_layers.insert(params.n_gpu_layers.end(), p.begin(), p.end());
} else if (arg == "-rpc" || arg == "--rpc") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.rpc_servers.push_back(argv[i]);
} else if (arg == "-sm" || arg == "--split-mode") { } else if (arg == "-sm" || arg == "--split-mode") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -519,6 +528,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.type_k.empty()) { params.type_k = cmd_params_defaults.type_k; } if (params.type_k.empty()) { params.type_k = cmd_params_defaults.type_k; }
if (params.type_v.empty()) { params.type_v = cmd_params_defaults.type_v; } if (params.type_v.empty()) { params.type_v = cmd_params_defaults.type_v; }
if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; } if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; }
if (params.rpc_servers.empty()) { params.rpc_servers = cmd_params_defaults.rpc_servers; }
if (params.split_mode.empty()) { params.split_mode = cmd_params_defaults.split_mode; } if (params.split_mode.empty()) { params.split_mode = cmd_params_defaults.split_mode; }
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; } if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; } if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
@ -541,6 +551,7 @@ struct cmd_params_instance {
ggml_type type_v; ggml_type type_v;
int n_threads; int n_threads;
int n_gpu_layers; int n_gpu_layers;
std::string rpc_servers;
llama_split_mode split_mode; llama_split_mode split_mode;
int main_gpu; int main_gpu;
bool no_kv_offload; bool no_kv_offload;
@ -553,6 +564,9 @@ struct cmd_params_instance {
llama_model_params mparams = llama_model_default_params(); llama_model_params mparams = llama_model_default_params();
mparams.n_gpu_layers = n_gpu_layers; mparams.n_gpu_layers = n_gpu_layers;
if (!rpc_servers.empty()) {
mparams.rpc_servers = rpc_servers.c_str();
}
mparams.split_mode = split_mode; mparams.split_mode = split_mode;
mparams.main_gpu = main_gpu; mparams.main_gpu = main_gpu;
mparams.tensor_split = tensor_split.data(); mparams.tensor_split = tensor_split.data();
@ -564,6 +578,7 @@ struct cmd_params_instance {
bool equal_mparams(const cmd_params_instance & other) const { bool equal_mparams(const cmd_params_instance & other) const {
return model == other.model && return model == other.model &&
n_gpu_layers == other.n_gpu_layers && n_gpu_layers == other.n_gpu_layers &&
rpc_servers == other.rpc_servers &&
split_mode == other.split_mode && split_mode == other.split_mode &&
main_gpu == other.main_gpu && main_gpu == other.main_gpu &&
use_mmap == other.use_mmap && use_mmap == other.use_mmap &&
@ -592,6 +607,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
// this ordering minimizes the number of times that each model needs to be reloaded // this ordering minimizes the number of times that each model needs to be reloaded
for (const auto & m : params.model) for (const auto & m : params.model)
for (const auto & nl : params.n_gpu_layers) for (const auto & nl : params.n_gpu_layers)
for (const auto & rpc : params.rpc_servers)
for (const auto & sm : params.split_mode) for (const auto & sm : params.split_mode)
for (const auto & mg : params.main_gpu) for (const auto & mg : params.main_gpu)
for (const auto & ts : params.tensor_split) for (const auto & ts : params.tensor_split)
@ -618,6 +634,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .type_v = */ tv, /* .type_v = */ tv,
/* .n_threads = */ nt, /* .n_threads = */ nt,
/* .n_gpu_layers = */ nl, /* .n_gpu_layers = */ nl,
/* .rpc_servers = */ rpc,
/* .split_mode = */ sm, /* .split_mode = */ sm,
/* .main_gpu = */ mg, /* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo, /* .no_kv_offload= */ nkvo,
@ -643,6 +660,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .type_v = */ tv, /* .type_v = */ tv,
/* .n_threads = */ nt, /* .n_threads = */ nt,
/* .n_gpu_layers = */ nl, /* .n_gpu_layers = */ nl,
/* .rpc_servers = */ rpc,
/* .split_mode = */ sm, /* .split_mode = */ sm,
/* .main_gpu = */ mg, /* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo, /* .no_kv_offload= */ nkvo,
@ -668,6 +686,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .type_v = */ tv, /* .type_v = */ tv,
/* .n_threads = */ nt, /* .n_threads = */ nt,
/* .n_gpu_layers = */ nl, /* .n_gpu_layers = */ nl,
/* .rpc_servers = */ rpc,
/* .split_mode = */ sm, /* .split_mode = */ sm,
/* .main_gpu = */ mg, /* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo, /* .no_kv_offload= */ nkvo,
@ -692,6 +711,7 @@ struct test {
static const bool kompute; static const bool kompute;
static const bool metal; static const bool metal;
static const bool sycl; static const bool sycl;
static const bool rpc;
static const bool gpu_blas; static const bool gpu_blas;
static const bool blas; static const bool blas;
static const std::string cpu_info; static const std::string cpu_info;
@ -790,6 +810,9 @@ struct test {
if (sycl) { if (sycl) {
return GGML_SYCL_NAME; return GGML_SYCL_NAME;
} }
if (rpc) {
return "RPC";
}
if (gpu_blas) { if (gpu_blas) {
return "GPU BLAS"; return "GPU BLAS";
} }
@ -803,7 +826,7 @@ struct test {
static const std::vector<std::string> & get_fields() { static const std::vector<std::string> & get_fields() {
static const std::vector<std::string> fields = { static const std::vector<std::string> fields = {
"build_commit", "build_number", "build_commit", "build_number",
"cuda", "opencl", "vulkan", "kompute", "metal", "sycl", "gpu_blas", "blas", "cuda", "opencl", "vulkan", "kompute", "metal", "sycl", "rpc", "gpu_blas", "blas",
"cpu_info", "gpu_info", "cpu_info", "gpu_info",
"model_filename", "model_type", "model_size", "model_n_params", "model_filename", "model_type", "model_size", "model_n_params",
"n_batch", "n_ubatch", "n_batch", "n_ubatch",
@ -859,7 +882,7 @@ struct test {
std::vector<std::string> values = { std::vector<std::string> values = {
build_commit, std::to_string(build_number), build_commit, std::to_string(build_number),
std::to_string(cuda), std::to_string(opencl), std::to_string(vulkan), std::to_string(vulkan), std::to_string(cuda), std::to_string(opencl), std::to_string(vulkan), std::to_string(vulkan),
std::to_string(metal), std::to_string(sycl), std::to_string(gpu_blas), std::to_string(blas), std::to_string(metal), std::to_string(sycl), std::to_string(rpc), std::to_string(gpu_blas), std::to_string(blas),
cpu_info, gpu_info, cpu_info, gpu_info,
model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params), model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
std::to_string(n_batch), std::to_string(n_ubatch), std::to_string(n_batch), std::to_string(n_ubatch),
@ -894,6 +917,7 @@ const bool test::metal = !!ggml_cpu_has_metal();
const bool test::gpu_blas = !!ggml_cpu_has_gpublas(); const bool test::gpu_blas = !!ggml_cpu_has_gpublas();
const bool test::blas = !!ggml_cpu_has_blas(); const bool test::blas = !!ggml_cpu_has_blas();
const bool test::sycl = !!ggml_cpu_has_sycl(); const bool test::sycl = !!ggml_cpu_has_sycl();
const bool test::rpc = !!ggml_cpu_has_rpc();
const std::string test::cpu_info = get_cpu_info(); const std::string test::cpu_info = get_cpu_info();
const std::string test::gpu_info = get_gpu_info(); const std::string test::gpu_info = get_gpu_info();

View file

@ -1,5 +1,6 @@
#include "concat.cuh" #include "concat.cuh"
// contiguous kernels
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) { static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x; int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) { if (nidx >= ne0) {
@ -92,25 +93,77 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, int n
concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02); concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
} }
// non-contiguous kernel (slow)
static __global__ void concat_f32_non_cont(
const char * src0,
const char * src1,
char * dst,
int64_t ne00,
int64_t ne01,
int64_t ne02,
int64_t ne03,
uint64_t nb00,
uint64_t nb01,
uint64_t nb02,
uint64_t nb03,
int64_t /*ne10*/,
int64_t /*ne11*/,
int64_t /*ne12*/,
int64_t /*ne13*/,
uint64_t nb10,
uint64_t nb11,
uint64_t nb12,
uint64_t nb13,
int64_t ne0,
int64_t /*ne1*/,
int64_t /*ne2*/,
int64_t /*ne3*/,
uint64_t nb0,
uint64_t nb1,
uint64_t nb2,
uint64_t nb3,
int32_t dim) {
const int64_t i3 = blockIdx.z;
const int64_t i2 = blockIdx.y;
const int64_t i1 = blockIdx.x;
int64_t o[4] = {0, 0, 0, 0};
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
const float * x;
for (int i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
} else {
x = (const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
}
float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*y = *x;
}
}
void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * src1 = dst->src[1];
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream(); cudaStream_t stream = ctx.stream();
const int32_t dim = ((int32_t *) dst->op_params)[0]; const int32_t dim = ((int32_t *) dst->op_params)[0];
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32);
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
if (dim != 3) { if (dim != 3) {
for (int i3 = 0; i3 < dst->ne[3]; i3++) { for (int i3 = 0; i3 < dst->ne[3]; i3++) {
concat_f32_cuda( concat_f32_cuda(
@ -127,4 +180,17 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
} }
} else {
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
(const char *)src0->data,
(const char *)src1->data,
( char *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim);
}
} }

41
ggml.c
View file

@ -60,6 +60,9 @@
typedef volatile LONG atomic_int; typedef volatile LONG atomic_int;
typedef atomic_int atomic_bool; typedef atomic_int atomic_bool;
typedef atomic_int atomic_flag;
#define ATOMIC_FLAG_INIT 0
static void atomic_store(atomic_int * ptr, LONG val) { static void atomic_store(atomic_int * ptr, LONG val) {
InterlockedExchange(ptr, val); InterlockedExchange(ptr, val);
@ -73,6 +76,12 @@ static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) { static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
return atomic_fetch_add(ptr, -(dec)); return atomic_fetch_add(ptr, -(dec));
} }
static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
return InterlockedExchange(ptr, 1);
}
static void atomic_flag_clear(atomic_flag * ptr) {
InterlockedExchange(ptr, 0);
}
typedef HANDLE pthread_t; typedef HANDLE pthread_t;
@ -2883,24 +2892,20 @@ struct ggml_state {
// global state // global state
static struct ggml_state g_state; static struct ggml_state g_state;
static atomic_int g_state_barrier = 0; static atomic_flag g_state_critical = ATOMIC_FLAG_INIT;
// barrier via spin lock // barrier via spin lock
inline static void ggml_critical_section_start(void) { inline static void ggml_critical_section_start(void) {
int processing = atomic_fetch_add(&g_state_barrier, 1); while (atomic_flag_test_and_set(&g_state_critical)) {
// spin
while (processing > 0) { sched_yield();
// wait for other threads to finish
atomic_fetch_sub(&g_state_barrier, 1);
sched_yield(); // TODO: reconsider this
processing = atomic_fetch_add(&g_state_barrier, 1);
} }
} }
// TODO: make this somehow automatically executed // TODO: make this somehow automatically executed
// some sort of "sentry" mechanism // some sort of "sentry" mechanism
inline static void ggml_critical_section_end(void) { inline static void ggml_critical_section_end(void) {
atomic_fetch_sub(&g_state_barrier, 1); atomic_flag_clear(&g_state_critical);
} }
#if defined(__gnu_linux__) #if defined(__gnu_linux__)
@ -6392,6 +6397,16 @@ struct ggml_tensor * ggml_rope_custom_inplace(
); );
} }
struct ggml_tensor * ggml_rope_xpos_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int n_dims,
float base,
bool down) {
return ggml_rope_impl(ctx, a, b, NULL, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
}
// ggml_rope_back // ggml_rope_back
struct ggml_tensor * ggml_rope_back( struct ggml_tensor * ggml_rope_back(
@ -22857,6 +22872,14 @@ int ggml_cpu_has_sycl(void) {
#endif #endif
} }
int ggml_cpu_has_rpc(void) {
#if defined(GGML_USE_RPC)
return 1;
#else
return 0;
#endif
}
int ggml_cpu_has_gpublas(void) { int ggml_cpu_has_gpublas(void) {
return ggml_cpu_has_cuda() || ggml_cpu_has_clblast() || ggml_cpu_has_vulkan() || ggml_cpu_has_kompute() || return ggml_cpu_has_cuda() || ggml_cpu_has_clblast() || ggml_cpu_has_vulkan() || ggml_cpu_has_kompute() ||
ggml_cpu_has_sycl(); ggml_cpu_has_sycl();

9
ggml.h
View file

@ -1548,6 +1548,14 @@ extern "C" {
float beta_slow), float beta_slow),
"use ggml_rope_ext_inplace instead"); "use ggml_rope_ext_inplace instead");
struct ggml_tensor * ggml_rope_xpos_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int n_dims,
float base,
bool down);
// compute correction dims for YaRN RoPE scaling // compute correction dims for YaRN RoPE scaling
GGML_CALL void ggml_rope_yarn_corr_dims( GGML_CALL void ggml_rope_yarn_corr_dims(
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]); int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
@ -2420,6 +2428,7 @@ extern "C" {
GGML_API int ggml_cpu_has_sse3 (void); GGML_API int ggml_cpu_has_sse3 (void);
GGML_API int ggml_cpu_has_ssse3 (void); GGML_API int ggml_cpu_has_ssse3 (void);
GGML_API int ggml_cpu_has_sycl (void); GGML_API int ggml_cpu_has_sycl (void);
GGML_API int ggml_cpu_has_rpc (void);
GGML_API int ggml_cpu_has_vsx (void); GGML_API int ggml_cpu_has_vsx (void);
GGML_API int ggml_cpu_has_matmul_int8(void); GGML_API int ggml_cpu_has_matmul_int8(void);

View file

@ -106,8 +106,6 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
# src/ggml-kompute.h -> ggml-kompute.h # src/ggml-kompute.h -> ggml-kompute.h
# src/ggml-metal.h -> ggml-metal.h # src/ggml-metal.h -> ggml-metal.h
# src/ggml-metal.m -> ggml-metal.m # src/ggml-metal.m -> ggml-metal.m
# src/ggml-mpi.h -> ggml-mpi.h
# src/ggml-mpi.c -> ggml-mpi.c
# src/ggml-opencl.cpp -> ggml-opencl.cpp # src/ggml-opencl.cpp -> ggml-opencl.cpp
# src/ggml-opencl.h -> ggml-opencl.h # src/ggml-opencl.h -> ggml-opencl.h
# src/ggml-quants.c -> ggml-quants.c # src/ggml-quants.c -> ggml-quants.c
@ -145,8 +143,6 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
-e 's/src\/ggml-kompute\.h/ggml-kompute.h/g' \ -e 's/src\/ggml-kompute\.h/ggml-kompute.h/g' \
-e 's/src\/ggml-metal\.h/ggml-metal.h/g' \ -e 's/src\/ggml-metal\.h/ggml-metal.h/g' \
-e 's/src\/ggml-metal\.m/ggml-metal.m/g' \ -e 's/src\/ggml-metal\.m/ggml-metal.m/g' \
-e 's/src\/ggml-mpi\.h/ggml-mpi.h/g' \
-e 's/src\/ggml-mpi\.c/ggml-mpi.c/g' \
-e 's/src\/ggml-opencl\.cpp/ggml-opencl.cpp/g' \ -e 's/src\/ggml-opencl\.cpp/ggml-opencl.cpp/g' \
-e 's/src\/ggml-opencl\.h/ggml-opencl.h/g' \ -e 's/src\/ggml-opencl\.h/ggml-opencl.h/g' \
-e 's/src\/ggml-quants\.c/ggml-quants.c/g' \ -e 's/src\/ggml-quants\.c/ggml-quants.c/g' \

View file

@ -1 +1 @@
126d34985705a5a2222723c145cb4e125ac689f3 2aae01fd9b8f9399f343cf18f46f38996ef52e2c

View file

@ -14,8 +14,6 @@ cp -rpv ../ggml/src/ggml-kompute.h ./ggml-kompute.h
cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h
cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
cp -rpv ../ggml/src/ggml-mpi.h ./ggml-mpi.h
cp -rpv ../ggml/src/ggml-mpi.c ./ggml-mpi.c
cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp
cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
cp -rpv ../ggml/src/ggml-quants.c ./ggml-quants.c cp -rpv ../ggml/src/ggml-quants.c ./ggml-quants.c

View file

@ -1262,22 +1262,37 @@ struct test_concat : public test_case {
const std::array<int64_t, 4> ne_a; const std::array<int64_t, 4> ne_a;
const int64_t ne_b_d; const int64_t ne_b_d;
const int dim; const int dim;
const int v; // view (1 << 0: non-cont a, 1 << 1: non-cont b)
std::string vars() override { std::string vars() override {
return VARS_TO_STR4(type, ne_a, ne_b_d, dim); return VARS_TO_STR5(type, ne_a, ne_b_d, dim, v);
} }
test_concat(ggml_type type = GGML_TYPE_F32, test_concat(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne_a = {10, 10, 10, 10}, std::array<int64_t, 4> ne_a = {10, 10, 10, 10},
int64_t ne_b_d = 10, int64_t ne_b_d = 10,
int dim = 2) int dim = 2, int v = 0)
: type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim) {} : type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
auto ne_b = ne_a; auto ne_b = ne_a;
ne_b[dim] = ne_b_d; ne_b[dim] = ne_b_d;
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); ggml_tensor * a;
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data()); if (v & 1) {
auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
a = ggml_new_tensor(ctx, type, 4, ne.data());
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
} else {
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
}
ggml_tensor * b;
if (v & 2) {
auto ne = ne_b; ne[0] *= 3; ne[1] *= 2; ne[2] *= 4;
b = ggml_new_tensor(ctx, type, 4, ne.data());
b = ggml_view_4d(ctx, b, ne_b[0], ne_b[1], ne_b[2], ne_b[3], b->nb[1], b->nb[2], b->nb[3], 0);
} else {
b = ggml_new_tensor(ctx, type, 4, ne_b.data());
}
ggml_tensor * out = ggml_concat(ctx, a, b, dim); ggml_tensor * out = ggml_concat(ctx, a, b, dim);
return out; return out;
} }
@ -2215,9 +2230,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
} }
} }
for (int v : { 0, 1, 2, 3 }) {
for (int dim : { 0, 1, 2, 3, }) { for (int dim : { 0, 1, 2, 3, }) {
test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim)); test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim, v));
test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim)); test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim, v));
}
} }
for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) { for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {