metal : do not build bfloat kernels when not supported
ggml-ci
This commit is contained in:
parent
a408f51906
commit
ad1226982f
2 changed files with 29 additions and 3 deletions
|
@ -496,7 +496,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
// dictionary of preprocessor macros
|
// dictionary of preprocessor macros
|
||||||
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
||||||
|
|
||||||
MTLCompileOptions* options = [MTLCompileOptions new];
|
if (!ctx_dev->has_bfloat) {
|
||||||
|
[prep setObject:@"GGML_METAL_NO_BFLOAT" forKey:@"GGML_METAL_NO_BFLOAT"];
|
||||||
|
}
|
||||||
|
|
||||||
|
MTLCompileOptions * options = [MTLCompileOptions new];
|
||||||
options.preprocessorMacros = prep;
|
options.preprocessorMacros = prep;
|
||||||
|
|
||||||
//[options setFastMathEnabled:false];
|
//[options setFastMathEnabled:false];
|
||||||
|
|
|
@ -12,12 +12,14 @@ using namespace metal;
|
||||||
|
|
||||||
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
||||||
|
|
||||||
|
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||||
|
typedef matrix<bfloat, 4, 4> bfloat4x4;
|
||||||
|
#endif
|
||||||
|
|
||||||
constexpr constant static float kvalues_iq4nl_f[16] = {
|
constexpr constant static float kvalues_iq4nl_f[16] = {
|
||||||
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef matrix<bfloat, 4, 4> bfloat4x4;
|
|
||||||
|
|
||||||
// NOTE: this is not dequantizing - we are simply fitting the template
|
// NOTE: this is not dequantizing - we are simply fitting the template
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
||||||
|
@ -29,10 +31,12 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
||||||
reg = (type4x4)(*src);
|
reg = (type4x4)(*src);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
|
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
|
||||||
reg = (type4x4)(*src);
|
reg = (type4x4)(*src);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
||||||
|
@ -2048,8 +2052,10 @@ typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
|
||||||
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
|
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
|
||||||
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
|
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
|
||||||
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
|
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
|
||||||
|
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||||
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, float, float4>;
|
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, float, float4>;
|
||||||
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
|
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
|
||||||
|
#endif
|
||||||
|
|
||||||
template<typename T, typename T4>
|
template<typename T, typename T4>
|
||||||
kernel void kernel_mul_mv_1row(
|
kernel void kernel_mul_mv_1row(
|
||||||
|
@ -2119,7 +2125,9 @@ kernel void kernel_mul_mv_1row(
|
||||||
typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
|
typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
|
||||||
|
|
||||||
template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
|
template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
|
||||||
|
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||||
template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<bfloat, bfloat4>;
|
template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<bfloat, bfloat4>;
|
||||||
|
#endif
|
||||||
|
|
||||||
// Assumes row size (ne00) is a multiple of 4
|
// Assumes row size (ne00) is a multiple of 4
|
||||||
template<typename T, typename T4>
|
template<typename T, typename T4>
|
||||||
|
@ -2179,7 +2187,9 @@ kernel void kernel_mul_mv_l4(
|
||||||
typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
|
typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
|
||||||
|
|
||||||
template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
|
template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
|
||||||
|
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||||
template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<bfloat, bfloat4>;
|
template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<bfloat, bfloat4>;
|
||||||
|
#endif
|
||||||
|
|
||||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||||
|
@ -3578,11 +3588,15 @@ typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
|
||||||
|
|
||||||
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
|
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
|
||||||
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
|
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
|
||||||
|
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||||
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
|
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
|
||||||
|
#endif
|
||||||
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
|
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
|
||||||
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
|
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
|
||||||
|
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||||
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bfloat, float>;
|
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bfloat, float>;
|
||||||
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
|
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
|
||||||
|
#endif
|
||||||
|
|
||||||
kernel void kernel_cpy_f32_q8_0(
|
kernel void kernel_cpy_f32_q8_0(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
|
@ -6487,7 +6501,9 @@ typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
|
||||||
|
|
||||||
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
|
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
|
||||||
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
|
||||||
|
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||||
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat>;
|
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat>;
|
||||||
|
#endif
|
||||||
|
|
||||||
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
||||||
|
|
||||||
|
@ -6519,7 +6535,9 @@ typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, de
|
||||||
|
|
||||||
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
||||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
||||||
|
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||||
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
||||||
|
#endif
|
||||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
||||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
||||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
||||||
|
@ -6548,7 +6566,9 @@ typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
|
||||||
|
|
||||||
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
||||||
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
||||||
|
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||||
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4, 1, dequantize_bf16>;
|
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4, 1, dequantize_bf16>;
|
||||||
|
#endif
|
||||||
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
||||||
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
||||||
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
||||||
|
@ -6772,7 +6792,9 @@ typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float
|
||||||
|
|
||||||
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
|
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
|
||||||
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
|
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
|
||||||
|
#if !defined(GGML_METAL_NO_BFLOAT)
|
||||||
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
|
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
|
||||||
|
#endif
|
||||||
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
|
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
|
||||||
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
||||||
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue