diff --git a/Package.swift b/Package.swift index 18d610d69..e33a4ff46 100644 --- a/Package.swift +++ b/Package.swift @@ -13,21 +13,17 @@ let package = Package( products: [ .library(name: "llama", targets: ["llama"]), ], + dependencies: [ + .package(url: "https://github.com/ggerganov/ggml.git", .branch("master")) + ], targets: [ .target( name: "llama", + dependencies: ["ggml"], path: ".", exclude: [], sources: [ - "ggml.c", "llama.cpp", - "ggml-alloc.c", - "ggml-backend.c", - "ggml-quants.c", - "ggml-metal.m", - ], - resources: [ - .process("ggml-metal.metal") ], publicHeadersPath: "spm-headers", cSettings: [ diff --git a/common/train.cpp b/common/train.cpp index dcf9614e4..e6f2f7a2f 100644 --- a/common/train.cpp +++ b/common/train.cpp @@ -1107,7 +1107,7 @@ void print_common_train_usage(int /*argc*/, char ** /*argv*/, const struct train fprintf(stderr, " --sample-start STR Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str()); fprintf(stderr, " --include-sample-start Include the sample start in the samples. (default off)\n"); fprintf(stderr, " --escape process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); - fprintf(stderr, " --overlapping-samples Samples my overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n"); + fprintf(stderr, " --overlapping-samples Samples may overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n"); fprintf(stderr, " --fill-with-next-samples Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n"); fprintf(stderr, " --separate-with-eos When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : ""); fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : ""); diff --git a/examples/llama.swiftui/llama.swiftui.xcodeproj/project.pbxproj b/examples/llama.swiftui/llama.swiftui.xcodeproj/project.pbxproj index fa6d4d3b2..a70750a22 100644 --- a/examples/llama.swiftui/llama.swiftui.xcodeproj/project.pbxproj +++ b/examples/llama.swiftui/llama.swiftui.xcodeproj/project.pbxproj @@ -9,7 +9,6 @@ /* Begin PBXBuildFile section */ 542376082B0D9BFB008E6A1C /* ggml-quants.c in Sources */ = {isa = PBXBuildFile; fileRef = 542376072B0D9BFB008E6A1C /* ggml-quants.c */; settings = {COMPILER_FLAGS = "-O3"; }; }; 5423760B2B0D9C4B008E6A1C /* ggml-backend.c in Sources */ = {isa = PBXBuildFile; fileRef = 5423760A2B0D9C4B008E6A1C /* ggml-backend.c */; settings = {COMPILER_FLAGS = "-O3"; }; }; - 542378792ACE3F3500834A7B /* ggml-metal.metal in Resources */ = {isa = PBXBuildFile; fileRef = 549479C82AC9E10B00E0F78B /* ggml-metal.metal */; }; 542EA09D2AC8723900A8AEE9 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 542EA09B2AC8723900A8AEE9 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE -DGGML_USE_METAL -DGGML_USE_K_QUANTS -O3"; }; }; 542EA0A02AC8725700A8AEE9 /* ggml-alloc.c in Sources */ = {isa = PBXBuildFile; fileRef = 542EA09F2AC8725700A8AEE9 /* ggml-alloc.c */; settings = {COMPILER_FLAGS = "-O3"; }; }; 542EA0A32AC8729100A8AEE9 /* llama.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 542EA0A12AC8729100A8AEE9 /* llama.cpp */; settings = {COMPILER_FLAGS = "-DGGML_USE_K_QUANTS -DGGML_USE_METAL -O3"; }; }; @@ -25,8 +24,25 @@ 8A907F332AC7138A006146EA /* LibLlama.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8A907F322AC7134E006146EA /* LibLlama.swift */; }; 8A9F7C4D2AC332EE008AE1EA /* LlamaState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8A9F7C4C2AC332EE008AE1EA /* LlamaState.swift */; }; F1FE20E22B465ECA00B45541 /* LoadCustomButton.swift in Sources */ = {isa = PBXBuildFile; fileRef = F1FE20E12B465EC900B45541 /* LoadCustomButton.swift */; }; + F1FE20DC2B465C4500B45541 /* ggml-metal.metal in Resources */ = {isa = PBXBuildFile; fileRef = 549479C82AC9E10B00E0F78B /* ggml-metal.metal */; }; /* End PBXBuildFile section */ +/* Begin PBXBuildRule section */ + F1FE20DB2B465C2100B45541 /* PBXBuildRule */ = { + isa = PBXBuildRule; + compilerSpec = com.apple.compilers.proxy.script; + fileType = sourcecode.metal; + inputFiles = ( + ); + isEditable = 1; + outputFiles = ( + "${DERIVED_FILES_DIR}/ggml-metal.air", + "${DERIVED_FILES_DIR}/ggml.metallib", + ); + script = "# metal\nxcrun metal -c \"${INPUT_FILE_PATH}\" -o \"${DERIVED_FILES_DIR}/${INPUT_FILE_BASE}.air\"\nxcrun metallib -o \"${DERIVED_FILES_DIR}/${INPUT_FILE_BASE%-metal}.metallib\" \"${DERIVED_FILES_DIR}/${INPUT_FILE_BASE}.air\"\n"; + }; +/* End PBXBuildRule section */ + /* Begin PBXFileReference section */ 542376062B0D9BEA008E6A1C /* ggml-quants.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = "ggml-quants.h"; path = "../../ggml-quants.h"; sourceTree = ""; }; 542376072B0D9BFB008E6A1C /* ggml-quants.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = "ggml-quants.c"; path = "../../ggml-quants.c"; sourceTree = ""; }; @@ -193,6 +209,7 @@ 8A1C83712AC328BD0096AF73 /* Resources */, ); buildRules = ( + F1FE20DB2B465C2100B45541 /* PBXBuildRule */, ); dependencies = ( ); @@ -244,7 +261,7 @@ isa = PBXResourcesBuildPhase; buildActionMask = 2147483647; files = ( - 542378792ACE3F3500834A7B /* ggml-metal.metal in Resources */, + F1FE20DC2B465C4500B45541 /* ggml-metal.metal in Resources */, 8A3F84242AC4C891005E2EE8 /* models in Resources */, 8A1C837E2AC328BE0096AF73 /* Preview Assets.xcassets in Resources */, 8A1C837B2AC328BE0096AF73 /* Assets.xcassets in Resources */, diff --git a/examples/server/public/completion.js b/examples/server/public/completion.js index 6e2b99565..baaec1d60 100644 --- a/examples/server/public/completion.js +++ b/examples/server/public/completion.js @@ -95,6 +95,15 @@ export async function* llama(prompt, params = {}, config = {}) { break; } } + if (result.error) { + result.error = JSON.parse(result.error); + if (result.error.content.includes('slot unavailable')) { + // Throw an error to be caught by upstream callers + throw new Error('slot unavailable'); + } else { + console.error(`llama.cpp error: ${result.error.content}`); + } + } if (result.error) { result.error = JSON.parse(result.error); console.error(`llama.cpp error: ${result.error.content}`); diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 8c2712308..52d3cc6a6 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -10039,14 +10039,19 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten } return false; } break; + case GGML_OP_DUP: + case GGML_OP_REPEAT: + case GGML_OP_CONCAT: + { + ggml_type src0_type = op->src[0]->type; + return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; + } break; case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: case GGML_OP_NORM: - case GGML_OP_REPEAT: - case GGML_OP_DUP: case GGML_OP_ADD: case GGML_OP_MUL: case GGML_OP_DIV: @@ -10063,7 +10068,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: - case GGML_OP_CONCAT: case GGML_OP_GROUP_NORM: case GGML_OP_UPSCALE: case GGML_OP_PAD: diff --git a/ggml-impl.h b/ggml-impl.h index 1f5610a86..2faced080 100644 --- a/ggml-impl.h +++ b/ggml-impl.h @@ -5,6 +5,7 @@ // GGML internal header #include +#include // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/ #include #include #include // memcpy diff --git a/ggml-metal.m b/ggml-metal.m index 7a369b55e..7aa92c14c 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -87,6 +87,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(get_rows_q4_K); GGML_METAL_DECL_KERNEL(get_rows_q5_K); GGML_METAL_DECL_KERNEL(get_rows_q6_K); + GGML_METAL_DECL_KERNEL(get_rows_i32); GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(group_norm); GGML_METAL_DECL_KERNEL(norm); @@ -377,6 +378,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(get_rows_q4_K); GGML_METAL_ADD_KERNEL(get_rows_q5_K); GGML_METAL_ADD_KERNEL(get_rows_q6_K); + GGML_METAL_ADD_KERNEL(get_rows_i32); GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(group_norm); GGML_METAL_ADD_KERNEL(norm); @@ -499,6 +501,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(get_rows_q4_K); GGML_METAL_DEL_KERNEL(get_rows_q5_K); GGML_METAL_DEL_KERNEL(get_rows_q6_K); + GGML_METAL_DEL_KERNEL(get_rows_i32); GGML_METAL_DEL_KERNEL(rms_norm); GGML_METAL_DEL_KERNEL(group_norm); GGML_METAL_DEL_KERNEL(norm); @@ -1978,6 +1981,7 @@ void ggml_metal_graph_compute( case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break; case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break; case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break; + case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break; default: GGML_ASSERT(false && "not implemented"); } diff --git a/ggml-metal.metal b/ggml-metal.metal index 9aa7b502a..a7d3f9efa 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -3829,6 +3829,35 @@ kernel void kernel_get_rows_f16( } } +kernel void kernel_get_rows_i32( + device const void * src0, + device const char * src1, + device int32_t * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + + #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B #define BLOCK_SIZE_K 32 diff --git a/ggml.c b/ggml.c index bcec200f6..b124f14cc 100644 --- a/ggml.c +++ b/ggml.c @@ -4766,8 +4766,11 @@ struct ggml_tensor * ggml_get_rows( } // TODO: implement non F32 return - //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); - struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]); + enum ggml_type type = GGML_TYPE_F32; + if (a->type == GGML_TYPE_I32) { + type = a->type; + } + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]); result->op = GGML_OP_GET_ROWS; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -6938,14 +6941,165 @@ static void ggml_compute_forward_dup_f32( } } +// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy. +static void ggml_compute_forward_dup_bytes( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_ASSERT(src0->type == dst->type); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) { + ggml_compute_forward_dup_same_cont(params, src0, dst); + return; + } + + GGML_TENSOR_UNARY_OP_LOCALS; + + const size_t type_size = ggml_type_size(src0->type); + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (src0->type == dst->type && + ne00 == ne0 && + nb00 == type_size && nb0 == type_size) { + // copy by rows + const size_t rs = ne00 * type_size; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ir0; i01 < ir1; i01++) { + memcpy( + ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + rs); + } + } + } + return; + } + + if (ggml_is_contiguous(dst)) { + size_t id = 0; + char * dst_ptr = (char *) dst->data; + const size_t rs = ne00 * type_size; + + if (nb00 == type_size) { + // src0 is contigous on first dimension, copy by rows + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int64_t i01 = ir0; i01 < ir1; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, rs); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, type_size); + + id += type_size; + } + } + id += rs * (ne01 - ir1); + } + } + } + + return; + } + + // dst counters + + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + memcpy(dst_ptr, src0_ptr, type_size); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } +} + static void ggml_compute_forward_dup( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { - if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) { - ggml_compute_forward_dup_same_cont(params, src0, dst); + if (src0->type == dst->type) { + ggml_compute_forward_dup_bytes(params, src0, dst); return; } + switch (src0->type) { case GGML_TYPE_F16: { @@ -8404,10 +8558,12 @@ static void ggml_compute_forward_repeat( struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: + case GGML_TYPE_I16: { ggml_compute_forward_repeat_f16(params, src0, dst); } break; case GGML_TYPE_F32: + case GGML_TYPE_I32: { ggml_compute_forward_repeat_f32(params, src0, dst); } break; @@ -8550,6 +8706,7 @@ static void ggml_compute_forward_concat( struct ggml_tensor* dst) { switch (src0->type) { case GGML_TYPE_F32: + case GGML_TYPE_I32: { ggml_compute_forward_concat_f32(params, src0, src1, dst); } break; @@ -10674,6 +10831,7 @@ static void ggml_compute_forward_get_rows( ggml_compute_forward_get_rows_f16(params, src0, src1, dst); } break; case GGML_TYPE_F32: + case GGML_TYPE_I32: { ggml_compute_forward_get_rows_f32(params, src0, src1, dst); } break; diff --git a/scripts/sync-ggml-am.sh b/scripts/sync-ggml-am.sh index 91478f177..248cf1023 100755 --- a/scripts/sync-ggml-am.sh +++ b/scripts/sync-ggml-am.sh @@ -27,7 +27,7 @@ echo "Syncing ggml changes since commit $lc" cd $SRC_GGML git log --oneline $lc..HEAD -git log --oneline $lc..HEAD | grep -v "(llama/[0-9]*)" | cut -d' ' -f1 > $SRC_LLAMA/ggml-commits +git log --oneline $lc..HEAD --reverse | grep -v "(llama/[0-9]*)" | cut -d' ' -f1 > $SRC_LLAMA/ggml-commits if [ ! -s $SRC_LLAMA/ggml-commits ]; then rm -v $SRC_LLAMA/ggml-commits @@ -87,7 +87,6 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then # src/ggml-impl.h -> ggml-impl.h # src/ggml-metal.h -> ggml-metal.h # src/ggml-metal.m -> ggml-metal.m - # src/ggml-metal.metal -> ggml-metal.metal # src/ggml-mpi.h -> ggml-mpi.h # src/ggml-mpi.c -> ggml-mpi.c # src/ggml-opencl.cpp -> ggml-opencl.cpp @@ -114,7 +113,6 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then -e 's/src\/ggml-impl\.h/ggml-impl.h/g' \ -e 's/src\/ggml-metal\.h/ggml-metal.h/g' \ -e 's/src\/ggml-metal\.m/ggml-metal.m/g' \ - -e 's/src\/ggml-metal\.metal/ggml-metal.metal/g' \ -e 's/src\/ggml-mpi\.h/ggml-mpi.h/g' \ -e 's/src\/ggml-mpi\.c/ggml-mpi.c/g' \ -e 's/src\/ggml-opencl\.cpp/ggml-opencl.cpp/g' \ diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 5b6a440f7..354246a26 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -df098ea908764cba4a4889a1cbe7b026b2d31a14 +3fd01e00e40583ccd4b393a7c6502d6a4455a1d5 diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index eff063b2d..44412cb94 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -58,6 +58,9 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m int64_t hist[16]; ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size, hist); ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size()); + } else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) { + // This is going to create some weird integers though. + ggml_backend_tensor_set(tensor, data.data(), 0, ggml_nbytes(tensor)); } else { GGML_ASSERT(false); } @@ -87,8 +90,13 @@ static std::vector tensor_to_float(const ggml_tensor * t) { tv.push_back(*(float *) &buf[i]); } else if (t->type == GGML_TYPE_I32) { tv.push_back((float)*(int32_t *) &buf[i]); + } else if (t->type == GGML_TYPE_I16) { + tv.push_back((float)*(int16_t *) &buf[i]); + } else if (t->type == GGML_TYPE_I8) { + tv.push_back((float)*(int8_t *) &buf[i]); } else if (quantized) { - tt.to_float(&buf[i], vq.data(), bs); + std::vector vq(ggml_blck_size(t->type)); + tt.to_float(&buf[i], vq.data(), ggml_blck_size(t->type)); tv.insert(tv.end(), vq.begin(), vq.end()); } else { GGML_ASSERT(false); @@ -661,17 +669,26 @@ struct test_repeat : public test_case { struct test_dup : public test_case { const ggml_type type; const std::array ne; + const std::array permute; + bool _use_permute; std::string vars() override { - return VARS_TO_STR2(type, ne); + std::string v = VARS_TO_STR2(type, ne); + if (_use_permute) v += "," + VAR_TO_STR(permute); + return v; } test_dup(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 10, 10, 1}) - : type(type), ne(ne) {} + std::array ne = {10, 10, 10, 1}, + std::array permute = {0, 0, 0, 0}) + : type(type), ne(ne), permute(permute), + _use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data()); + if (_use_permute) { + src = ggml_permute(ctx, src, permute[0], permute[1], permute[2], permute[3]); + } ggml_tensor * out = ggml_dup(ctx, src); return out; } @@ -1450,14 +1467,26 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } } + for (int b : {1, 7}) { + for (bool v : {false, true}) { + test_cases.emplace_back(new test_get_rows(GGML_TYPE_I32, 256, 5, 4, b, v)); + } + } test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 2, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 2})); + test_cases.emplace_back(new test_repeat(GGML_TYPE_I32, {10, 10, 10, 10}, {2, 1, 1, 1})); + test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 10, 10, 10}, {1, 1, 1, 2})); - test_cases.emplace_back(new test_dup()); + test_cases.emplace_back(new test_dup(GGML_TYPE_F32)); + test_cases.emplace_back(new test_dup(GGML_TYPE_F16)); + test_cases.emplace_back(new test_dup(GGML_TYPE_I32)); + test_cases.emplace_back(new test_dup(GGML_TYPE_I16)); + test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {0, 2, 1, 3})); + test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {1, 2, 0, 3})); for (ggml_type type : all_types) { test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, type, {256, 10, 10, 1})); @@ -1565,7 +1594,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_alibi()); test_cases.emplace_back(new test_im2col()); - test_cases.emplace_back(new test_concat()); + test_cases.emplace_back(new test_concat(GGML_TYPE_F32)); + test_cases.emplace_back(new test_concat(GGML_TYPE_I32)); for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) { test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));