diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 0a28a1111..367df07a7 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -515,6 +515,31 @@ jobs:
- name: Build Xcode project
run: xcodebuild -project examples/llama.swiftui/llama.swiftui.xcodeproj -scheme llama.swiftui -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' build
+ android-build:
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Clone
+ uses: actions/checkout@v3
+
+ - name: Set up JDK
+ uses: actions/setup-java@v3
+ with:
+ java-version: 17
+ distribution: zulu
+
+ - name: Setup Android SDK
+ uses: android-actions/setup-android@v3
+ with:
+ log-accepted-android-sdk-licenses: false
+
+ - name: Build
+ run: |
+ cd examples/llama.android
+
+ # Skip armeabi-v7a for now (https://github.com/llvm/llvm-project/issues/65820).
+ ./gradlew build --no-daemon -Pskip-armeabi-v7a
+
# freeBSD-latest:
# runs-on: macos-12
# steps:
diff --git a/examples/llama.android/.gitignore b/examples/llama.android/.gitignore
new file mode 100644
index 000000000..347e252ef
--- /dev/null
+++ b/examples/llama.android/.gitignore
@@ -0,0 +1,33 @@
+# Gradle files
+.gradle/
+build/
+
+# Local configuration file (sdk path, etc)
+local.properties
+
+# Log/OS Files
+*.log
+
+# Android Studio generated files and folders
+captures/
+.externalNativeBuild/
+.cxx/
+*.apk
+output.json
+
+# IntelliJ
+*.iml
+.idea/
+misc.xml
+deploymentTargetDropDown.xml
+render.experimental.xml
+
+# Keystore files
+*.jks
+*.keystore
+
+# Google Services (e.g. APIs or Firebase)
+google-services.json
+
+# Android Profiling
+*.hprof
diff --git a/examples/llama.android/README.md b/examples/llama.android/README.md
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/llama.android/app/.gitignore b/examples/llama.android/app/.gitignore
new file mode 100644
index 000000000..796b96d1c
--- /dev/null
+++ b/examples/llama.android/app/.gitignore
@@ -0,0 +1 @@
+/build
diff --git a/examples/llama.android/app/build.gradle.kts b/examples/llama.android/app/build.gradle.kts
new file mode 100644
index 000000000..7815a8025
--- /dev/null
+++ b/examples/llama.android/app/build.gradle.kts
@@ -0,0 +1,91 @@
+plugins {
+ id("com.android.application")
+ id("org.jetbrains.kotlin.android")
+}
+
+android {
+ namespace = "com.example.llama"
+ compileSdk = 34
+
+ ndkVersion = "26.1.10909125"
+
+ defaultConfig {
+ applicationId = "com.example.llama"
+ minSdk = 33
+ targetSdk = 34
+ versionCode = 1
+ versionName = "1.0"
+
+ testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
+ vectorDrawables {
+ useSupportLibrary = true
+ }
+ ndk {
+ // Workaround for https://github.com/llvm/llvm-project/issues/65820
+ // affecting armeabi-v7a. Skip armeabi-v7a when invoked with
+ // -Pskip-armeabi-v7a (e.g., ./gradlew build -Pskip-armeabi-v7a).
+ if (project.hasProperty("skip-armeabi-v7a")) {
+ abiFilters += listOf("arm64-v8a", "x86_64", "x86")
+ }
+ }
+ externalNativeBuild {
+ cmake {
+ cppFlags += listOf()
+ arguments += listOf()
+ }
+ }
+ }
+
+ buildTypes {
+ release {
+ isMinifyEnabled = false
+ proguardFiles(
+ getDefaultProguardFile("proguard-android-optimize.txt"),
+ "proguard-rules.pro"
+ )
+ }
+ }
+ compileOptions {
+ sourceCompatibility = JavaVersion.VERSION_1_8
+ targetCompatibility = JavaVersion.VERSION_1_8
+ }
+ kotlinOptions {
+ jvmTarget = "1.8"
+ }
+ buildFeatures {
+ compose = true
+ }
+ composeOptions {
+ kotlinCompilerExtensionVersion = "1.5.1"
+ }
+ packaging {
+ resources {
+ excludes += "/META-INF/{AL2.0,LGPL2.1}"
+ }
+ }
+ externalNativeBuild {
+ cmake {
+ path = file("src/main/cpp/CMakeLists.txt")
+ version = "3.22.1"
+ }
+ }
+}
+
+dependencies {
+
+ implementation("androidx.core:core-ktx:1.12.0")
+ implementation("androidx.lifecycle:lifecycle-runtime-ktx:2.6.2")
+ implementation("androidx.activity:activity-compose:1.8.2")
+ implementation(platform("androidx.compose:compose-bom:2023.08.00"))
+ implementation("androidx.compose.ui:ui")
+ implementation("androidx.compose.ui:ui-graphics")
+ implementation("androidx.compose.ui:ui-tooling-preview")
+ implementation("androidx.compose.material3:material3")
+ testImplementation("junit:junit:4.13.2")
+ androidTestImplementation("androidx.test.ext:junit:1.1.5")
+ androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
+ androidTestImplementation(platform("androidx.compose:compose-bom:2023.08.00"))
+ androidTestImplementation("androidx.compose.ui:ui-test-junit4")
+ debugImplementation("androidx.compose.ui:ui-tooling")
+ debugImplementation("androidx.compose.ui:ui-test-manifest")
+}
diff --git a/examples/llama.android/app/proguard-rules.pro b/examples/llama.android/app/proguard-rules.pro
new file mode 100644
index 000000000..f1b424510
--- /dev/null
+++ b/examples/llama.android/app/proguard-rules.pro
@@ -0,0 +1,21 @@
+# Add project specific ProGuard rules here.
+# You can control the set of applied configuration files using the
+# proguardFiles setting in build.gradle.
+#
+# For more details, see
+# http://developer.android.com/guide/developing/tools/proguard.html
+
+# If your project uses WebView with JS, uncomment the following
+# and specify the fully qualified class name to the JavaScript interface
+# class:
+#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
+# public *;
+#}
+
+# Uncomment this to preserve the line number information for
+# debugging stack traces.
+#-keepattributes SourceFile,LineNumberTable
+
+# If you keep the line number information, uncomment this to
+# hide the original source file name.
+#-renamesourcefileattribute SourceFile
diff --git a/examples/llama.android/app/src/main/AndroidManifest.xml b/examples/llama.android/app/src/main/AndroidManifest.xml
new file mode 100644
index 000000000..41a358a29
--- /dev/null
+++ b/examples/llama.android/app/src/main/AndroidManifest.xml
@@ -0,0 +1,30 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/examples/llama.android/app/src/main/cpp/CMakeLists.txt b/examples/llama.android/app/src/main/cpp/CMakeLists.txt
new file mode 100644
index 000000000..85139329a
--- /dev/null
+++ b/examples/llama.android/app/src/main/cpp/CMakeLists.txt
@@ -0,0 +1,50 @@
+
+# For more information about using CMake with Android Studio, read the
+# documentation: https://d.android.com/studio/projects/add-native-code.html.
+# For more examples on how to use CMake, see https://github.com/android/ndk-samples.
+
+# Sets the minimum CMake version required for this project.
+cmake_minimum_required(VERSION 3.22.1)
+
+# Declares the project name. The project name can be accessed via ${ PROJECT_NAME},
+# Since this is the top level CMakeLists.txt, the project name is also accessible
+# with ${CMAKE_PROJECT_NAME} (both CMake variables are in-sync within the top level
+# build script scope).
+project("llama-android")
+
+include(FetchContent)
+FetchContent_Declare(
+ llama
+ GIT_REPOSITORY https://github.com/ggerganov/llama.cpp
+ GIT_TAG master
+)
+
+# Also provides "common"
+FetchContent_MakeAvailable(llama)
+
+# Creates and names a library, sets it as either STATIC
+# or SHARED, and provides the relative paths to its source code.
+# You can define multiple libraries, and CMake builds them for you.
+# Gradle automatically packages shared libraries with your APK.
+#
+# In this top level CMakeLists.txt, ${CMAKE_PROJECT_NAME} is used to define
+# the target library name; in the sub-module's CMakeLists.txt, ${PROJECT_NAME}
+# is preferred for the same purpose.
+#
+# In order to load a library into your app from Java/Kotlin, you must call
+# System.loadLibrary() and pass the name of the library defined here;
+# for GameActivity/NativeActivity derived applications, the same library name must be
+# used in the AndroidManifest.xml file.
+add_library(${CMAKE_PROJECT_NAME} SHARED
+ # List C/C++ source files with relative paths to this CMakeLists.txt.
+ llama-android.cpp)
+
+# Specifies libraries CMake should link to your target library. You
+# can link libraries from various origins, such as libraries defined in this
+# build script, prebuilt third-party libraries, or Android system libraries.
+target_link_libraries(${CMAKE_PROJECT_NAME}
+ # List libraries link to the target library
+ llama
+ common
+ android
+ log)
diff --git a/examples/llama.android/app/src/main/cpp/llama-android.cpp b/examples/llama.android/app/src/main/cpp/llama-android.cpp
new file mode 100644
index 000000000..d5e705dce
--- /dev/null
+++ b/examples/llama.android/app/src/main/cpp/llama-android.cpp
@@ -0,0 +1,394 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include "llama.h"
+#include "common/common.h"
+
+// Write C++ code here.
+//
+// Do not forget to dynamically load the C++ library into your application.
+//
+// For instance,
+//
+// In MainActivity.java:
+// static {
+// System.loadLibrary("llama-android");
+// }
+//
+// Or, in MainActivity.kt:
+// companion object {
+// init {
+// System.loadLibrary("llama-android")
+// }
+// }
+
+#define TAG "llama-android.cpp"
+#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
+#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
+
+jclass la_int_var;
+jmethodID la_int_var_value;
+jmethodID la_int_var_inc;
+
+static void log_callback(ggml_log_level level, const char * fmt, void * data) {
+ if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
+ else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
+ else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
+ else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
+}
+
+extern "C"
+JNIEXPORT jlong JNICALL
+Java_com_example_llama_Llm_load_1model(JNIEnv *env, jobject, jstring filename) {
+ llama_model_params model_params = llama_model_default_params();
+
+ auto path_to_model = env->GetStringUTFChars(filename, 0);
+ LOGi("Loading model from %s", path_to_model);
+
+ auto model = llama_load_model_from_file(path_to_model, model_params);
+ env->ReleaseStringUTFChars(filename, path_to_model);
+
+ if (!model) {
+ LOGe("load_model() failed");
+ env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed");
+ return 0;
+ }
+
+ return reinterpret_cast(model);
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_com_example_llama_Llm_free_1model(JNIEnv *, jobject, jlong model) {
+ llama_free_model(reinterpret_cast(model));
+}
+
+extern "C"
+JNIEXPORT jlong JNICALL
+Java_com_example_llama_Llm_new_1context(JNIEnv *env, jobject, jlong jmodel) {
+ auto model = reinterpret_cast(jmodel);
+
+ if (!model) {
+ LOGe("new_context(): model cannot be null");
+ env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null");
+ return 0;
+ }
+
+ int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2));
+ LOGi("Using %d threads", n_threads);
+
+ llama_context_params ctx_params = llama_context_default_params();
+ ctx_params.seed = 1234;
+ ctx_params.n_ctx = 2048;
+ ctx_params.n_threads = n_threads;
+ ctx_params.n_threads_batch = n_threads;
+
+ llama_context * context = llama_new_context_with_model(model, ctx_params);
+
+ if (!context) {
+ LOGe("llama_new_context_with_model() returned null)");
+ env->ThrowNew(env->FindClass("java/lang/IllegalStateException"),
+ "llama_new_context_with_model() returned null)");
+ return 0;
+ }
+
+ return reinterpret_cast(context);
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_com_example_llama_Llm_free_1context(JNIEnv *, jobject, jlong context) {
+ llama_free(reinterpret_cast(context));
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_com_example_llama_Llm_backend_1free(JNIEnv *, jobject) {
+ llama_backend_free();
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_com_example_llama_Llm_log_1to_1android(JNIEnv *, jobject) {
+ llama_log_set(log_callback, NULL);
+}
+
+extern "C"
+JNIEXPORT jstring JNICALL
+Java_com_example_llama_Llm_bench_1model(
+ JNIEnv *env,
+ jobject,
+ jlong context_pointer,
+ jlong model_pointer,
+ jlong batch_pointer,
+ jint pp,
+ jint tg,
+ jint pl,
+ jint nr
+ ) {
+ auto pp_avg = 0.0;
+ auto tg_avg = 0.0;
+ auto pp_std = 0.0;
+ auto tg_std = 0.0;
+
+ const auto context = reinterpret_cast(context_pointer);
+ const auto model = reinterpret_cast(model_pointer);
+ const auto batch = reinterpret_cast(batch_pointer);
+
+ const int n_ctx = llama_n_ctx(context);
+
+ LOGi("n_ctx = %d", n_ctx);
+
+ int i, j;
+ int nri;
+ for (nri = 0; nri < nr; nri++) {
+ LOGi("Benchmark prompt processing (pp)");
+
+ llama_batch_clear(*batch);
+
+ const int n_tokens = pp;
+ for (i = 0; i < n_tokens; i++) {
+ llama_batch_add(*batch, 0, i, { 0 }, false);
+ }
+
+ batch->logits[batch->n_tokens - 1] = true;
+ llama_kv_cache_clear(context);
+
+ const auto t_pp_start = ggml_time_us();
+ if (llama_decode(context, *batch) != 0) {
+ LOGi("llama_decode() failed during prompt processing");
+ }
+ const auto t_pp_end = ggml_time_us();
+
+ // bench text generation
+
+ LOGi("Benchmark text generation (tg)");
+
+ llama_kv_cache_clear(context);
+ const auto t_tg_start = ggml_time_us();
+ for (i = 0; i < tg; i++) {
+
+ llama_batch_clear(*batch);
+ for (j = 0; j < pl; j++) {
+ llama_batch_add(*batch, 0, i, { j }, true);
+ }
+
+ LOGi("llama_decode() text generation: %d", i);
+ if (llama_decode(context, *batch) != 0) {
+ LOGi("llama_decode() failed during text generation");
+ }
+ }
+
+ const auto t_tg_end = ggml_time_us();
+
+ llama_kv_cache_clear(context);
+
+ const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
+ const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
+
+ const auto speed_pp = double(pp) / t_pp;
+ const auto speed_tg = double(pl * tg) / t_tg;
+
+ pp_avg += speed_pp;
+ tg_avg += speed_tg;
+
+ pp_std += speed_pp * speed_pp;
+ tg_std += speed_tg * speed_tg;
+
+ LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
+ }
+
+ pp_avg /= double(nr);
+ tg_avg /= double(nr);
+
+ if (nr > 1) {
+ pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
+ tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
+ } else {
+ pp_std = 0;
+ tg_std = 0;
+ }
+
+ char model_desc[128];
+ llama_model_desc(model, model_desc, sizeof(model_desc));
+
+ const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0;
+ const auto model_n_params = double(llama_model_n_params(model)) / 1e9;
+
+ const auto backend = "(Android)"; // TODO: What should this be?
+
+ std::stringstream result;
+ result << std::setprecision(2);
+ result << "| model | size | params | backend | test | t/s |\n";
+ result << "| --- | --- | --- | --- | --- | --- |\n";
+ result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
+ result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
+
+ return env->NewStringUTF(result.str().c_str());
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_com_example_llama_Llm_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
+ llama_batch_free(*reinterpret_cast(batch_pointer));
+}
+
+extern "C"
+JNIEXPORT jlong JNICALL
+Java_com_example_llama_Llm_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
+
+ // Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
+
+ llama_batch *batch = new llama_batch {
+ 0,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ 0,
+ 0,
+ 0,
+ };
+
+ if (embd) {
+ batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd);
+ } else {
+ batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
+ }
+
+ batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
+ batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
+ batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
+ for (int i = 0; i < n_tokens; ++i) {
+ batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
+ }
+ batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
+
+ return reinterpret_cast(batch);
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_com_example_llama_Llm_backend_1init(JNIEnv *, jobject, jboolean numa) {
+ llama_backend_init(numa);
+}
+
+extern "C"
+JNIEXPORT jstring JNICALL
+Java_com_example_llama_Llm_system_1info(JNIEnv *env, jobject) {
+ return env->NewStringUTF(llama_print_system_info());
+}
+
+extern "C"
+JNIEXPORT jint JNICALL
+Java_com_example_llama_Llm_completion_1init(
+ JNIEnv *env,
+ jobject,
+ jlong context_pointer,
+ jlong batch_pointer,
+ jstring jtext,
+ jint n_len
+ ) {
+
+ const auto text = env->GetStringUTFChars(jtext, 0);
+ const auto context = reinterpret_cast(context_pointer);
+ const auto batch = reinterpret_cast(batch_pointer);
+
+ const auto tokens_list = llama_tokenize(context, text, 1);
+
+ auto n_ctx = llama_n_ctx(context);
+ auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
+
+ LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req);
+
+ if (n_kv_req > n_ctx) {
+ LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
+ }
+
+ for (auto id : tokens_list) {
+ LOGi("%s", llama_token_to_piece(context, id).c_str());
+ }
+
+ llama_batch_clear(*batch);
+
+ // evaluate the initial prompt
+ for (auto i = 0; i < tokens_list.size(); i++) {
+ llama_batch_add(*batch, tokens_list[i], i, { 0 }, false);
+ }
+
+ // llama_decode will output logits only for the last token of the prompt
+ batch->logits[batch->n_tokens - 1] = true;
+
+ if (llama_decode(context, *batch) != 0) {
+ LOGe("llama_decode() failed");
+ }
+
+ env->ReleaseStringUTFChars(jtext, text);
+
+ return batch->n_tokens;
+}
+
+extern "C"
+JNIEXPORT jstring JNICALL
+Java_com_example_llama_Llm_completion_1loop(
+ JNIEnv * env,
+ jobject,
+ jlong context_pointer,
+ jlong batch_pointer,
+ jint n_len,
+ jobject intvar_ncur
+) {
+ const auto context = reinterpret_cast(context_pointer);
+ const auto batch = reinterpret_cast(batch_pointer);
+ const auto model = llama_get_model(context);
+
+ if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
+ if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
+ if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
+
+ auto n_vocab = llama_n_vocab(model);
+ auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
+
+ std::vector candidates;
+ candidates.reserve(n_vocab);
+
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
+ }
+
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+
+ // sample the most likely token
+ const auto new_token_id = llama_sample_token_greedy(context, &candidates_p);
+
+ const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
+ if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
+ return env->NewStringUTF("");
+ }
+
+ auto new_token_chars = llama_token_to_piece(context, new_token_id);
+ LOGi("new_token_chars: `%s`", new_token_chars.c_str());
+ auto new_token = env->NewStringUTF(new_token_chars.c_str());
+
+ llama_batch_clear(*batch);
+ llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
+
+ env->CallVoidMethod(intvar_ncur, la_int_var_inc);
+
+ if (llama_decode(context, *batch) != 0) {
+ LOGe("llama_decode() returned null");
+ }
+
+ return new_token;
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_com_example_llama_Llm_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
+ llama_kv_cache_clear(reinterpret_cast(context));
+}
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/Downloadable.kt b/examples/llama.android/app/src/main/java/com/example/llama/Downloadable.kt
new file mode 100644
index 000000000..78c231ae5
--- /dev/null
+++ b/examples/llama.android/app/src/main/java/com/example/llama/Downloadable.kt
@@ -0,0 +1,119 @@
+package com.example.llama
+
+import android.app.DownloadManager
+import android.net.Uri
+import android.util.Log
+import androidx.compose.material3.Button
+import androidx.compose.material3.Text
+import androidx.compose.runtime.Composable
+import androidx.compose.runtime.getValue
+import androidx.compose.runtime.mutableDoubleStateOf
+import androidx.compose.runtime.mutableStateOf
+import androidx.compose.runtime.remember
+import androidx.compose.runtime.rememberCoroutineScope
+import androidx.compose.runtime.setValue
+import androidx.core.database.getLongOrNull
+import androidx.core.net.toUri
+import kotlinx.coroutines.delay
+import kotlinx.coroutines.launch
+import java.io.File
+
+data class Downloadable(val name: String, val source: Uri, val destination: File) {
+ companion object {
+ @JvmStatic
+ private val tag: String? = this::class.qualifiedName
+
+ sealed interface State
+ data object Ready: State
+ data class Downloading(val id: Long): State
+ data class Downloaded(val downloadable: Downloadable): State
+ data class Error(val message: String): State
+
+ @JvmStatic
+ @Composable
+ fun Button(viewModel: MainViewModel, dm: DownloadManager, item: Downloadable) {
+ var status: State by remember {
+ mutableStateOf(
+ if (item.destination.exists()) Downloaded(item)
+ else Ready
+ )
+ }
+ var progress by remember { mutableDoubleStateOf(0.0) }
+
+ val coroutineScope = rememberCoroutineScope()
+
+ suspend fun waitForDownload(result: Downloading, item: Downloadable): State {
+ while (true) {
+ val cursor = dm.query(DownloadManager.Query().setFilterById(result.id))
+
+ if (cursor == null) {
+ Log.e(tag, "dm.query() returned null")
+ return Error("dm.query() returned null")
+ }
+
+ if (!cursor.moveToFirst() || cursor.count < 1) {
+ cursor.close()
+ Log.i(tag, "cursor.moveToFirst() returned false or cursor.count < 1, download canceled?")
+ return Ready
+ }
+
+ val pix = cursor.getColumnIndex(DownloadManager.COLUMN_BYTES_DOWNLOADED_SO_FAR)
+ val tix = cursor.getColumnIndex(DownloadManager.COLUMN_TOTAL_SIZE_BYTES)
+ val sofar = cursor.getLongOrNull(pix) ?: 0
+ val total = cursor.getLongOrNull(tix) ?: 1
+ cursor.close()
+
+ if (sofar == total) {
+ return Downloaded(item)
+ }
+
+ progress = (sofar * 1.0) / total
+
+ delay(1000L)
+ }
+ }
+
+ fun onClick() {
+ when (val s = status) {
+ is Downloaded -> {
+ viewModel.load(item.destination.path)
+ }
+
+ is Downloading -> {
+ coroutineScope.launch {
+ status = waitForDownload(s, item)
+ }
+ }
+
+ else -> {
+ item.destination.delete()
+
+ val request = DownloadManager.Request(item.source).apply {
+ setTitle("Downloading model")
+ setDescription("Downloading model: ${item.name}")
+ setAllowedNetworkTypes(DownloadManager.Request.NETWORK_WIFI)
+ setDestinationUri(item.destination.toUri())
+ }
+
+ viewModel.log("Saving ${item.name} to ${item.destination.path}")
+ Log.i(tag, "Saving ${item.name} to ${item.destination.path}")
+
+ val id = dm.enqueue(request)
+ status = Downloading(id)
+ onClick()
+ }
+ }
+ }
+
+ Button(onClick = { onClick() }, enabled = status !is Downloading) {
+ when (status) {
+ is Downloading -> Text(text = "Downloading ${(progress * 100).toInt()}%")
+ is Downloaded -> Text("Load ${item.name}")
+ is Ready -> Text("Download ${item.name}")
+ is Error -> Text("Download ${item.name}")
+ }
+ }
+ }
+
+ }
+}
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt b/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt
new file mode 100644
index 000000000..5f3270372
--- /dev/null
+++ b/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt
@@ -0,0 +1,172 @@
+package com.example.llama
+
+import android.util.Log
+import kotlinx.coroutines.CoroutineDispatcher
+import kotlinx.coroutines.asCoroutineDispatcher
+import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.flow.flow
+import kotlinx.coroutines.flow.flowOn
+import kotlinx.coroutines.withContext
+import java.util.concurrent.Executors
+import kotlin.concurrent.thread
+
+class Llm {
+ private val tag: String? = this::class.simpleName
+
+ private val threadLocalState: ThreadLocal = ThreadLocal.withInitial { State.Idle }
+
+ private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
+ thread(start = false, name = "Llm-RunLoop") {
+ Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}")
+
+ // No-op if called more than once.
+ System.loadLibrary("llama-android")
+
+ // Set llama log handler to Android
+ log_to_android()
+ backend_init(false)
+
+ Log.d(tag, system_info())
+
+ it.run()
+ }.apply {
+ uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
+ Log.e(tag, "Unhandled exception", exception)
+ }
+ }
+ }.asCoroutineDispatcher()
+
+ private val nlen: Int = 64
+
+ private external fun log_to_android()
+ private external fun load_model(filename: String): Long
+ private external fun free_model(model: Long)
+ private external fun new_context(model: Long): Long
+ private external fun free_context(context: Long)
+ private external fun backend_init(numa: Boolean)
+ private external fun backend_free()
+ private external fun free_batch(batch: Long)
+ private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
+ private external fun bench_model(
+ context: Long,
+ model: Long,
+ batch: Long,
+ pp: Int,
+ tg: Int,
+ pl: Int,
+ nr: Int
+ ): String
+
+ private external fun system_info(): String
+
+ private external fun completion_init(
+ context: Long,
+ batch: Long,
+ text: String,
+ nLen: Int
+ ): Int
+
+ private external fun completion_loop(
+ context: Long,
+ batch: Long,
+ nLen: Int,
+ ncur: IntVar
+ ): String
+
+ private external fun kv_cache_clear(context: Long)
+
+ suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
+ return withContext(runLoop) {
+ when (val state = threadLocalState.get()) {
+ is State.Loaded -> {
+ Log.d(tag, "bench(): $state")
+ bench_model(state.context, state.model, state.batch, pp, tg, pl, nr)
+ }
+
+ else -> throw IllegalStateException("No model loaded")
+ }
+ }
+ }
+
+ suspend fun load(pathToModel: String) {
+ withContext(runLoop) {
+ when (threadLocalState.get()) {
+ is State.Idle -> {
+ val model = load_model(pathToModel)
+ if (model == 0L) throw IllegalStateException("load_model() failed")
+
+ val context = new_context(model)
+ if (context == 0L) throw IllegalStateException("new_context() failed")
+
+ val batch = new_batch(512, 0, 1)
+ if (batch == 0L) throw IllegalStateException("new_batch() failed")
+
+ Log.i(tag, "Loaded model $pathToModel")
+ threadLocalState.set(State.Loaded(model, context, batch))
+ }
+ else -> throw IllegalStateException("Model already loaded")
+ }
+ }
+ }
+
+ fun send(message: String): Flow = flow {
+ when (val state = threadLocalState.get()) {
+ is State.Loaded -> {
+ val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
+ while (ncur.value <= nlen) {
+ val str = completion_loop(state.context, state.batch, nlen, ncur)
+ if (str.isEmpty()) {
+ break
+ }
+ emit(str)
+ }
+ kv_cache_clear(state.context)
+ }
+ else -> {}
+ }
+ }.flowOn(runLoop)
+
+ /**
+ * Unloads the model and frees resources.
+ *
+ * This is a no-op if there's no model loaded.
+ */
+ suspend fun unload() {
+ withContext(runLoop) {
+ when (val state = threadLocalState.get()) {
+ is State.Loaded -> {
+ free_context(state.context)
+ free_model(state.model)
+ free_batch(state.batch)
+
+ threadLocalState.set(State.Idle)
+ }
+ else -> {}
+ }
+ }
+ }
+
+ companion object {
+ private class IntVar(value: Int) {
+ @Volatile
+ var value: Int = value
+ private set
+
+ fun inc() {
+ synchronized(this) {
+ value += 1
+ }
+ }
+ }
+
+ private sealed interface State {
+ data object Idle: State
+ data class Loaded(val model: Long, val context: Long, val batch: Long): State
+ }
+
+ // Enforce only one instance of Llm.
+ private val _instance: Llm = Llm()
+
+ fun instance(): Llm = _instance
+ }
+}
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt
new file mode 100644
index 000000000..9da04f7d3
--- /dev/null
+++ b/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt
@@ -0,0 +1,154 @@
+package com.example.llama
+
+import android.app.ActivityManager
+import android.app.DownloadManager
+import android.content.ClipData
+import android.content.ClipboardManager
+import android.net.Uri
+import android.os.Bundle
+import android.os.StrictMode
+import android.os.StrictMode.VmPolicy
+import android.text.format.Formatter
+import androidx.activity.ComponentActivity
+import androidx.activity.compose.setContent
+import androidx.activity.viewModels
+import androidx.compose.foundation.layout.Box
+import androidx.compose.foundation.layout.Column
+import androidx.compose.foundation.layout.Row
+import androidx.compose.foundation.layout.fillMaxSize
+import androidx.compose.foundation.layout.padding
+import androidx.compose.foundation.lazy.LazyColumn
+import androidx.compose.foundation.lazy.items
+import androidx.compose.foundation.lazy.rememberLazyListState
+import androidx.compose.material3.Button
+import androidx.compose.material3.LocalContentColor
+import androidx.compose.material3.MaterialTheme
+import androidx.compose.material3.OutlinedTextField
+import androidx.compose.material3.Surface
+import androidx.compose.material3.Text
+import androidx.compose.runtime.Composable
+import androidx.compose.ui.Modifier
+import androidx.compose.ui.unit.dp
+import androidx.core.content.getSystemService
+import com.example.llama.ui.theme.LlamaAndroidTheme
+import java.io.File
+
+class MainActivity(
+ activityManager: ActivityManager? = null,
+ downloadManager: DownloadManager? = null,
+ clipboardManager: ClipboardManager? = null,
+): ComponentActivity() {
+ private val tag: String? = this::class.simpleName
+
+ private val activityManager by lazy { activityManager ?: getSystemService()!! }
+ private val downloadManager by lazy { downloadManager ?: getSystemService()!! }
+ private val clipboardManager by lazy { clipboardManager ?: getSystemService()!! }
+
+ private val viewModel: MainViewModel by viewModels()
+
+ // Get a MemoryInfo object for the device's current memory status.
+ private fun availableMemory(): ActivityManager.MemoryInfo {
+ return ActivityManager.MemoryInfo().also { memoryInfo ->
+ activityManager.getMemoryInfo(memoryInfo)
+ }
+ }
+
+ override fun onCreate(savedInstanceState: Bundle?) {
+ super.onCreate(savedInstanceState)
+
+ StrictMode.setVmPolicy(
+ VmPolicy.Builder(StrictMode.getVmPolicy())
+ .detectLeakedClosableObjects()
+ .build()
+ )
+
+ val free = Formatter.formatFileSize(this, availableMemory().availMem)
+ val total = Formatter.formatFileSize(this, availableMemory().totalMem)
+
+ viewModel.log("Current memory: $free / $total")
+ viewModel.log("Downloads directory: ${getExternalFilesDir(null)}")
+
+ val extFilesDir = getExternalFilesDir(null)
+
+ val models = listOf(
+ Downloadable(
+ "Phi-2 7B (Q4_0, 1.6 GiB)",
+ Uri.parse("https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true"),
+ File(extFilesDir, "phi-2-q4_0.gguf"),
+ ),
+ Downloadable(
+ "TinyLlama 1.1B (f16, 2.2 GiB)",
+ Uri.parse("https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf?download=true"),
+ File(extFilesDir, "tinyllama-1.1-f16.gguf"),
+ ),
+ Downloadable(
+ "Phi 2 DPO (Q3_K_M, 1.48 GiB)",
+ Uri.parse("https://huggingface.co/TheBloke/phi-2-dpo-GGUF/resolve/main/phi-2-dpo.Q3_K_M.gguf?download=true"),
+ File(extFilesDir, "phi-2-dpo.Q3_K_M.gguf")
+ ),
+ )
+
+ setContent {
+ LlamaAndroidTheme {
+ // A surface container using the 'background' color from the theme
+ Surface(
+ modifier = Modifier.fillMaxSize(),
+ color = MaterialTheme.colorScheme.background
+ ) {
+ MainCompose(
+ viewModel,
+ clipboardManager,
+ downloadManager,
+ models,
+ )
+ }
+
+ }
+ }
+ }
+}
+
+@Composable
+fun MainCompose(
+ viewModel: MainViewModel,
+ clipboard: ClipboardManager,
+ dm: DownloadManager,
+ models: List
+) {
+ Column {
+ val scrollState = rememberLazyListState()
+
+ Box(modifier = Modifier.weight(1f)) {
+ LazyColumn(state = scrollState) {
+ items(viewModel.messages) {
+ Text(
+ it,
+ style = MaterialTheme.typography.bodyLarge.copy(color = LocalContentColor.current),
+ modifier = Modifier.padding(16.dp)
+ )
+ }
+ }
+ }
+ OutlinedTextField(
+ value = viewModel.message,
+ onValueChange = { viewModel.updateMessage(it) },
+ label = { Text("Message") },
+ )
+ Row {
+ Button({ viewModel.send() }) { Text("Send") }
+ Button({ viewModel.bench(8, 4, 1) }) { Text("Bench") }
+ Button({ viewModel.clear() }) { Text("Clear") }
+ Button({
+ viewModel.messages.joinToString("\n").let {
+ clipboard.setPrimaryClip(ClipData.newPlainText("", it))
+ }
+ }) { Text("Copy") }
+ }
+
+ Column {
+ for (model in models) {
+ Downloadable.Button(viewModel, dm, model)
+ }
+ }
+ }
+}
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt
new file mode 100644
index 000000000..be95e2221
--- /dev/null
+++ b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt
@@ -0,0 +1,104 @@
+package com.example.llama
+
+import android.util.Log
+import androidx.compose.runtime.getValue
+import androidx.compose.runtime.mutableStateOf
+import androidx.compose.runtime.setValue
+import androidx.lifecycle.ViewModel
+import androidx.lifecycle.viewModelScope
+import kotlinx.coroutines.flow.catch
+import kotlinx.coroutines.launch
+
+class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
+ companion object {
+ @JvmStatic
+ private val NanosPerSecond = 1_000_000_000.0
+ }
+
+ private val tag: String? = this::class.simpleName
+
+ var messages by mutableStateOf(listOf("Initializing..."))
+ private set
+
+ var message by mutableStateOf("")
+ private set
+
+ override fun onCleared() {
+ super.onCleared()
+
+ viewModelScope.launch {
+ try {
+ llm.unload()
+ } catch (exc: IllegalStateException) {
+ messages += exc.message!!
+ }
+ }
+ }
+
+ fun send() {
+ val text = message
+ message = ""
+
+ // Add to messages console.
+ messages += text
+ messages += ""
+
+ viewModelScope.launch {
+ llm.send(text)
+ .catch {
+ Log.e(tag, "send() failed", it)
+ messages += it.message!!
+ }
+ .collect { messages = messages.dropLast(1) + (messages.last() + it) }
+ }
+ }
+
+ fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1) {
+ viewModelScope.launch {
+ try {
+ val start = System.nanoTime()
+ val warmupResult = llm.bench(pp, tg, pl, nr)
+ val end = System.nanoTime()
+
+ messages += warmupResult
+
+ val warmup = (end - start).toDouble() / NanosPerSecond
+ messages += "Warm up time: $warmup seconds, please wait..."
+
+ if (warmup > 5.0) {
+ messages += "Warm up took too long, aborting benchmark"
+ return@launch
+ }
+
+ messages += llm.bench(512, 128, 1, 3)
+ } catch (exc: IllegalStateException) {
+ Log.e(tag, "bench() failed", exc)
+ messages += exc.message!!
+ }
+ }
+ }
+
+ fun load(pathToModel: String) {
+ viewModelScope.launch {
+ try {
+ llm.load(pathToModel)
+ messages += "Loaded $pathToModel"
+ } catch (exc: IllegalStateException) {
+ Log.e(tag, "load() failed", exc)
+ messages += exc.message!!
+ }
+ }
+ }
+
+ fun updateMessage(newMessage: String) {
+ message = newMessage
+ }
+
+ fun clear() {
+ messages = listOf()
+ }
+
+ fun log(message: String) {
+ messages += message
+ }
+}
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Color.kt b/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Color.kt
new file mode 100644
index 000000000..40c30e8d9
--- /dev/null
+++ b/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Color.kt
@@ -0,0 +1,11 @@
+package com.example.llama.ui.theme
+
+import androidx.compose.ui.graphics.Color
+
+val Purple80 = Color(0xFFD0BCFF)
+val PurpleGrey80 = Color(0xFFCCC2DC)
+val Pink80 = Color(0xFFEFB8C8)
+
+val Purple40 = Color(0xFF6650a4)
+val PurpleGrey40 = Color(0xFF625b71)
+val Pink40 = Color(0xFF7D5260)
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Theme.kt b/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Theme.kt
new file mode 100644
index 000000000..e742220a8
--- /dev/null
+++ b/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Theme.kt
@@ -0,0 +1,70 @@
+package com.example.llama.ui.theme
+
+import android.app.Activity
+import android.os.Build
+import androidx.compose.foundation.isSystemInDarkTheme
+import androidx.compose.material3.MaterialTheme
+import androidx.compose.material3.darkColorScheme
+import androidx.compose.material3.dynamicDarkColorScheme
+import androidx.compose.material3.dynamicLightColorScheme
+import androidx.compose.material3.lightColorScheme
+import androidx.compose.runtime.Composable
+import androidx.compose.runtime.SideEffect
+import androidx.compose.ui.graphics.toArgb
+import androidx.compose.ui.platform.LocalContext
+import androidx.compose.ui.platform.LocalView
+import androidx.core.view.WindowCompat
+
+private val DarkColorScheme = darkColorScheme(
+ primary = Purple80,
+ secondary = PurpleGrey80,
+ tertiary = Pink80
+)
+
+private val LightColorScheme = lightColorScheme(
+ primary = Purple40,
+ secondary = PurpleGrey40,
+ tertiary = Pink40
+
+ /* Other default colors to override
+ background = Color(0xFFFFFBFE),
+ surface = Color(0xFFFFFBFE),
+ onPrimary = Color.White,
+ onSecondary = Color.White,
+ onTertiary = Color.White,
+ onBackground = Color(0xFF1C1B1F),
+ onSurface = Color(0xFF1C1B1F),
+ */
+)
+
+@Composable
+fun LlamaAndroidTheme(
+ darkTheme: Boolean = isSystemInDarkTheme(),
+ // Dynamic color is available on Android 12+
+ dynamicColor: Boolean = true,
+ content: @Composable () -> Unit
+) {
+ val colorScheme = when {
+ dynamicColor && Build.VERSION.SDK_INT >= Build.VERSION_CODES.S -> {
+ val context = LocalContext.current
+ if (darkTheme) dynamicDarkColorScheme(context) else dynamicLightColorScheme(context)
+ }
+
+ darkTheme -> DarkColorScheme
+ else -> LightColorScheme
+ }
+ val view = LocalView.current
+ if (!view.isInEditMode) {
+ SideEffect {
+ val window = (view.context as Activity).window
+ window.statusBarColor = colorScheme.primary.toArgb()
+ WindowCompat.getInsetsController(window, view).isAppearanceLightStatusBars = darkTheme
+ }
+ }
+
+ MaterialTheme(
+ colorScheme = colorScheme,
+ typography = Typography,
+ content = content
+ )
+}
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Type.kt b/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Type.kt
new file mode 100644
index 000000000..0b87946ca
--- /dev/null
+++ b/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Type.kt
@@ -0,0 +1,34 @@
+package com.example.llama.ui.theme
+
+import androidx.compose.material3.Typography
+import androidx.compose.ui.text.TextStyle
+import androidx.compose.ui.text.font.FontFamily
+import androidx.compose.ui.text.font.FontWeight
+import androidx.compose.ui.unit.sp
+
+// Set of Material typography styles to start with
+val Typography = Typography(
+ bodyLarge = TextStyle(
+ fontFamily = FontFamily.Default,
+ fontWeight = FontWeight.Normal,
+ fontSize = 16.sp,
+ lineHeight = 24.sp,
+ letterSpacing = 0.5.sp
+ )
+ /* Other default text styles to override
+ titleLarge = TextStyle(
+ fontFamily = FontFamily.Default,
+ fontWeight = FontWeight.Normal,
+ fontSize = 22.sp,
+ lineHeight = 28.sp,
+ letterSpacing = 0.sp
+ ),
+ labelSmall = TextStyle(
+ fontFamily = FontFamily.Default,
+ fontWeight = FontWeight.Medium,
+ fontSize = 11.sp,
+ lineHeight = 16.sp,
+ letterSpacing = 0.5.sp
+ )
+ */
+)
diff --git a/examples/llama.android/app/src/main/res/drawable/ic_launcher_background.xml b/examples/llama.android/app/src/main/res/drawable/ic_launcher_background.xml
new file mode 100644
index 000000000..07d5da9cb
--- /dev/null
+++ b/examples/llama.android/app/src/main/res/drawable/ic_launcher_background.xml
@@ -0,0 +1,170 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/examples/llama.android/app/src/main/res/drawable/ic_launcher_foreground.xml b/examples/llama.android/app/src/main/res/drawable/ic_launcher_foreground.xml
new file mode 100644
index 000000000..7706ab9e6
--- /dev/null
+++ b/examples/llama.android/app/src/main/res/drawable/ic_launcher_foreground.xml
@@ -0,0 +1,30 @@
+
+
+
+
+
+
+
+
+
+
+
diff --git a/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher.xml b/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher.xml
new file mode 100644
index 000000000..b3e26b4c6
--- /dev/null
+++ b/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
diff --git a/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher_round.xml b/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher_round.xml
new file mode 100644
index 000000000..b3e26b4c6
--- /dev/null
+++ b/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher_round.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
diff --git a/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher.webp b/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher.webp
new file mode 100644
index 000000000..c209e78ec
Binary files /dev/null and b/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher.webp differ
diff --git a/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp b/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp
new file mode 100644
index 000000000..b2dfe3d1b
Binary files /dev/null and b/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp differ
diff --git a/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher.webp b/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher.webp
new file mode 100644
index 000000000..4f0f1d64e
Binary files /dev/null and b/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher.webp differ
diff --git a/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp b/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp
new file mode 100644
index 000000000..62b611da0
Binary files /dev/null and b/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp differ
diff --git a/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher.webp b/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher.webp
new file mode 100644
index 000000000..948a3070f
Binary files /dev/null and b/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher.webp differ
diff --git a/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp b/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp
new file mode 100644
index 000000000..1b9a6956b
Binary files /dev/null and b/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp differ
diff --git a/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp b/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp
new file mode 100644
index 000000000..28d4b77f9
Binary files /dev/null and b/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp differ
diff --git a/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp b/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp
new file mode 100644
index 000000000..9287f5083
Binary files /dev/null and b/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp differ
diff --git a/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp b/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp
new file mode 100644
index 000000000..aa7d6427e
Binary files /dev/null and b/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp differ
diff --git a/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp b/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp
new file mode 100644
index 000000000..9126ae37c
Binary files /dev/null and b/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp differ
diff --git a/examples/llama.android/app/src/main/res/values/colors.xml b/examples/llama.android/app/src/main/res/values/colors.xml
new file mode 100644
index 000000000..ca1931bca
--- /dev/null
+++ b/examples/llama.android/app/src/main/res/values/colors.xml
@@ -0,0 +1,10 @@
+
+
+ #FFBB86FC
+ #FF6200EE
+ #FF3700B3
+ #FF03DAC5
+ #FF018786
+ #FF000000
+ #FFFFFFFF
+
diff --git a/examples/llama.android/app/src/main/res/values/strings.xml b/examples/llama.android/app/src/main/res/values/strings.xml
new file mode 100644
index 000000000..7a9d314e2
--- /dev/null
+++ b/examples/llama.android/app/src/main/res/values/strings.xml
@@ -0,0 +1,3 @@
+
+ LlamaAndroid
+
diff --git a/examples/llama.android/app/src/main/res/values/themes.xml b/examples/llama.android/app/src/main/res/values/themes.xml
new file mode 100644
index 000000000..8a24fda56
--- /dev/null
+++ b/examples/llama.android/app/src/main/res/values/themes.xml
@@ -0,0 +1,5 @@
+
+
+
+
+
diff --git a/examples/llama.android/app/src/main/res/xml/backup_rules.xml b/examples/llama.android/app/src/main/res/xml/backup_rules.xml
new file mode 100644
index 000000000..148c18b65
--- /dev/null
+++ b/examples/llama.android/app/src/main/res/xml/backup_rules.xml
@@ -0,0 +1,13 @@
+
+
+
+
diff --git a/examples/llama.android/app/src/main/res/xml/data_extraction_rules.xml b/examples/llama.android/app/src/main/res/xml/data_extraction_rules.xml
new file mode 100644
index 000000000..0c4f95cab
--- /dev/null
+++ b/examples/llama.android/app/src/main/res/xml/data_extraction_rules.xml
@@ -0,0 +1,19 @@
+
+
+
+
+
+
+
diff --git a/examples/llama.android/build.gradle.kts b/examples/llama.android/build.gradle.kts
new file mode 100644
index 000000000..50ebc8211
--- /dev/null
+++ b/examples/llama.android/build.gradle.kts
@@ -0,0 +1,5 @@
+// Top-level build file where you can add configuration options common to all sub-projects/modules.
+plugins {
+ id("com.android.application") version "8.2.0" apply false
+ id("org.jetbrains.kotlin.android") version "1.9.0" apply false
+}
diff --git a/examples/llama.android/gradle.properties b/examples/llama.android/gradle.properties
new file mode 100644
index 000000000..2cbd6d19d
--- /dev/null
+++ b/examples/llama.android/gradle.properties
@@ -0,0 +1,23 @@
+# Project-wide Gradle settings.
+# IDE (e.g. Android Studio) users:
+# Gradle settings configured through the IDE *will override*
+# any settings specified in this file.
+# For more details on how to configure your build environment visit
+# http://www.gradle.org/docs/current/userguide/build_environment.html
+# Specifies the JVM arguments used for the daemon process.
+# The setting is particularly useful for tweaking memory settings.
+org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
+# When configured, Gradle will run in incubating parallel mode.
+# This option should only be used with decoupled projects. More details, visit
+# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
+# org.gradle.parallel=true
+# AndroidX package structure to make it clearer which packages are bundled with the
+# Android operating system, and which are packaged with your app's APK
+# https://developer.android.com/topic/libraries/support-library/androidx-rn
+android.useAndroidX=true
+# Kotlin code style for this project: "official" or "obsolete":
+kotlin.code.style=official
+# Enables namespacing of each library's R class so that its R class includes only the
+# resources declared in the library itself and none from the library's dependencies,
+# thereby reducing the size of the R class for that library
+android.nonTransitiveRClass=true
diff --git a/examples/llama.android/gradle/wrapper/gradle-wrapper.jar b/examples/llama.android/gradle/wrapper/gradle-wrapper.jar
new file mode 100644
index 000000000..e708b1c02
Binary files /dev/null and b/examples/llama.android/gradle/wrapper/gradle-wrapper.jar differ
diff --git a/examples/llama.android/gradle/wrapper/gradle-wrapper.properties b/examples/llama.android/gradle/wrapper/gradle-wrapper.properties
new file mode 100644
index 000000000..a3958c140
--- /dev/null
+++ b/examples/llama.android/gradle/wrapper/gradle-wrapper.properties
@@ -0,0 +1,6 @@
+#Thu Dec 21 14:31:09 AEDT 2023
+distributionBase=GRADLE_USER_HOME
+distributionPath=wrapper/dists
+distributionUrl=https\://services.gradle.org/distributions/gradle-8.2-bin.zip
+zipStoreBase=GRADLE_USER_HOME
+zipStorePath=wrapper/dists
diff --git a/examples/llama.android/gradlew b/examples/llama.android/gradlew
new file mode 100755
index 000000000..4f906e0c8
--- /dev/null
+++ b/examples/llama.android/gradlew
@@ -0,0 +1,185 @@
+#!/usr/bin/env sh
+
+#
+# Copyright 2015 the original author or authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+##############################################################################
+##
+## Gradle start up script for UN*X
+##
+##############################################################################
+
+# Attempt to set APP_HOME
+# Resolve links: $0 may be a link
+PRG="$0"
+# Need this for relative symlinks.
+while [ -h "$PRG" ] ; do
+ ls=`ls -ld "$PRG"`
+ link=`expr "$ls" : '.*-> \(.*\)$'`
+ if expr "$link" : '/.*' > /dev/null; then
+ PRG="$link"
+ else
+ PRG=`dirname "$PRG"`"/$link"
+ fi
+done
+SAVED="`pwd`"
+cd "`dirname \"$PRG\"`/" >/dev/null
+APP_HOME="`pwd -P`"
+cd "$SAVED" >/dev/null
+
+APP_NAME="Gradle"
+APP_BASE_NAME=`basename "$0"`
+
+# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
+
+# Use the maximum available, or set MAX_FD != -1 to use that value.
+MAX_FD="maximum"
+
+warn () {
+ echo "$*"
+}
+
+die () {
+ echo
+ echo "$*"
+ echo
+ exit 1
+}
+
+# OS specific support (must be 'true' or 'false').
+cygwin=false
+msys=false
+darwin=false
+nonstop=false
+case "`uname`" in
+ CYGWIN* )
+ cygwin=true
+ ;;
+ Darwin* )
+ darwin=true
+ ;;
+ MINGW* )
+ msys=true
+ ;;
+ NONSTOP* )
+ nonstop=true
+ ;;
+esac
+
+CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
+
+
+# Determine the Java command to use to start the JVM.
+if [ -n "$JAVA_HOME" ] ; then
+ if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
+ # IBM's JDK on AIX uses strange locations for the executables
+ JAVACMD="$JAVA_HOME/jre/sh/java"
+ else
+ JAVACMD="$JAVA_HOME/bin/java"
+ fi
+ if [ ! -x "$JAVACMD" ] ; then
+ die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+ fi
+else
+ JAVACMD="java"
+ which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+fi
+
+# Increase the maximum file descriptors if we can.
+if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
+ MAX_FD_LIMIT=`ulimit -H -n`
+ if [ $? -eq 0 ] ; then
+ if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
+ MAX_FD="$MAX_FD_LIMIT"
+ fi
+ ulimit -n $MAX_FD
+ if [ $? -ne 0 ] ; then
+ warn "Could not set maximum file descriptor limit: $MAX_FD"
+ fi
+ else
+ warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
+ fi
+fi
+
+# For Darwin, add options to specify how the application appears in the dock
+if $darwin; then
+ GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
+fi
+
+# For Cygwin or MSYS, switch paths to Windows format before running java
+if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
+ APP_HOME=`cygpath --path --mixed "$APP_HOME"`
+ CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
+
+ JAVACMD=`cygpath --unix "$JAVACMD"`
+
+ # We build the pattern for arguments to be converted via cygpath
+ ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
+ SEP=""
+ for dir in $ROOTDIRSRAW ; do
+ ROOTDIRS="$ROOTDIRS$SEP$dir"
+ SEP="|"
+ done
+ OURCYGPATTERN="(^($ROOTDIRS))"
+ # Add a user-defined pattern to the cygpath arguments
+ if [ "$GRADLE_CYGPATTERN" != "" ] ; then
+ OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
+ fi
+ # Now convert the arguments - kludge to limit ourselves to /bin/sh
+ i=0
+ for arg in "$@" ; do
+ CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
+ CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
+
+ if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
+ eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
+ else
+ eval `echo args$i`="\"$arg\""
+ fi
+ i=`expr $i + 1`
+ done
+ case $i in
+ 0) set -- ;;
+ 1) set -- "$args0" ;;
+ 2) set -- "$args0" "$args1" ;;
+ 3) set -- "$args0" "$args1" "$args2" ;;
+ 4) set -- "$args0" "$args1" "$args2" "$args3" ;;
+ 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
+ 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
+ 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
+ 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
+ 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
+ esac
+fi
+
+# Escape application args
+save () {
+ for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
+ echo " "
+}
+APP_ARGS=`save "$@"`
+
+# Collect all arguments for the java command, following the shell quoting and substitution rules
+eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
+
+exec "$JAVACMD" "$@"
diff --git a/examples/llama.android/settings.gradle.kts b/examples/llama.android/settings.gradle.kts
new file mode 100644
index 000000000..2ba32c4fa
--- /dev/null
+++ b/examples/llama.android/settings.gradle.kts
@@ -0,0 +1,17 @@
+pluginManagement {
+ repositories {
+ google()
+ mavenCentral()
+ gradlePluginPortal()
+ }
+}
+dependencyResolutionManagement {
+ repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
+ repositories {
+ google()
+ mavenCentral()
+ }
+}
+
+rootProject.name = "LlamaAndroid"
+include(":app")
diff --git a/ggml-metal.m b/ggml-metal.m
index 23175005c..30b00fa4a 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -366,8 +366,12 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->support_simdgroup_reduction ? "true" : "false");
GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
-#if TARGET_OS_OSX
- GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
+
+#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
+ if (@available(macOS 10.12, iOS 16.0, *)) {
+ GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
+ }
+#elif TARGET_OS_OSX
if (ctx->device.maxTransferRate != 0) {
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
} else {
@@ -720,1478 +724,257 @@ static bool ggml_metal_graph_compute(
const int n_nodes = gf->n_nodes;
const int n_cb = ctx->n_cb;
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
- id command_buffers[n_cb];
+ id command_buffer_builder[n_cb];
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
id command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
- command_buffers[cb_idx] = command_buffer;
+ command_buffer_builder[cb_idx] = command_buffer;
// enqueue the command buffers in order to specify their execution order
[command_buffer enqueue];
+ }
+ const id *command_buffers = command_buffer_builder;
- dispatch_async(ctx->d_queue, ^{
- size_t offs_src0 = 0;
- size_t offs_src1 = 0;
- size_t offs_dst = 0;
+ dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
+ const int cb_idx = iter;
- id encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
+ size_t offs_src0 = 0;
+ size_t offs_src1 = 0;
+ size_t offs_dst = 0;
- const int node_start = (cb_idx + 0) * n_nodes_per_cb;
- const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
+ id command_buffer = command_buffers[cb_idx];
+ id encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
- for (int i = node_start; i < node_end; ++i) {
- if (i == -1) {
- [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
- continue;
- }
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
+ const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
- //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
+ for (int i = node_start; i < node_end; ++i) {
+ if (i == -1) {
+ [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
+ continue;
+ }
- struct ggml_tensor * src0 = gf->nodes[i]->src[0];
- struct ggml_tensor * src1 = gf->nodes[i]->src[1];
- struct ggml_tensor * dst = gf->nodes[i];
+ //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
- switch (dst->op) {
- case GGML_OP_NONE:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_TRANSPOSE:
- case GGML_OP_PERMUTE:
- {
- // noop -> next node
- } continue;
- default:
- {
- } break;
- }
+ struct ggml_tensor * src0 = gf->nodes[i]->src[0];
+ struct ggml_tensor * src1 = gf->nodes[i]->src[1];
+ struct ggml_tensor * dst = gf->nodes[i];
- if (!ggml_metal_supports_op(ctx, dst)) {
- GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
- GGML_ASSERT(!"unsupported op");
- }
+ switch (dst->op) {
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_PERMUTE:
+ {
+ // noop -> next node
+ } continue;
+ default:
+ {
+ } break;
+ }
+
+ if (!ggml_metal_supports_op(ctx, dst)) {
+ GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
+ GGML_ASSERT(!"unsupported op");
+ }
#ifndef GGML_METAL_NDEBUG
- [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
+ [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
#endif
- const int64_t ne00 = src0 ? src0->ne[0] : 0;
- const int64_t ne01 = src0 ? src0->ne[1] : 0;
- const int64_t ne02 = src0 ? src0->ne[2] : 0;
- const int64_t ne03 = src0 ? src0->ne[3] : 0;
+ const int64_t ne00 = src0 ? src0->ne[0] : 0;
+ const int64_t ne01 = src0 ? src0->ne[1] : 0;
+ const int64_t ne02 = src0 ? src0->ne[2] : 0;
+ const int64_t ne03 = src0 ? src0->ne[3] : 0;
- const uint64_t nb00 = src0 ? src0->nb[0] : 0;
- const uint64_t nb01 = src0 ? src0->nb[1] : 0;
- const uint64_t nb02 = src0 ? src0->nb[2] : 0;
- const uint64_t nb03 = src0 ? src0->nb[3] : 0;
+ const uint64_t nb00 = src0 ? src0->nb[0] : 0;
+ const uint64_t nb01 = src0 ? src0->nb[1] : 0;
+ const uint64_t nb02 = src0 ? src0->nb[2] : 0;
+ const uint64_t nb03 = src0 ? src0->nb[3] : 0;
- const int64_t ne10 = src1 ? src1->ne[0] : 0;
- const int64_t ne11 = src1 ? src1->ne[1] : 0;
- const int64_t ne12 = src1 ? src1->ne[2] : 0;
- const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
+ const int64_t ne10 = src1 ? src1->ne[0] : 0;
+ const int64_t ne11 = src1 ? src1->ne[1] : 0;
+ const int64_t ne12 = src1 ? src1->ne[2] : 0;
+ const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
- const uint64_t nb10 = src1 ? src1->nb[0] : 0;
- const uint64_t nb11 = src1 ? src1->nb[1] : 0;
- const uint64_t nb12 = src1 ? src1->nb[2] : 0;
- const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
+ const uint64_t nb10 = src1 ? src1->nb[0] : 0;
+ const uint64_t nb11 = src1 ? src1->nb[1] : 0;
+ const uint64_t nb12 = src1 ? src1->nb[2] : 0;
+ const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
- const int64_t ne0 = dst ? dst->ne[0] : 0;
- const int64_t ne1 = dst ? dst->ne[1] : 0;
- const int64_t ne2 = dst ? dst->ne[2] : 0;
- const int64_t ne3 = dst ? dst->ne[3] : 0;
+ const int64_t ne0 = dst ? dst->ne[0] : 0;
+ const int64_t ne1 = dst ? dst->ne[1] : 0;
+ const int64_t ne2 = dst ? dst->ne[2] : 0;
+ const int64_t ne3 = dst ? dst->ne[3] : 0;
- const uint64_t nb0 = dst ? dst->nb[0] : 0;
- const uint64_t nb1 = dst ? dst->nb[1] : 0;
- const uint64_t nb2 = dst ? dst->nb[2] : 0;
- const uint64_t nb3 = dst ? dst->nb[3] : 0;
+ const uint64_t nb0 = dst ? dst->nb[0] : 0;
+ const uint64_t nb1 = dst ? dst->nb[1] : 0;
+ const uint64_t nb2 = dst ? dst->nb[2] : 0;
+ const uint64_t nb3 = dst ? dst->nb[3] : 0;
- const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
- const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
- const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
+ const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
+ const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
+ const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
- id id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
- id id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
- id id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
+ id id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
+ id id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
+ id id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
- //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
- //if (src0) {
- // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
- // ggml_is_contiguous(src0), src0->name);
- //}
- //if (src1) {
- // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
- // ggml_is_contiguous(src1), src1->name);
- //}
- //if (dst) {
- // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
- // dst->name);
- //}
+ //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
+ //if (src0) {
+ // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
+ // ggml_is_contiguous(src0), src0->name);
+ //}
+ //if (src1) {
+ // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
+ // ggml_is_contiguous(src1), src1->name);
+ //}
+ //if (dst) {
+ // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
+ // dst->name);
+ //}
- switch (dst->op) {
- case GGML_OP_CONCAT:
- {
- const int64_t nb = ne00;
+ switch (dst->op) {
+ case GGML_OP_CONCAT:
+ {
+ const int64_t nb = ne00;
- id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
+ id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
- [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
- const int nth = MIN(1024, ne0);
+ const int nth = MIN(1024, ne0);
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_ADD:
- case GGML_OP_MUL:
- case GGML_OP_DIV:
- {
- const size_t offs = 0;
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_ADD:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ {
+ const size_t offs = 0;
- bool bcast_row = false;
+ bool bcast_row = false;
- int64_t nb = ne00;
+ int64_t nb = ne00;
- id pipeline = nil;
+ id pipeline = nil;
- if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
- GGML_ASSERT(ggml_is_contiguous(src0));
-
- // src1 is a row
- GGML_ASSERT(ne11 == 1);
-
- nb = ne00 / 4;
- switch (dst->op) {
- case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
- default: GGML_ASSERT(false);
- }
-
- bcast_row = true;
- } else {
- switch (dst->op) {
- case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
- default: GGML_ASSERT(false);
- }
- }
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
- [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
- [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
-
- if (bcast_row) {
- const int64_t n = ggml_nelements(dst)/4;
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } else {
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- }
- } break;
- case GGML_OP_ACC:
- {
- GGML_ASSERT(src0t == GGML_TYPE_F32);
- GGML_ASSERT(src1t == GGML_TYPE_F32);
- GGML_ASSERT(dstt == GGML_TYPE_F32);
-
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(src1));
-
- const size_t pnb1 = ((int32_t *) dst->op_params)[0];
- const size_t pnb2 = ((int32_t *) dst->op_params)[1];
- const size_t pnb3 = ((int32_t *) dst->op_params)[2];
- const size_t offs = ((int32_t *) dst->op_params)[3];
-
- const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
-
- if (!inplace) {
- // run a separete kernel to cpy src->dst
- // not sure how to avoid this
- // TODO: make a simpler cpy_bytes kernel
-
- const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
-
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- }
-
- const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
- [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
- [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
- [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
- [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
- [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
- [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
- [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
-
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_SCALE:
- {
+ if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
GGML_ASSERT(ggml_is_contiguous(src0));
- const float scale = *(const float *) dst->op_params;
+ // src1 is a row
+ GGML_ASSERT(ne11 == 1);
- int64_t n = ggml_nelements(dst);
-
- id pipeline = nil;
-
- if (n % 4 == 0) {
- n /= 4;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
+ nb = ne00 / 4;
+ switch (dst->op) {
+ case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
+ default: GGML_ASSERT(false);
}
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
+ bcast_row = true;
+ } else {
+ switch (dst->op) {
+ case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
+ default: GGML_ASSERT(false);
+ }
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
+
+ if (bcast_row) {
+ const int64_t n = ggml_nelements(dst)/4;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_UNARY:
- switch (ggml_get_unary_op(gf->nodes[i])) {
- case GGML_UNARY_OP_TANH:
- {
- id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
-
- const int64_t n = ggml_nelements(dst);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_UNARY_OP_RELU:
- {
- id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
-
- const int64_t n = ggml_nelements(dst);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_UNARY_OP_GELU:
- {
- id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
-
- const int64_t n = ggml_nelements(dst);
- GGML_ASSERT(n % 4 == 0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_UNARY_OP_GELU_QUICK:
- {
- id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
-
- const int64_t n = ggml_nelements(dst);
- GGML_ASSERT(n % 4 == 0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_UNARY_OP_SILU:
- {
- id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
-
- const int64_t n = ggml_nelements(dst);
- GGML_ASSERT(n % 4 == 0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- default:
- {
- GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
- GGML_ASSERT(false);
- }
- } break;
- case GGML_OP_SQR:
- {
- GGML_ASSERT(ggml_is_contiguous(src0));
-
- id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
-
- const int64_t n = ggml_nelements(dst);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_SUM_ROWS:
- {
- GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
-
- id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_SOFT_MAX:
- {
- int nth = 32; // SIMD width
-
- id pipeline = nil;
-
- if (ne00%4 == 0) {
- while (nth < ne00/4 && nth < 256) {
- nth *= 2;
- }
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline;
- } else {
- while (nth < ne00 && nth < 1024) {
- nth *= 2;
- }
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
- }
-
- const float scale = ((float *) dst->op_params)[0];
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- if (id_src1) {
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- } else {
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
- }
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
- [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_DIAG_MASK_INF:
- {
- const int n_past = ((int32_t *)(dst->op_params))[0];
-
- id pipeline = nil;
-
- if (ne00%8 == 0) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
- }
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
- [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
-
- if (ne00%8 == 0) {
- [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- }
- else {
- [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- }
- } break;
- case GGML_OP_MUL_MAT:
- {
- GGML_ASSERT(ne00 == ne10);
-
- // TODO: assert that dim2 and dim3 are contiguous
- GGML_ASSERT(ne12 % ne02 == 0);
- GGML_ASSERT(ne13 % ne03 == 0);
-
- const uint r2 = ne12/ne02;
- const uint r3 = ne13/ne03;
-
- // find the break-even point where the matrix-matrix kernel becomes more efficient compared
- // to the matrix-vector kernel
- int ne11_mm_min = 1;
-
-#if 0
- // the numbers below are measured on M2 Ultra for 7B and 13B models
- // these numbers do not translate to other devices or model sizes
- // TODO: need to find a better approach
- if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
- switch (src0t) {
- case GGML_TYPE_F16: ne11_mm_min = 2; break;
- case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
- case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
- case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
- case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
- case GGML_TYPE_Q5_0: // not tested yet
- case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
- case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
- case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
- default: ne11_mm_min = 1; break;
- }
- }
-#endif
-
- // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
- // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
- !ggml_is_transposed(src0) &&
- !ggml_is_transposed(src1) &&
- src1t == GGML_TYPE_F32 &&
- ne00 % 32 == 0 && ne00 >= 64 &&
- (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
- //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
-
- id pipeline = nil;
-
- switch (src0->type) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
- case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
- case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
- case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
- case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
- case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
- case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
- case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
- default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
- }
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
- } else {
- int nth0 = 32;
- int nth1 = 1;
- int nrows = 1;
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
-
- id pipeline = nil;
-
- // use custom matrix x vector kernel
- switch (src0t) {
- case GGML_TYPE_F32:
- {
- GGML_ASSERT(src1t == GGML_TYPE_F32);
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
- nrows = 4;
- } break;
- case GGML_TYPE_F16:
- {
- nth0 = 32;
- nth1 = 1;
- if (src1t == GGML_TYPE_F32) {
- if (ne11 * ne12 < 4) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
- nrows = ne11;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
- nrows = 4;
- }
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
- nrows = 4;
- }
- } break;
- case GGML_TYPE_Q4_0:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
- } break;
- case GGML_TYPE_Q4_1:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
- } break;
- case GGML_TYPE_Q5_0:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
- } break;
- case GGML_TYPE_Q5_1:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
- } break;
- case GGML_TYPE_Q8_0:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
- } break;
- case GGML_TYPE_Q2_K:
- {
- nth0 = 2;
- nth1 = 32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
- } break;
- case GGML_TYPE_Q3_K:
- {
- nth0 = 2;
- nth1 = 32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
- } break;
- case GGML_TYPE_Q4_K:
- {
- nth0 = 4; //1;
- nth1 = 8; //32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
- } break;
- case GGML_TYPE_Q5_K:
- {
- nth0 = 2;
- nth1 = 32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
- } break;
- case GGML_TYPE_Q6_K:
- {
- nth0 = 2;
- nth1 = 32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
- } break;
- case GGML_TYPE_IQ2_XXS:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
- } break;
- case GGML_TYPE_IQ2_XS:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
- } break;
- default:
- {
- GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
- GGML_ASSERT(false && "not implemented");
- }
- };
-
- if (ggml_is_quantized(src0t)) {
- GGML_ASSERT(ne00 >= nth0*nth1);
- }
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
-
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
- src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
- src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
- const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_Q4_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_Q3_K) {
-#ifdef GGML_QKK_64
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-#else
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-#endif
- }
- else if (src0t == GGML_TYPE_Q5_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_Q6_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- } else {
- const int64_t ny = (ne11 + nrows - 1)/nrows;
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- }
- } break;
- case GGML_OP_MUL_MAT_ID:
- {
- //GGML_ASSERT(ne00 == ne10);
- //GGML_ASSERT(ne03 == ne13);
-
- GGML_ASSERT(src0t == GGML_TYPE_I32);
-
- const int n_as = ((int32_t *) dst->op_params)[1];
-
- // TODO: make this more general
- GGML_ASSERT(n_as <= 8);
-
- // max size of the src1ids array in the kernel stack
- GGML_ASSERT(ne11 <= 512);
-
- struct ggml_tensor * src2 = gf->nodes[i]->src[2];
-
- const int64_t ne20 = src2 ? src2->ne[0] : 0;
- const int64_t ne21 = src2 ? src2->ne[1] : 0;
- const int64_t ne22 = src2 ? src2->ne[2] : 0;
- const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
-
- const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
- const uint64_t nb21 = src2 ? src2->nb[1] : 0;
- const uint64_t nb22 = src2 ? src2->nb[2] : 0;
- const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
-
- const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
-
- GGML_ASSERT(!ggml_is_transposed(src2));
- GGML_ASSERT(!ggml_is_transposed(src1));
-
- GGML_ASSERT(src1t == GGML_TYPE_F32);
-
- const uint r2 = ne12/ne22;
- const uint r3 = ne13/ne23;
-
- // find the break-even point where the matrix-matrix kernel becomes more efficient compared
- // to the matrix-vector kernel
- int ne11_mm_min = n_as;
-
- const int idx = ((int32_t *) dst->op_params)[0];
-
- // batch size
- GGML_ASSERT(ne01 == ne11);
-
- // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
- // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
- // !!!
- // TODO: for now, always use mat-vec kernels until we figure out how to improve the
- // indirect matrix multiplication
- // !!!
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
- ne20 % 32 == 0 && ne20 >= 64 &&
- ne11 > ne11_mm_min) {
-
- id pipeline = nil;
-
- switch (src2->type) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
- case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
- case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
- case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
- case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
- case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
- case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
- case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
- default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
- }
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
- [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
- // TODO: how to make this an array? read Metal docs
- for (int j = 0; j < 8; ++j) {
- // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
- struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
-
- size_t offs_src_cur = 0;
- id id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
-
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
- }
-
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
-
- [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
- } else {
- int nth0 = 32;
- int nth1 = 1;
- int nrows = 1;
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
-
- id pipeline = nil;
-
- // use custom matrix x vector kernel
- switch (src2t) {
- case GGML_TYPE_F32:
- {
- GGML_ASSERT(src1t == GGML_TYPE_F32);
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
- } break;
- case GGML_TYPE_F16:
- {
- GGML_ASSERT(src1t == GGML_TYPE_F32);
- nth0 = 32;
- nth1 = 1;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
- } break;
- case GGML_TYPE_Q4_0:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
- } break;
- case GGML_TYPE_Q4_1:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
- } break;
- case GGML_TYPE_Q5_0:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
- } break;
- case GGML_TYPE_Q5_1:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
- } break;
- case GGML_TYPE_Q8_0:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
- } break;
- case GGML_TYPE_Q2_K:
- {
- nth0 = 2;
- nth1 = 32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
- } break;
- case GGML_TYPE_Q3_K:
- {
- nth0 = 2;
- nth1 = 32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
- } break;
- case GGML_TYPE_Q4_K:
- {
- nth0 = 4; //1;
- nth1 = 8; //32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
- } break;
- case GGML_TYPE_Q5_K:
- {
- nth0 = 2;
- nth1 = 32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
- } break;
- case GGML_TYPE_Q6_K:
- {
- nth0 = 2;
- nth1 = 32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
- } break;
- case GGML_TYPE_IQ2_XXS:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
- } break;
- case GGML_TYPE_IQ2_XS:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
- } break;
- default:
- {
- GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
- GGML_ASSERT(false && "not implemented");
- }
- };
-
- if (ggml_is_quantized(src2t)) {
- GGML_ASSERT(ne20 >= nth0*nth1);
- }
-
- const int64_t _ne1 = 1; // kernels needs a reference in constant memory
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
- [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
- [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
- [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
- // TODO: how to make this an array? read Metal docs
- for (int j = 0; j < 8; ++j) {
- // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
- struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
-
- size_t offs_src_cur = 0;
- id id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
-
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
- }
-
- if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
- src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
- src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
- const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src2t == GGML_TYPE_Q4_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src2t == GGML_TYPE_Q3_K) {
-#ifdef GGML_QKK_64
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-#else
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-#endif
- }
- else if (src2t == GGML_TYPE_Q5_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src2t == GGML_TYPE_Q6_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- } else {
- const int64_t ny = (_ne1 + nrows - 1)/nrows;
- [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- }
- } break;
- case GGML_OP_GET_ROWS:
- {
- id pipeline = nil;
-
- switch (src0->type) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
- case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
- case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
- case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
- case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break;
- case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
- case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
- case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
- case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
- default: GGML_ASSERT(false && "not implemented");
- }
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
- [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
- [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
- [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
- } break;
- case GGML_OP_RMS_NORM:
- {
- GGML_ASSERT(ne00 % 4 == 0);
-
- float eps;
- memcpy(&eps, dst->op_params, sizeof(float));
-
- int nth = 32; // SIMD width
-
- while (nth < ne00/4 && nth < 1024) {
- nth *= 2;
- }
-
- id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
-
- const int64_t nrows = ggml_nrows(src0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_GROUP_NORM:
- {
- GGML_ASSERT(ne00 % 4 == 0);
-
- //float eps;
- //memcpy(&eps, dst->op_params, sizeof(float));
-
- const float eps = 1e-6f; // TODO: temporarily hardcoded
-
- const int32_t n_groups = ((int32_t *) dst->op_params)[0];
-
- int nth = 32; // SIMD width
-
- //while (nth < ne00/4 && nth < 1024) {
- // nth *= 2;
- //}
-
- id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
- [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
- [encoder setBytes:&eps length:sizeof( float) atIndex:9];
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
-
- [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_NORM:
- {
- float eps;
- memcpy(&eps, dst->op_params, sizeof(float));
-
- const int nth = MIN(256, ne00);
-
- id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
- [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
-
- const int64_t nrows = ggml_nrows(src0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_ALIBI:
- {
- GGML_ASSERT((src0t == GGML_TYPE_F32));
-
- const int nth = MIN(1024, ne00);
-
- //const int n_past = ((int32_t *) dst->op_params)[0];
- const int n_head = ((int32_t *) dst->op_params)[1];
- float max_bias;
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
-
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
-
- id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
- [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
- [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
- [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_ROPE:
- {
- GGML_ASSERT(ne10 == ne02);
-
- const int nth = MIN(1024, ne00);
-
- const int n_past = ((int32_t *) dst->op_params)[0];
- const int n_dims = ((int32_t *) dst->op_params)[1];
- const int mode = ((int32_t *) dst->op_params)[2];
- // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
-
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
-
- id pipeline = nil;
-
- switch (src0->type) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
- default: GGML_ASSERT(false);
- };
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
- [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
- [encoder setBytes:&mode length:sizeof( int) atIndex:21];
- [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_IM2COL:
- {
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F16);
-
- const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
- const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
- const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
- const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
- const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
- const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
- const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
-
- const int32_t N = src1->ne[is_2D ? 3 : 2];
- const int32_t IC = src1->ne[is_2D ? 2 : 1];
- const int32_t IH = is_2D ? src1->ne[1] : 1;
- const int32_t IW = src1->ne[0];
-
- const int32_t KH = is_2D ? src0->ne[1] : 1;
- const int32_t KW = src0->ne[0];
-
- const int32_t OH = is_2D ? dst->ne[2] : 1;
- const int32_t OW = dst->ne[1];
-
- const int32_t CHW = IC * KH * KW;
-
- const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
- const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
-
- id pipeline = nil;
-
- switch (src0->type) {
- case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
- default: GGML_ASSERT(false);
- };
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
- [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
- [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
- [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
- [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
- [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
- [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
- [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
- [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
- [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
- [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
-
- [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
- } break;
- case GGML_OP_UPSCALE:
- {
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
-
- const int sf = dst->op_params[0];
-
- const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
- [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
-
+ } else {
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_PAD:
- {
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ }
+ } break;
+ case GGML_OP_ACC:
+ {
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ GGML_ASSERT(dstt == GGML_TYPE_F32);
- id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
+ const size_t pnb1 = ((int32_t *) dst->op_params)[0];
+ const size_t pnb2 = ((int32_t *) dst->op_params)[1];
+ const size_t pnb3 = ((int32_t *) dst->op_params)[2];
+ const size_t offs = ((int32_t *) dst->op_params)[3];
- const int nth = MIN(1024, ne0);
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_ARGSORT:
- {
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_I32);
+ if (!inplace) {
+ // run a separete kernel to cpy src->dst
+ // not sure how to avoid this
+ // TODO: make a simpler cpy_bytes kernel
- const int nrows = ggml_nrows(src0);
-
- enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
-
- id pipeline = nil;
-
- switch (order) {
- case GGML_SORT_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
- case GGML_SORT_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
- default: GGML_ASSERT(false);
- };
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
-
- [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
- } break;
- case GGML_OP_LEAKY_RELU:
- {
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
-
- float slope;
- memcpy(&slope, dst->op_params, sizeof(float));
-
- id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
-
- const int64_t n = ggml_nelements(dst);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_DUP:
- case GGML_OP_CPY:
- case GGML_OP_CONT:
- {
- GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
-
- int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
-
- id pipeline = nil;
-
- switch (src0t) {
- case GGML_TYPE_F32:
- {
- GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
-
- switch (dstt) {
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
- //case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
- //case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
- default: GGML_ASSERT(false && "not implemented");
- };
- } break;
- case GGML_TYPE_F16:
- {
- switch (dstt) {
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
- default: GGML_ASSERT(false && "not implemented");
- };
- } break;
- default: GGML_ASSERT(false && "not implemented");
- }
+ const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -2213,28 +996,1253 @@ static bool ggml_metal_graph_compute(
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- default:
- {
- GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
- GGML_ASSERT(false);
- }
- }
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
-#ifndef GGML_METAL_NDEBUG
- [encoder popDebugGroup];
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ }
+
+ const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
+
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_SCALE:
+ {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ const float scale = *(const float *) dst->op_params;
+
+ int64_t n = ggml_nelements(dst);
+
+ id pipeline = nil;
+
+ if (n % 4 == 0) {
+ n /= 4;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(gf->nodes[i])) {
+ case GGML_UNARY_OP_TANH:
+ {
+ id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ const int64_t n = ggml_nelements(dst);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_UNARY_OP_RELU:
+ {
+ id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ const int64_t n = ggml_nelements(dst);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_UNARY_OP_GELU:
+ {
+ id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ const int64_t n = ggml_nelements(dst);
+ GGML_ASSERT(n % 4 == 0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_UNARY_OP_GELU_QUICK:
+ {
+ id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ const int64_t n = ggml_nelements(dst);
+ GGML_ASSERT(n % 4 == 0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_UNARY_OP_SILU:
+ {
+ id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ const int64_t n = ggml_nelements(dst);
+ GGML_ASSERT(n % 4 == 0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ default:
+ {
+ GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+ GGML_ASSERT(false);
+ }
+ } break;
+ case GGML_OP_SQR:
+ {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ const int64_t n = ggml_nelements(dst);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_OP_SUM_ROWS:
+ {
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+
+ id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_OP_SOFT_MAX:
+ {
+ int nth = 32; // SIMD width
+
+ id pipeline = nil;
+
+ if (ne00%4 == 0) {
+ while (nth < ne00/4 && nth < 256) {
+ nth *= 2;
+ }
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline;
+ } else {
+ while (nth < ne00 && nth < 1024) {
+ nth *= 2;
+ }
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
+ }
+
+ const float scale = ((float *) dst->op_params)[0];
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ if (id_src1) {
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ }
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_DIAG_MASK_INF:
+ {
+ const int n_past = ((int32_t *)(dst->op_params))[0];
+
+ id pipeline = nil;
+
+ if (ne00%8 == 0) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+ [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
+
+ if (ne00%8 == 0) {
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ }
+ else {
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ }
+ } break;
+ case GGML_OP_MUL_MAT:
+ {
+ GGML_ASSERT(ne00 == ne10);
+
+ // TODO: assert that dim2 and dim3 are contiguous
+ GGML_ASSERT(ne12 % ne02 == 0);
+ GGML_ASSERT(ne13 % ne03 == 0);
+
+ const uint r2 = ne12/ne02;
+ const uint r3 = ne13/ne03;
+
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
+ // to the matrix-vector kernel
+ int ne11_mm_min = 1;
+
+#if 0
+ // the numbers below are measured on M2 Ultra for 7B and 13B models
+ // these numbers do not translate to other devices or model sizes
+ // TODO: need to find a better approach
+ if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
+ switch (src0t) {
+ case GGML_TYPE_F16: ne11_mm_min = 2; break;
+ case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
+ case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
+ case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
+ case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
+ case GGML_TYPE_Q5_0: // not tested yet
+ case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
+ case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
+ case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
+ default: ne11_mm_min = 1; break;
+ }
+ }
#endif
+
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+ !ggml_is_transposed(src0) &&
+ !ggml_is_transposed(src1) &&
+ src1t == GGML_TYPE_F32 &&
+ ne00 % 32 == 0 && ne00 >= 64 &&
+ (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
+ //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+
+ id pipeline = nil;
+
+ switch (src0->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
+ default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+ } else {
+ int nth0 = 32;
+ int nth1 = 1;
+ int nrows = 1;
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+
+ id pipeline = nil;
+
+ // use custom matrix x vector kernel
+ switch (src0t) {
+ case GGML_TYPE_F32:
+ {
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
+ nrows = 4;
+ } break;
+ case GGML_TYPE_F16:
+ {
+ nth0 = 32;
+ nth1 = 1;
+ if (src1t == GGML_TYPE_F32) {
+ if (ne11 * ne12 < 4) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
+ nrows = ne11;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
+ nrows = 4;
+ }
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
+ nrows = 4;
+ }
+ } break;
+ case GGML_TYPE_Q4_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q4_1:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q5_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q5_1:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q8_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q2_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q3_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q4_K:
+ {
+ nth0 = 4; //1;
+ nth1 = 8; //32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q5_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q6_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ2_XXS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ2_XS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
+ } break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
+ GGML_ASSERT(false && "not implemented");
+ }
+ };
+
+ if (ggml_is_quantized(src0t)) {
+ GGML_ASSERT(ne00 >= nth0*nth1);
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
+
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
+ src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
+ src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
+ const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_Q4_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_Q3_K) {
+#ifdef GGML_QKK_64
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#else
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#endif
+ }
+ else if (src0t == GGML_TYPE_Q5_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_Q6_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ } else {
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ }
+ } break;
+ case GGML_OP_MUL_MAT_ID:
+ {
+ //GGML_ASSERT(ne00 == ne10);
+ //GGML_ASSERT(ne03 == ne13);
+
+ GGML_ASSERT(src0t == GGML_TYPE_I32);
+
+ const int n_as = ((int32_t *) dst->op_params)[1];
+
+ // TODO: make this more general
+ GGML_ASSERT(n_as <= 8);
+
+ // max size of the src1ids array in the kernel stack
+ GGML_ASSERT(ne11 <= 512);
+
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2];
+
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
+ const int64_t ne22 = src2 ? src2->ne[2] : 0;
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
+
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
+
+ const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
+
+ GGML_ASSERT(!ggml_is_transposed(src2));
+ GGML_ASSERT(!ggml_is_transposed(src1));
+
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+
+ const uint r2 = ne12/ne22;
+ const uint r3 = ne13/ne23;
+
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
+ // to the matrix-vector kernel
+ int ne11_mm_min = n_as;
+
+ const int idx = ((int32_t *) dst->op_params)[0];
+
+ // batch size
+ GGML_ASSERT(ne01 == ne11);
+
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
+ // !!!
+ // TODO: for now, always use mat-vec kernels until we figure out how to improve the
+ // indirect matrix multiplication
+ // !!!
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+ ne20 % 32 == 0 && ne20 >= 64 &&
+ ne11 > ne11_mm_min) {
+
+ id pipeline = nil;
+
+ switch (src2->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
+ default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
+ // TODO: how to make this an array? read Metal docs
+ for (int j = 0; j < 8; ++j) {
+ // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
+ struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
+
+ size_t offs_src_cur = 0;
+ id id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
+
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
+ }
+
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
+
+ [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+ } else {
+ int nth0 = 32;
+ int nth1 = 1;
+ int nrows = 1;
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+
+ id pipeline = nil;
+
+ // use custom matrix x vector kernel
+ switch (src2t) {
+ case GGML_TYPE_F32:
+ {
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
+ } break;
+ case GGML_TYPE_F16:
+ {
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ nth0 = 32;
+ nth1 = 1;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q4_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q4_1:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q5_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q5_1:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q8_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q2_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q3_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q4_K:
+ {
+ nth0 = 4; //1;
+ nth1 = 8; //32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q5_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q6_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ2_XXS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ2_XS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
+ } break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
+ GGML_ASSERT(false && "not implemented");
+ }
+ };
+
+ if (ggml_is_quantized(src2t)) {
+ GGML_ASSERT(ne20 >= nth0*nth1);
+ }
+
+ const int64_t _ne1 = 1; // kernels needs a reference in constant memory
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
+ [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
+ // TODO: how to make this an array? read Metal docs
+ for (int j = 0; j < 8; ++j) {
+ // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
+ struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
+
+ size_t offs_src_cur = 0;
+ id id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
+
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
+ }
+
+ if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
+ src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
+ src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
+ const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src2t == GGML_TYPE_Q4_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src2t == GGML_TYPE_Q3_K) {
+#ifdef GGML_QKK_64
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#else
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#endif
+ }
+ else if (src2t == GGML_TYPE_Q5_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src2t == GGML_TYPE_Q6_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ } else {
+ const int64_t ny = (_ne1 + nrows - 1)/nrows;
+ [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ }
+ } break;
+ case GGML_OP_GET_ROWS:
+ {
+ id pipeline = nil;
+
+ switch (src0->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break;
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
+ case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
+ default: GGML_ASSERT(false && "not implemented");
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
+ [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
+ [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
+ } break;
+ case GGML_OP_RMS_NORM:
+ {
+ GGML_ASSERT(ne00 % 4 == 0);
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ int nth = 32; // SIMD width
+
+ while (nth < ne00/4 && nth < 1024) {
+ nth *= 2;
+ }
+
+ id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+ const int64_t nrows = ggml_nrows(src0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_GROUP_NORM:
+ {
+ GGML_ASSERT(ne00 % 4 == 0);
+
+ //float eps;
+ //memcpy(&eps, dst->op_params, sizeof(float));
+
+ const float eps = 1e-6f; // TODO: temporarily hardcoded
+
+ const int32_t n_groups = ((int32_t *) dst->op_params)[0];
+
+ int nth = 32; // SIMD width
+
+ //while (nth < ne00/4 && nth < 1024) {
+ // nth *= 2;
+ //}
+
+ id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
+ [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
+ [encoder setBytes:&eps length:sizeof( float) atIndex:9];
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_NORM:
+ {
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ const int nth = MIN(256, ne00);
+
+ id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
+ [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
+
+ const int64_t nrows = ggml_nrows(src0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_ALIBI:
+ {
+ GGML_ASSERT((src0t == GGML_TYPE_F32));
+
+ const int nth = MIN(1024, ne00);
+
+ //const int n_past = ((int32_t *) dst->op_params)[0];
+ const int n_head = ((int32_t *) dst->op_params)[1];
+ float max_bias;
+ memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
+
+ const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
+ const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
+
+ id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
+ [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
+ [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
+ [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_ROPE:
+ {
+ GGML_ASSERT(ne10 == ne02);
+
+ const int nth = MIN(1024, ne00);
+
+ const int n_past = ((int32_t *) dst->op_params)[0];
+ const int n_dims = ((int32_t *) dst->op_params)[1];
+ const int mode = ((int32_t *) dst->op_params)[2];
+ // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+
+ id pipeline = nil;
+
+ switch (src0->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
+ default: GGML_ASSERT(false);
+ };
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
+ [encoder setBytes:&mode length:sizeof( int) atIndex:21];
+ [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_IM2COL:
+ {
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F16);
+
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+ const int32_t N = src1->ne[is_2D ? 3 : 2];
+ const int32_t IC = src1->ne[is_2D ? 2 : 1];
+ const int32_t IH = is_2D ? src1->ne[1] : 1;
+ const int32_t IW = src1->ne[0];
+
+ const int32_t KH = is_2D ? src0->ne[1] : 1;
+ const int32_t KW = src0->ne[0];
+
+ const int32_t OH = is_2D ? dst->ne[2] : 1;
+ const int32_t OW = dst->ne[1];
+
+ const int32_t CHW = IC * KH * KW;
+
+ const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
+ const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
+
+ id pipeline = nil;
+
+ switch (src0->type) {
+ case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
+ default: GGML_ASSERT(false);
+ };
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
+ [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
+ [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
+ [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
+ [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
+ [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
+ [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
+ [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
+ [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
+ [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
+ [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
+ } break;
+ case GGML_OP_UPSCALE:
+ {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+ const int sf = dst->op_params[0];
+
+ const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
+ [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
+
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_PAD:
+ {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+ id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
+
+ const int nth = MIN(1024, ne0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_ARGSORT:
+ {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
+
+ const int nrows = ggml_nrows(src0);
+
+ enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
+
+ id pipeline = nil;
+
+ switch (order) {
+ case GGML_SORT_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
+ case GGML_SORT_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
+ default: GGML_ASSERT(false);
+ };
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
+ } break;
+ case GGML_OP_LEAKY_RELU:
+ {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+ float slope;
+ memcpy(&slope, dst->op_params, sizeof(float));
+
+ id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
+
+ const int64_t n = ggml_nelements(dst);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_OP_DUP:
+ case GGML_OP_CPY:
+ case GGML_OP_CONT:
+ {
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
+
+ int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
+
+ id pipeline = nil;
+
+ switch (src0t) {
+ case GGML_TYPE_F32:
+ {
+ GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
+
+ switch (dstt) {
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
+ //case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
+ //case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
+ default: GGML_ASSERT(false && "not implemented");
+ };
+ } break;
+ case GGML_TYPE_F16:
+ {
+ switch (dstt) {
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
+ default: GGML_ASSERT(false && "not implemented");
+ };
+ } break;
+ default: GGML_ASSERT(false && "not implemented");
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+ GGML_ASSERT(false);
+ }
}
+#ifndef GGML_METAL_NDEBUG
+ [encoder popDebugGroup];
+#endif
+ }
+
+ if (encoder != nil) {
[encoder endEncoding];
+ encoder = nil;
+ }
- [command_buffer commit];
- });
- }
-
- // Wait for all command buffers to be committed
- dispatch_barrier_sync(ctx->d_queue, ^{});
+ [command_buffer commit];
+ });
// Wait for completion and check status of each command buffer
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
@@ -2356,6 +2364,25 @@ GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backe
UNUSED(buft);
}
+static void ggml_backend_metal_log_allocated_size(id device) {
+#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
+ if (@available(macOS 10.12, iOS 16.0, *)) {
+ GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
+ device.currentAllocatedSize / 1024.0 / 1024.0,
+ device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
+
+ if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
+ GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
+ } else {
+ GGML_METAL_LOG_INFO("\n");
+ }
+ } else {
+ GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
+ }
+#endif
+ UNUSED(device);
+}
+
GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
@@ -2388,22 +2415,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff
}
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
-
-
-#if TARGET_OS_OSX
- GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
- device.currentAllocatedSize / 1024.0 / 1024.0,
- device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
-
- if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
- GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
- } else {
- GGML_METAL_LOG_INFO("\n");
- }
-#else
- GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
-#endif
-
+ ggml_backend_metal_log_allocated_size(device);
return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
}
@@ -2511,19 +2523,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
}
}
-#if TARGET_OS_OSX
- GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
- device.currentAllocatedSize / 1024.0 / 1024.0,
- device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
-
- if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
- GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
- } else {
- GGML_METAL_LOG_INFO("\n");
- }
-#else
- GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
-#endif
+ ggml_backend_metal_log_allocated_size(device);
return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
}