ggml : full ALiBi support
This commit is contained in:
parent
d11afd6652
commit
7fdca3348c
10 changed files with 82 additions and 680 deletions
|
@ -4,7 +4,6 @@
|
||||||
|
|
||||||
#include "ggml-cuda/common.cuh"
|
#include "ggml-cuda/common.cuh"
|
||||||
#include "ggml-cuda/acc.cuh"
|
#include "ggml-cuda/acc.cuh"
|
||||||
#include "ggml-cuda/alibi.cuh"
|
|
||||||
#include "ggml-cuda/arange.cuh"
|
#include "ggml-cuda/arange.cuh"
|
||||||
#include "ggml-cuda/argsort.cuh"
|
#include "ggml-cuda/argsort.cuh"
|
||||||
#include "ggml-cuda/binbcast.cuh"
|
#include "ggml-cuda/binbcast.cuh"
|
||||||
|
@ -2277,9 +2276,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
ggml_cuda_op_rope(ctx, dst);
|
ggml_cuda_op_rope(ctx, dst);
|
||||||
break;
|
break;
|
||||||
case GGML_OP_ALIBI:
|
|
||||||
ggml_cuda_op_alibi(ctx, dst);
|
|
||||||
break;
|
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
ggml_cuda_op_im2col(ctx, dst);
|
ggml_cuda_op_im2col(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
@ -2829,7 +2825,6 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
case GGML_OP_ALIBI:
|
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
|
|
|
@ -1,63 +0,0 @@
|
||||||
#include "alibi.cuh"
|
|
||||||
|
|
||||||
static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows,
|
|
||||||
const int n_heads_log2_floor, const float m0, const float m1) {
|
|
||||||
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
|
||||||
|
|
||||||
if (col >= ncols) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int row = blockDim.y*blockIdx.y + threadIdx.y;
|
|
||||||
const int i = row*ncols + col;
|
|
||||||
|
|
||||||
const int k = row/k_rows;
|
|
||||||
|
|
||||||
float m_k;
|
|
||||||
if (k < n_heads_log2_floor) {
|
|
||||||
m_k = powf(m0, k + 1);
|
|
||||||
} else {
|
|
||||||
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
dst[i] = col * m_k + x[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
|
|
||||||
const int k_rows, const int n_heads_log2_floor, const float m0,
|
|
||||||
const float m1, cudaStream_t stream) {
|
|
||||||
const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1);
|
|
||||||
const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE);
|
|
||||||
const dim3 block_nums(num_blocks_x, nrows, 1);
|
|
||||||
alibi_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ggml_cuda_op_alibi(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
|
||||||
const float * src0_d = (const float *)src0->data;
|
|
||||||
float * dst_d = (float *)dst->data;
|
|
||||||
cudaStream_t stream = ctx.stream();
|
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
|
||||||
const int64_t ne01 = src0->ne[1];
|
|
||||||
const int64_t ne02 = src0->ne[2];
|
|
||||||
const int64_t nrows = ggml_nrows(src0);
|
|
||||||
|
|
||||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
||||||
const int n_head = ((int32_t *) dst->op_params)[1];
|
|
||||||
float max_bias;
|
|
||||||
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
|
||||||
|
|
||||||
//GGML_ASSERT(ne01 + n_past == ne00);
|
|
||||||
GGML_ASSERT(n_head == ne02);
|
|
||||||
|
|
||||||
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
|
||||||
|
|
||||||
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
|
||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
|
||||||
|
|
||||||
alibi_f32_cuda(src0_d, dst_d, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, stream);
|
|
||||||
}
|
|
|
@ -1,5 +0,0 @@
|
||||||
#include "common.cuh"
|
|
||||||
|
|
||||||
#define CUDA_ALIBI_BLOCK_SIZE 32
|
|
||||||
|
|
||||||
void ggml_cuda_op_alibi(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
72
ggml-metal.m
72
ggml-metal.m
|
@ -169,7 +169,6 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
|
||||||
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
||||||
|
@ -623,7 +622,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
||||||
|
@ -759,7 +757,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
return ctx->support_simdgroup_reduction;
|
return ctx->support_simdgroup_reduction;
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_ALIBI:
|
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
return true;
|
return true;
|
||||||
|
@ -1357,13 +1354,12 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
|
|
||||||
|
|
||||||
int nth = 32; // SIMD width
|
int nth = 32; // SIMD width
|
||||||
|
|
||||||
id<MTLComputePipelineState> pipeline = nil;
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
|
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
||||||
|
|
||||||
if (ne00%4 == 0) {
|
if (ne00%4 == 0) {
|
||||||
while (nth < ne00/4 && nth < 256) {
|
while (nth < ne00/4 && nth < 256) {
|
||||||
|
@ -1407,20 +1403,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
} else {
|
} else {
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||||
}
|
}
|
||||||
if (id_src2) {
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||||
} else {
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||||
}
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
|
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
|
||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
|
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
|
||||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
|
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
|
||||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:7];
|
|
||||||
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
|
|
||||||
[encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
|
|
||||||
[encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
|
|
||||||
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
|
|
||||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
|
@ -2225,49 +2216,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_ALIBI:
|
|
||||||
{
|
|
||||||
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
|
||||||
|
|
||||||
const int nth = MIN(1024, ne00);
|
|
||||||
|
|
||||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
||||||
const int n_head = ((int32_t *) dst->op_params)[1];
|
|
||||||
|
|
||||||
float max_bias;
|
|
||||||
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
|
||||||
|
|
||||||
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
|
||||||
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
|
||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
|
||||||
|
|
||||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
|
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
||||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
||||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
||||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
|
||||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
|
||||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
|
||||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
|
||||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
|
||||||
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
|
||||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
||||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
||||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
||||||
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
||||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
||||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
||||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
||||||
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
|
||||||
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
|
|
||||||
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
|
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
||||||
} break;
|
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne10 == ne02);
|
GGML_ASSERT(ne10 == ne02);
|
||||||
|
|
|
@ -356,7 +356,6 @@ template<typename T>
|
||||||
kernel void kernel_soft_max(
|
kernel void kernel_soft_max(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device const char * src2,
|
|
||||||
device char * dst,
|
device char * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
|
@ -378,10 +377,9 @@ kernel void kernel_soft_max(
|
||||||
|
|
||||||
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
|
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
|
||||||
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
|
||||||
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
|
|
||||||
float slope = 0.0f;
|
float slope = 1.0f;
|
||||||
|
|
||||||
// ALiBi
|
// ALiBi
|
||||||
if (max_bias > 0.0f) {
|
if (max_bias > 0.0f) {
|
||||||
|
@ -397,7 +395,7 @@ kernel void kernel_soft_max(
|
||||||
float lmax = -INFINITY;
|
float lmax = -INFINITY;
|
||||||
|
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
|
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
||||||
}
|
}
|
||||||
|
|
||||||
// find the max value in the block
|
// find the max value in the block
|
||||||
|
@ -422,7 +420,7 @@ kernel void kernel_soft_max(
|
||||||
// parallel sum
|
// parallel sum
|
||||||
float lsum = 0.0f;
|
float lsum = 0.0f;
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
|
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
||||||
lsum += exp_psrc0;
|
lsum += exp_psrc0;
|
||||||
pdst[i00] = exp_psrc0;
|
pdst[i00] = exp_psrc0;
|
||||||
}
|
}
|
||||||
|
@ -461,7 +459,6 @@ template<typename T>
|
||||||
kernel void kernel_soft_max_4(
|
kernel void kernel_soft_max_4(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device const char * src2,
|
|
||||||
device char * dst,
|
device char * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
|
@ -483,10 +480,9 @@ kernel void kernel_soft_max_4(
|
||||||
|
|
||||||
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
|
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
|
||||||
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
|
||||||
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||||
|
|
||||||
float slope = 0.0f;
|
float slope = 1.0f;
|
||||||
|
|
||||||
if (max_bias > 0.0f) {
|
if (max_bias > 0.0f) {
|
||||||
const int64_t h = i02;
|
const int64_t h = i02;
|
||||||
|
@ -501,7 +497,7 @@ kernel void kernel_soft_max_4(
|
||||||
float4 lmax4 = -INFINITY;
|
float4 lmax4 = -INFINITY;
|
||||||
|
|
||||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
|
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
||||||
}
|
}
|
||||||
|
|
||||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||||
|
@ -527,7 +523,7 @@ kernel void kernel_soft_max_4(
|
||||||
// parallel sum
|
// parallel sum
|
||||||
float4 lsum4 = 0.0f;
|
float4 lsum4 = 0.0f;
|
||||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
|
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
||||||
lsum4 += exp_psrc4;
|
lsum4 += exp_psrc4;
|
||||||
pdst4[i00] = exp_psrc4;
|
pdst4[i00] = exp_psrc4;
|
||||||
}
|
}
|
||||||
|
@ -1595,60 +1591,6 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_alibi_f32(
|
|
||||||
device const float * src0,
|
|
||||||
device float * dst,
|
|
||||||
constant int64_t & ne00,
|
|
||||||
constant int64_t & ne01,
|
|
||||||
constant int64_t & ne02,
|
|
||||||
constant int64_t & ne03,
|
|
||||||
constant uint64_t & nb00,
|
|
||||||
constant uint64_t & nb01,
|
|
||||||
constant uint64_t & nb02,
|
|
||||||
constant uint64_t & nb03,
|
|
||||||
constant int64_t & ne0,
|
|
||||||
constant int64_t & ne1,
|
|
||||||
constant int64_t & ne2,
|
|
||||||
constant int64_t & ne3,
|
|
||||||
constant uint64_t & nb0,
|
|
||||||
constant uint64_t & nb1,
|
|
||||||
constant uint64_t & nb2,
|
|
||||||
constant uint64_t & nb3,
|
|
||||||
constant float & m0,
|
|
||||||
constant float & m1,
|
|
||||||
constant int & n_heads_log2_floor,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
||||||
uint3 ntg[[threads_per_threadgroup]]) {
|
|
||||||
const int64_t i03 = tgpig[2];
|
|
||||||
const int64_t i02 = tgpig[1];
|
|
||||||
const int64_t i01 = tgpig[0];
|
|
||||||
|
|
||||||
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
||||||
|
|
||||||
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
||||||
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
||||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
||||||
//const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
||||||
|
|
||||||
const int64_t k = i3*ne3 + i2;
|
|
||||||
|
|
||||||
float m_k;
|
|
||||||
if (k < n_heads_log2_floor) {
|
|
||||||
m_k = pow(m0, k + 1);
|
|
||||||
} else {
|
|
||||||
m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
|
|
||||||
device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
|
||||||
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
||||||
const float src_v = *(device float *)(src_row + i00*nb00);
|
|
||||||
device float * dst_v = (device float *)(dst_row + i00*nb0);
|
|
||||||
*dst_v = i00 * m_k + src_v;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||||
return 1.0f - min(1.0f, max(0.0f, y));
|
return 1.0f - min(1.0f, max(0.0f, y));
|
||||||
|
|
|
@ -3154,7 +3154,6 @@ typedef float (*vec_dot_q_mul_mat_sycl_t)(
|
||||||
#define SYCL_SCALE_BLOCK_SIZE 256
|
#define SYCL_SCALE_BLOCK_SIZE 256
|
||||||
#define SYCL_CLAMP_BLOCK_SIZE 256
|
#define SYCL_CLAMP_BLOCK_SIZE 256
|
||||||
#define SYCL_ROPE_BLOCK_SIZE 256
|
#define SYCL_ROPE_BLOCK_SIZE 256
|
||||||
#define SYCL_ALIBI_BLOCK_SIZE 32
|
|
||||||
#define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32
|
#define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32
|
||||||
#define SYCL_QUANTIZE_BLOCK_SIZE 256
|
#define SYCL_QUANTIZE_BLOCK_SIZE 256
|
||||||
#define SYCL_DEQUANTIZE_BLOCK_SIZE 256
|
#define SYCL_DEQUANTIZE_BLOCK_SIZE 256
|
||||||
|
@ -9316,32 +9315,6 @@ static void rope_glm_f32(
|
||||||
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
|
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows,
|
|
||||||
const int n_heads_log2_floor, const float m0, const float m1,
|
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
|
||||||
const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
||||||
item_ct1.get_local_id(2);
|
|
||||||
|
|
||||||
if (col >= ncols) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
|
||||||
item_ct1.get_local_id(1);
|
|
||||||
const int i = row*ncols + col;
|
|
||||||
|
|
||||||
const int k = row/k_rows;
|
|
||||||
|
|
||||||
float m_k;
|
|
||||||
if (k < n_heads_log2_floor) {
|
|
||||||
m_k = dpct::pow(m0, k + 1);
|
|
||||||
} else {
|
|
||||||
m_k = dpct::pow(m1, 2 * (k - n_heads_log2_floor) + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
dst[i] = col * m_k + x[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
|
static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int row = item_ct1.get_group(1);
|
const int row = item_ct1.get_group(1);
|
||||||
|
@ -12964,20 +12937,6 @@ static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
static void alibi_f32_sycl(const float *x, float *dst, const int ncols,
|
|
||||||
const int nrows, const int k_rows,
|
|
||||||
const int n_heads_log2_floor, const float m0,
|
|
||||||
const float m1, dpct::queue_ptr stream) {
|
|
||||||
const sycl::range<3> block_dims(1, 1, SYCL_ALIBI_BLOCK_SIZE);
|
|
||||||
const int num_blocks_x = (ncols + SYCL_ALIBI_BLOCK_SIZE - 1) / (SYCL_ALIBI_BLOCK_SIZE);
|
|
||||||
const sycl::range<3> block_nums(1, nrows, num_blocks_x);
|
|
||||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
|
||||||
alibi_f32(x, dst, ncols, k_rows,
|
|
||||||
n_heads_log2_floor, m0, m1, item_ct1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
||||||
const int nrows, dpct::queue_ptr stream) {
|
const int nrows, dpct::queue_ptr stream) {
|
||||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||||
|
@ -14562,36 +14521,6 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
|
||||||
(void) src1_dd;
|
(void) src1_dd;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void ggml_sycl_op_alibi(const ggml_tensor *src0, const ggml_tensor *src1,
|
|
||||||
ggml_tensor *dst, const float *src0_dd,
|
|
||||||
const float *src1_dd, float *dst_dd,
|
|
||||||
const dpct::queue_ptr &main_stream) {
|
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
||||||
|
|
||||||
GGML_TENSOR_LOCALS_3(int64_t, ne0, src0, ne);
|
|
||||||
const int64_t nrows = ggml_nrows(src0);
|
|
||||||
|
|
||||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
||||||
const int n_head = ((int32_t *) dst->op_params)[1];
|
|
||||||
float max_bias;
|
|
||||||
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
|
||||||
|
|
||||||
//GGML_ASSERT(ne01 + n_past == ne00);
|
|
||||||
GGML_ASSERT(n_head == ne02);
|
|
||||||
|
|
||||||
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
|
||||||
|
|
||||||
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
|
||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
|
||||||
|
|
||||||
alibi_f32_sycl(src0_dd, dst_dd, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, main_stream);
|
|
||||||
|
|
||||||
(void) src1;
|
|
||||||
(void) src1_dd;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_sycl_op_pool2d(const ggml_tensor *src0,
|
static void ggml_sycl_op_pool2d(const ggml_tensor *src0,
|
||||||
const ggml_tensor *src1, ggml_tensor *dst,
|
const ggml_tensor *src1, ggml_tensor *dst,
|
||||||
const float *src0_dd, const float *src1_dd,
|
const float *src0_dd, const float *src1_dd,
|
||||||
|
@ -16232,10 +16161,6 @@ static void ggml_sycl_rope(const ggml_tensor * src0, const ggml_tensor * src1, g
|
||||||
ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_rope);
|
ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_rope);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_sycl_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
||||||
ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_alibi);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_sycl_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_sycl_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_pool2d);
|
ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_pool2d);
|
||||||
}
|
}
|
||||||
|
@ -16612,9 +16537,6 @@ bool ggml_sycl_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
func = ggml_sycl_rope;
|
func = ggml_sycl_rope;
|
||||||
break;
|
break;
|
||||||
case GGML_OP_ALIBI:
|
|
||||||
func = ggml_sycl_alibi;
|
|
||||||
break;
|
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
func = ggml_sycl_im2col;
|
func = ggml_sycl_im2col;
|
||||||
break;
|
break;
|
||||||
|
@ -17744,7 +17666,6 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
case GGML_OP_ALIBI:
|
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
|
|
275
ggml.c
275
ggml.c
|
@ -2185,7 +2185,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"SOFT_MAX_BACK",
|
"SOFT_MAX_BACK",
|
||||||
"ROPE",
|
"ROPE",
|
||||||
"ROPE_BACK",
|
"ROPE_BACK",
|
||||||
"ALIBI",
|
|
||||||
"CLAMP",
|
"CLAMP",
|
||||||
"CONV_TRANSPOSE_1D",
|
"CONV_TRANSPOSE_1D",
|
||||||
"IM2COL",
|
"IM2COL",
|
||||||
|
@ -2227,7 +2226,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"CROSS_ENTROPY_LOSS_BACK",
|
"CROSS_ENTROPY_LOSS_BACK",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
|
static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 77");
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"none",
|
"none",
|
||||||
|
@ -2276,7 +2275,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"soft_max_back(x)",
|
"soft_max_back(x)",
|
||||||
"rope(x)",
|
"rope(x)",
|
||||||
"rope_back(x)",
|
"rope_back(x)",
|
||||||
"alibi(x)",
|
|
||||||
"clamp(x)",
|
"clamp(x)",
|
||||||
"conv_transpose_1d(x)",
|
"conv_transpose_1d(x)",
|
||||||
"im2col(x)",
|
"im2col(x)",
|
||||||
|
@ -2318,7 +2316,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"cross_entropy_loss_back(x,y)",
|
"cross_entropy_loss_back(x,y)",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
|
static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 77");
|
||||||
|
|
||||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||||
|
|
||||||
|
@ -5646,7 +5644,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * mask,
|
struct ggml_tensor * mask,
|
||||||
struct ggml_tensor * pos,
|
|
||||||
float scale,
|
float scale,
|
||||||
float max_bias,
|
float max_bias,
|
||||||
bool inplace) {
|
bool inplace) {
|
||||||
|
@ -5660,20 +5657,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
||||||
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pos) {
|
|
||||||
GGML_ASSERT(ggml_is_vector(pos));
|
|
||||||
GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
|
|
||||||
GGML_ASSERT(pos->ne[0] == a->ne[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (pos && mask) {
|
|
||||||
GGML_ASSERT(pos->type == mask->type);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (max_bias > 0.0f) {
|
|
||||||
GGML_ASSERT(pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
if (a->grad) {
|
if (a->grad) {
|
||||||
|
@ -5689,7 +5672,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
result->src[0] = a;
|
result->src[0] = a;
|
||||||
result->src[1] = mask;
|
result->src[1] = mask;
|
||||||
result->src[2] = pos;
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -5697,23 +5679,22 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
||||||
struct ggml_tensor * ggml_soft_max(
|
struct ggml_tensor * ggml_soft_max(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a) {
|
struct ggml_tensor * a) {
|
||||||
return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, false);
|
return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_soft_max_inplace(
|
struct ggml_tensor * ggml_soft_max_inplace(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a) {
|
struct ggml_tensor * a) {
|
||||||
return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, true);
|
return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_soft_max_ext(
|
struct ggml_tensor * ggml_soft_max_ext(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * mask,
|
struct ggml_tensor * mask,
|
||||||
struct ggml_tensor * pos,
|
|
||||||
float scale,
|
float scale,
|
||||||
float max_bias) {
|
float max_bias) {
|
||||||
return ggml_soft_max_impl(ctx, a, mask, pos, scale, max_bias, false);
|
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_soft_max_back
|
// ggml_soft_max_back
|
||||||
|
@ -5928,37 +5909,6 @@ struct ggml_tensor * ggml_rope_back(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_alibi
|
|
||||||
|
|
||||||
struct ggml_tensor * ggml_alibi(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * a,
|
|
||||||
int n_past,
|
|
||||||
int n_head,
|
|
||||||
float bias_max) {
|
|
||||||
GGML_ASSERT(n_past >= 0);
|
|
||||||
bool is_node = false;
|
|
||||||
|
|
||||||
if (a->grad) {
|
|
||||||
GGML_ASSERT(false); // TODO: implement backward
|
|
||||||
is_node = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: when implement backward, fix this:
|
|
||||||
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
|
||||||
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
|
||||||
|
|
||||||
int32_t op_params[3] = { n_past, n_head };
|
|
||||||
memcpy(op_params + 2, &bias_max, sizeof(float));
|
|
||||||
ggml_set_op_params(result, op_params, sizeof(op_params));
|
|
||||||
|
|
||||||
result->op = GGML_OP_ALIBI;
|
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
|
||||||
result->src[0] = a;
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// ggml_clamp
|
// ggml_clamp
|
||||||
|
|
||||||
struct ggml_tensor * ggml_clamp(
|
struct ggml_tensor * ggml_clamp(
|
||||||
|
@ -13333,7 +13283,6 @@ static void ggml_compute_forward_soft_max_f32(
|
||||||
|
|
||||||
const struct ggml_tensor * src0 = dst->src[0];
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
const struct ggml_tensor * src1 = dst->src[1];
|
const struct ggml_tensor * src1 = dst->src[1];
|
||||||
const struct ggml_tensor * src2 = dst->src[2];
|
|
||||||
|
|
||||||
assert(ggml_is_contiguous(dst));
|
assert(ggml_is_contiguous(dst));
|
||||||
assert(ggml_are_same_shape(src0, dst));
|
assert(ggml_are_same_shape(src0, dst));
|
||||||
|
@ -13377,13 +13326,13 @@ static void ggml_compute_forward_soft_max_f32(
|
||||||
|
|
||||||
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
||||||
|
|
||||||
// when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
|
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
||||||
ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
|
|
||||||
float * pos_f32 = src2 ? (float *) src2->data : src0->data;
|
|
||||||
|
|
||||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
|
|
||||||
|
|
||||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
|
// ALiBi
|
||||||
|
const uint32_t h = (i1/ne01)%ne02; // head
|
||||||
|
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
||||||
|
|
||||||
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
||||||
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
||||||
|
|
||||||
|
@ -13396,27 +13345,11 @@ static void ggml_compute_forward_soft_max_f32(
|
||||||
if (mp_f32) {
|
if (mp_f32) {
|
||||||
if (use_f16) {
|
if (use_f16) {
|
||||||
for (int i = 0; i < nc; ++i) {
|
for (int i = 0; i < nc; ++i) {
|
||||||
wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
|
wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < nc; ++i) {
|
for (int i = 0; i < nc; ++i) {
|
||||||
wp[i] += mp_f32[i];
|
wp[i] += slope*mp_f32[i];
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ALiBi bias
|
|
||||||
if (max_bias > 0.0f) {
|
|
||||||
const uint32_t h = (i1/ne01)%ne02; // head
|
|
||||||
const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
|
|
||||||
|
|
||||||
if (use_f16) {
|
|
||||||
for (int i = 0; i < nc; ++i) {
|
|
||||||
wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int i = 0; i < nc; ++i) {
|
|
||||||
wp[i] += slope*pos_f32[i];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -13578,178 +13511,6 @@ static void ggml_compute_forward_soft_max_back(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_alibi
|
|
||||||
|
|
||||||
static void ggml_compute_forward_alibi_f32(
|
|
||||||
const struct ggml_compute_params * params,
|
|
||||||
struct ggml_tensor * dst) {
|
|
||||||
|
|
||||||
const struct ggml_tensor * src0 = dst->src[0];
|
|
||||||
|
|
||||||
assert(params->ith == 0);
|
|
||||||
|
|
||||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
||||||
const int n_head = ((int32_t *) dst->op_params)[1];
|
|
||||||
float max_bias;
|
|
||||||
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
|
||||||
|
|
||||||
const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
|
|
||||||
const int64_t ne1 = src0->ne[1]; // seq_len_without_past
|
|
||||||
const int64_t ne2 = src0->ne[2]; // n_head -> this is k
|
|
||||||
//const int64_t ne3 = src0->ne[3]; // 1 -> bsz
|
|
||||||
|
|
||||||
const int64_t n = ggml_nrows(src0);
|
|
||||||
const int64_t ne2_ne3 = n/ne1; // ne2*ne3
|
|
||||||
|
|
||||||
const size_t nb0 = src0->nb[0];
|
|
||||||
const size_t nb1 = src0->nb[1];
|
|
||||||
const size_t nb2 = src0->nb[2];
|
|
||||||
//const int nb3 = src0->nb[3];
|
|
||||||
|
|
||||||
GGML_ASSERT(nb0 == sizeof(float));
|
|
||||||
GGML_ASSERT(n_head == ne2);
|
|
||||||
|
|
||||||
// add alibi to src0 (KQ_scaled)
|
|
||||||
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
|
||||||
|
|
||||||
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
|
||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
|
||||||
|
|
||||||
for (int64_t k = 0; k < ne2_ne3; k++) {
|
|
||||||
// TODO: k*nb2 or k*nb3
|
|
||||||
float m_k;
|
|
||||||
|
|
||||||
if (k < n_heads_log2_floor) {
|
|
||||||
m_k = powf(m0, k + 1);
|
|
||||||
} else {
|
|
||||||
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int64_t i = 0; i < ne0; i++) {
|
|
||||||
for (int64_t j = 0; j < ne1; j++) {
|
|
||||||
float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
|
||||||
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
|
||||||
pdst[0] = i * m_k + src[0];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_compute_forward_alibi_f16(
|
|
||||||
const struct ggml_compute_params * params,
|
|
||||||
struct ggml_tensor * dst) {
|
|
||||||
|
|
||||||
const struct ggml_tensor * src0 = dst->src[0];
|
|
||||||
|
|
||||||
assert(params->ith == 0);
|
|
||||||
|
|
||||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
||||||
const int n_head = ((int32_t *) dst->op_params)[1];
|
|
||||||
float max_bias;
|
|
||||||
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
|
||||||
|
|
||||||
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
|
|
||||||
const int ne1 = src0->ne[1]; // seq_len_without_past
|
|
||||||
const int ne2 = src0->ne[2]; // n_head -> this is k
|
|
||||||
//const int ne3 = src0->ne[3]; // 1 -> bsz
|
|
||||||
|
|
||||||
const int n = ggml_nrows(src0);
|
|
||||||
const int ne2_ne3 = n/ne1; // ne2*ne3
|
|
||||||
|
|
||||||
const int nb0 = src0->nb[0];
|
|
||||||
const int nb1 = src0->nb[1];
|
|
||||||
const int nb2 = src0->nb[2];
|
|
||||||
//const int nb3 = src0->nb[3];
|
|
||||||
|
|
||||||
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
|
|
||||||
//GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
|
|
||||||
GGML_ASSERT(n_head == ne2);
|
|
||||||
|
|
||||||
// add alibi to src0 (KQ_scaled)
|
|
||||||
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
|
||||||
|
|
||||||
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
|
||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
|
||||||
|
|
||||||
for (int k = 0; k < ne2_ne3; k++) {
|
|
||||||
// TODO: k*nb2 or k*nb3
|
|
||||||
float m_k;
|
|
||||||
|
|
||||||
if (k < n_heads_log2_floor) {
|
|
||||||
m_k = powf(m0, k + 1);
|
|
||||||
} else {
|
|
||||||
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < ne0; i++) {
|
|
||||||
for (int j = 0; j < ne1; j++) {
|
|
||||||
ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
|
||||||
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
|
||||||
|
|
||||||
// we return F32
|
|
||||||
pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_compute_forward_alibi(
|
|
||||||
const struct ggml_compute_params * params,
|
|
||||||
struct ggml_tensor * dst) {
|
|
||||||
|
|
||||||
const struct ggml_tensor * src0 = dst->src[0];
|
|
||||||
|
|
||||||
switch (src0->type) {
|
|
||||||
case GGML_TYPE_F16:
|
|
||||||
{
|
|
||||||
ggml_compute_forward_alibi_f16(params, dst);
|
|
||||||
} break;
|
|
||||||
case GGML_TYPE_F32:
|
|
||||||
{
|
|
||||||
ggml_compute_forward_alibi_f32(params, dst);
|
|
||||||
} break;
|
|
||||||
case GGML_TYPE_BF16:
|
|
||||||
case GGML_TYPE_Q4_0:
|
|
||||||
case GGML_TYPE_Q4_1:
|
|
||||||
case GGML_TYPE_Q5_0:
|
|
||||||
case GGML_TYPE_Q5_1:
|
|
||||||
case GGML_TYPE_Q8_0:
|
|
||||||
case GGML_TYPE_Q8_1:
|
|
||||||
case GGML_TYPE_Q2_K:
|
|
||||||
case GGML_TYPE_Q3_K:
|
|
||||||
case GGML_TYPE_Q4_K:
|
|
||||||
case GGML_TYPE_Q5_K:
|
|
||||||
case GGML_TYPE_Q6_K:
|
|
||||||
case GGML_TYPE_IQ2_XXS:
|
|
||||||
case GGML_TYPE_IQ2_XS:
|
|
||||||
case GGML_TYPE_IQ3_XXS:
|
|
||||||
case GGML_TYPE_IQ1_S:
|
|
||||||
case GGML_TYPE_IQ1_M:
|
|
||||||
case GGML_TYPE_IQ4_NL:
|
|
||||||
case GGML_TYPE_IQ4_XS:
|
|
||||||
case GGML_TYPE_IQ3_S:
|
|
||||||
case GGML_TYPE_IQ2_S:
|
|
||||||
case GGML_TYPE_Q8_K:
|
|
||||||
case GGML_TYPE_I8:
|
|
||||||
case GGML_TYPE_I16:
|
|
||||||
case GGML_TYPE_I32:
|
|
||||||
case GGML_TYPE_I64:
|
|
||||||
case GGML_TYPE_F64:
|
|
||||||
case GGML_TYPE_COUNT:
|
|
||||||
{
|
|
||||||
GGML_ASSERT(false);
|
|
||||||
} break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ggml_compute_forward_clamp
|
// ggml_compute_forward_clamp
|
||||||
|
|
||||||
static void ggml_compute_forward_clamp_f32(
|
static void ggml_compute_forward_clamp_f32(
|
||||||
|
@ -17630,10 +17391,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
{
|
{
|
||||||
ggml_compute_forward_rope_back(params, tensor);
|
ggml_compute_forward_rope_back(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_ALIBI:
|
|
||||||
{
|
|
||||||
ggml_compute_forward_alibi(params, tensor);
|
|
||||||
} break;
|
|
||||||
case GGML_OP_CLAMP:
|
case GGML_OP_CLAMP:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_clamp(params, tensor);
|
ggml_compute_forward_clamp(params, tensor);
|
||||||
|
@ -18652,10 +18409,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
zero_table);
|
zero_table);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_ALIBI:
|
|
||||||
{
|
|
||||||
GGML_ASSERT(false); // TODO: not implemented
|
|
||||||
} break;
|
|
||||||
case GGML_OP_CLAMP:
|
case GGML_OP_CLAMP:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(false); // TODO: not implemented
|
GGML_ASSERT(false); // TODO: not implemented
|
||||||
|
@ -19428,10 +19181,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_ALIBI:
|
|
||||||
{
|
|
||||||
n_tasks = 1; //TODO
|
|
||||||
} break;
|
|
||||||
case GGML_OP_CLAMP:
|
case GGML_OP_CLAMP:
|
||||||
{
|
{
|
||||||
n_tasks = 1; //TODO
|
n_tasks = 1; //TODO
|
||||||
|
|
15
ggml.h
15
ggml.h
|
@ -468,7 +468,6 @@ extern "C" {
|
||||||
GGML_OP_SOFT_MAX_BACK,
|
GGML_OP_SOFT_MAX_BACK,
|
||||||
GGML_OP_ROPE,
|
GGML_OP_ROPE,
|
||||||
GGML_OP_ROPE_BACK,
|
GGML_OP_ROPE_BACK,
|
||||||
GGML_OP_ALIBI,
|
|
||||||
GGML_OP_CLAMP,
|
GGML_OP_CLAMP,
|
||||||
GGML_OP_CONV_TRANSPOSE_1D,
|
GGML_OP_CONV_TRANSPOSE_1D,
|
||||||
GGML_OP_IM2COL,
|
GGML_OP_IM2COL,
|
||||||
|
@ -1428,15 +1427,13 @@ extern "C" {
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
// fused soft_max(a*scale + mask + pos[i]*(ALiBi slope))
|
// fused soft_max(a*scale + mask*(ALiBi slope))
|
||||||
// mask is optional
|
// mask is optional
|
||||||
// pos is required when max_bias > 0.0f
|
|
||||||
// max_bias = 0.0f for no ALiBi
|
// max_bias = 0.0f for no ALiBi
|
||||||
GGML_API struct ggml_tensor * ggml_soft_max_ext(
|
GGML_API struct ggml_tensor * ggml_soft_max_ext(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * mask,
|
struct ggml_tensor * mask,
|
||||||
struct ggml_tensor * pos,
|
|
||||||
float scale,
|
float scale,
|
||||||
float max_bias);
|
float max_bias);
|
||||||
|
|
||||||
|
@ -1538,16 +1535,6 @@ extern "C" {
|
||||||
float xpos_base,
|
float xpos_base,
|
||||||
bool xpos_down);
|
bool xpos_down);
|
||||||
|
|
||||||
// alibi position embedding
|
|
||||||
// in-place, returns view(a)
|
|
||||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_alibi(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * a,
|
|
||||||
int n_past,
|
|
||||||
int n_head,
|
|
||||||
float bias_max),
|
|
||||||
"use ggml_soft_max_ext instead (will be removed in Mar 2024)");
|
|
||||||
|
|
||||||
// clamp
|
// clamp
|
||||||
// in-place, returns view(a)
|
// in-place, returns view(a)
|
||||||
GGML_API struct ggml_tensor * ggml_clamp(
|
GGML_API struct ggml_tensor * ggml_clamp(
|
||||||
|
|
170
llama.cpp
170
llama.cpp
|
@ -1845,7 +1845,7 @@ struct llama_hparams {
|
||||||
float f_logit_scale = 0.0f;
|
float f_logit_scale = 0.0f;
|
||||||
|
|
||||||
bool causal_attn = true;
|
bool causal_attn = true;
|
||||||
bool use_alibi = false; // currently, we need KQ_pos data for ALiBi-based models
|
bool use_alibi = false;
|
||||||
|
|
||||||
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
|
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||||
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
||||||
|
@ -2317,7 +2317,6 @@ struct llama_context {
|
||||||
struct ggml_tensor * inp_pos; // I32 [n_batch]
|
struct ggml_tensor * inp_pos; // I32 [n_batch]
|
||||||
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
|
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
|
||||||
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
|
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
|
||||||
struct ggml_tensor * inp_KQ_pos; // F32 [n_kv]
|
|
||||||
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
|
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
|
||||||
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
|
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
|
||||||
struct ggml_tensor * inp_cls; // I32 [n_batch]
|
struct ggml_tensor * inp_cls; // I32 [n_batch]
|
||||||
|
@ -6500,7 +6499,6 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
struct ggml_tensor * wo_b,
|
struct ggml_tensor * wo_b,
|
||||||
struct ggml_tensor * q_cur,
|
struct ggml_tensor * q_cur,
|
||||||
struct ggml_tensor * kq_mask,
|
struct ggml_tensor * kq_mask,
|
||||||
struct ggml_tensor * kq_pos,
|
|
||||||
int32_t n_tokens,
|
int32_t n_tokens,
|
||||||
int32_t n_kv,
|
int32_t n_kv,
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
|
@ -6530,10 +6528,6 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
GGML_UNUSED(model);
|
GGML_UNUSED(model);
|
||||||
GGML_UNUSED(n_ctx);
|
GGML_UNUSED(n_ctx);
|
||||||
|
|
||||||
// note: if this assert triggers, then some check has failed earlier
|
|
||||||
// the idea is to detect during context creation that ALiBi would be used and disable Flash Attention
|
|
||||||
GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention");
|
|
||||||
|
|
||||||
// split cached v into n_head heads (not transposed)
|
// split cached v into n_head heads (not transposed)
|
||||||
struct ggml_tensor * v =
|
struct ggml_tensor * v =
|
||||||
ggml_view_3d(ctx, kv.v_l[il],
|
ggml_view_3d(ctx, kv.v_l[il],
|
||||||
|
@ -6574,28 +6568,8 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
kq = ggml_scale(ctx, kq, 30);
|
kq = ggml_scale(ctx, kq, 30);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(GGML_USE_KOMPUTE)
|
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||||
#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute")
|
cb(kq, "kq_soft_max_ext", il);
|
||||||
#pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024")
|
|
||||||
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488")
|
|
||||||
if (hparams.use_alibi) {
|
|
||||||
kq = ggml_scale(ctx, kq, kq_scale);
|
|
||||||
cb(kq, "kq_scaled", il);
|
|
||||||
|
|
||||||
kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias);
|
|
||||||
cb(kq, "kq_scaled_alibi", il);
|
|
||||||
|
|
||||||
kq = ggml_add(ctx, kq, kq_mask);
|
|
||||||
cb(kq, "kq_masked", il);
|
|
||||||
|
|
||||||
kq = ggml_soft_max(ctx, kq);
|
|
||||||
cb(kq, "kq_soft_max", il);
|
|
||||||
} else
|
|
||||||
#endif
|
|
||||||
{
|
|
||||||
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias);
|
|
||||||
cb(kq, "kq_soft_max_ext", il);
|
|
||||||
}
|
|
||||||
|
|
||||||
GGML_ASSERT(kv.size == n_ctx);
|
GGML_ASSERT(kv.size == n_ctx);
|
||||||
|
|
||||||
|
@ -6645,7 +6619,6 @@ static struct ggml_tensor * llm_build_kv(
|
||||||
struct ggml_tensor * v_cur,
|
struct ggml_tensor * v_cur,
|
||||||
struct ggml_tensor * q_cur,
|
struct ggml_tensor * q_cur,
|
||||||
struct ggml_tensor * kq_mask,
|
struct ggml_tensor * kq_mask,
|
||||||
struct ggml_tensor * kq_pos,
|
|
||||||
int32_t n_tokens,
|
int32_t n_tokens,
|
||||||
int32_t kv_head,
|
int32_t kv_head,
|
||||||
int32_t n_kv,
|
int32_t n_kv,
|
||||||
|
@ -6664,7 +6637,7 @@ static struct ggml_tensor * llm_build_kv(
|
||||||
struct ggml_tensor * cur;
|
struct ggml_tensor * cur;
|
||||||
|
|
||||||
cur = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b,
|
cur = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b,
|
||||||
q_cur, kq_mask, kq_pos, n_tokens, n_kv, kq_scale, cb, il);
|
q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
|
||||||
return cur;
|
return cur;
|
||||||
|
@ -6771,18 +6744,17 @@ struct llm_build_context {
|
||||||
|
|
||||||
ctx0 = ggml_init(params);
|
ctx0 = ggml_init(params);
|
||||||
|
|
||||||
lctx.inp_tokens = nullptr;
|
lctx.inp_tokens = nullptr;
|
||||||
lctx.inp_embd = nullptr;
|
lctx.inp_embd = nullptr;
|
||||||
lctx.inp_pos = nullptr;
|
lctx.inp_pos = nullptr;
|
||||||
lctx.inp_out_ids = nullptr;
|
lctx.inp_out_ids = nullptr;
|
||||||
lctx.inp_KQ_mask = nullptr;
|
lctx.inp_KQ_mask = nullptr;
|
||||||
lctx.inp_KQ_pos = nullptr;
|
|
||||||
lctx.inp_K_shift = nullptr;
|
lctx.inp_K_shift = nullptr;
|
||||||
lctx.inp_mean = nullptr;
|
lctx.inp_mean = nullptr;
|
||||||
lctx.inp_cls = nullptr;
|
lctx.inp_cls = nullptr;
|
||||||
lctx.inp_s_copy = nullptr;
|
lctx.inp_s_copy = nullptr;
|
||||||
lctx.inp_s_mask = nullptr;
|
lctx.inp_s_mask = nullptr;
|
||||||
lctx.inp_s_seq = nullptr;
|
lctx.inp_s_seq = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void free() {
|
void free() {
|
||||||
|
@ -6932,19 +6904,6 @@ struct llm_build_context {
|
||||||
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
|
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * build_inp_KQ_pos(bool causal = true) {
|
|
||||||
if (causal) {
|
|
||||||
lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv);
|
|
||||||
} else {
|
|
||||||
// TODO: this will be needed for ALiBi-based BERT models
|
|
||||||
// https://github.com/ggerganov/llama.cpp/pull/6826
|
|
||||||
lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_tokens);
|
|
||||||
}
|
|
||||||
cb(lctx.inp_KQ_pos, "KQ_pos", -1);
|
|
||||||
ggml_set_input(lctx.inp_KQ_pos);
|
|
||||||
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16) : lctx.inp_KQ_pos;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor * build_inp_mean() {
|
struct ggml_tensor * build_inp_mean() {
|
||||||
lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
|
lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
|
||||||
cb(lctx.inp_mean, "inp_mean", -1);
|
cb(lctx.inp_mean, "inp_mean", -1);
|
||||||
|
@ -7050,7 +7009,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -7143,9 +7102,6 @@ struct llm_build_context {
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||||
|
|
||||||
// positions of the tokens in the KV cache
|
|
||||||
struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
|
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
struct ggml_tensor * inpSA = inpL;
|
struct ggml_tensor * inpSA = inpL;
|
||||||
|
|
||||||
|
@ -7190,7 +7146,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -7260,9 +7216,6 @@ struct llm_build_context {
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||||
|
|
||||||
// positions of the tokens in the KV cache
|
|
||||||
struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
|
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
struct ggml_tensor * inpSA = inpL;
|
struct ggml_tensor * inpSA = inpL;
|
||||||
|
|
||||||
|
@ -7297,7 +7250,7 @@ struct llm_build_context {
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -7417,7 +7370,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -7542,7 +7495,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -7694,7 +7647,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -7806,7 +7759,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -8010,7 +7963,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Q, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Q, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -8076,9 +8029,6 @@ struct llm_build_context {
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||||
|
|
||||||
// positions of the tokens in the KV cache
|
|
||||||
struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
|
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
struct ggml_tensor * inpSA = inpL;
|
struct ggml_tensor * inpSA = inpL;
|
||||||
|
|
||||||
|
@ -8106,7 +8056,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -8246,7 +8196,7 @@ struct llm_build_context {
|
||||||
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||||
cb(kq, "kq", il);
|
cb(kq, "kq", il);
|
||||||
|
|
||||||
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
|
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
|
||||||
cb(kq, "kq_soft_max_ext", il);
|
cb(kq, "kq_soft_max_ext", il);
|
||||||
|
|
||||||
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
|
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
|
||||||
|
@ -8363,9 +8313,6 @@ struct llm_build_context {
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||||
|
|
||||||
// positions of the tokens in the KV cache
|
|
||||||
struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
|
|
||||||
|
|
||||||
inpL = llm_build_norm(ctx0, inpL, hparams,
|
inpL = llm_build_norm(ctx0, inpL, hparams,
|
||||||
model.tok_norm,
|
model.tok_norm,
|
||||||
model.tok_norm_b,
|
model.tok_norm_b,
|
||||||
|
@ -8399,7 +8346,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -8464,9 +8411,6 @@ struct llm_build_context {
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||||
|
|
||||||
// positions of the tokens in the KV cache
|
|
||||||
struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
|
|
||||||
|
|
||||||
if (model.pos_embd) {
|
if (model.pos_embd) {
|
||||||
// inp_pos - contains the positions
|
// inp_pos - contains the positions
|
||||||
struct ggml_tensor * inp_pos = build_inp_pos();
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||||
|
@ -8530,13 +8474,13 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
} else {
|
} else {
|
||||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8680,7 +8624,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -8798,7 +8742,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -8911,7 +8855,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -9025,7 +8969,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -9180,7 +9124,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -9297,7 +9241,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -9410,7 +9354,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
struct ggml_tensor * sa_out = cur;
|
struct ggml_tensor * sa_out = cur;
|
||||||
|
|
||||||
|
@ -9513,7 +9457,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -9620,7 +9564,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -9736,7 +9680,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -9853,7 +9797,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -9983,7 +9927,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -10104,7 +10048,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -10223,7 +10167,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -10513,7 +10457,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, model.layers[il].bo,
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -10644,7 +10588,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
model.layers[il].wo, nullptr,
|
model.layers[il].wo, nullptr,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
|
@ -11032,7 +10976,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
|
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
|
||||||
f = -INFINITY;
|
f = -INFINITY;
|
||||||
} else {
|
} else {
|
||||||
f = 0.0f;
|
if (hparams.use_alibi) {
|
||||||
|
f = -fabs(lctx.kv_self.cells[i].pos - pos);
|
||||||
|
} else {
|
||||||
|
f = 0.0f;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
||||||
}
|
}
|
||||||
|
@ -11055,7 +11003,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
float f = -INFINITY;
|
float f = -INFINITY;
|
||||||
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
|
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
|
||||||
if (batch.seq_id[i][s] == seq_id) {
|
if (batch.seq_id[i][s] == seq_id) {
|
||||||
f = 0.0f;
|
if (hparams.use_alibi) {
|
||||||
|
f = -fabs(batch.pos[i] - batch.pos[j]);
|
||||||
|
} else {
|
||||||
|
f = 0.0f;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11071,21 +11023,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ALiBi requires the KQ_pos tensor to provide the sequence position of each token in the batch
|
|
||||||
// this allows to process multiple sequences in parallel with ALiBi-based models
|
|
||||||
if (hparams.use_alibi) {
|
|
||||||
const int64_t n_kv = kv_self.n;
|
|
||||||
|
|
||||||
GGML_ASSERT(lctx.inp_KQ_pos);
|
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer));
|
|
||||||
|
|
||||||
float * data = (float *) lctx.inp_KQ_pos->data;
|
|
||||||
|
|
||||||
for (int i = 0; i < n_kv; ++i) {
|
|
||||||
data[i] = float(lctx.kv_self.cells[i].pos);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
|
@ -15509,11 +15446,6 @@ struct llama_context * llama_new_context_with_model(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cparams.flash_attn && hparams.use_alibi) {
|
|
||||||
LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with ALiBi - forcing off\n", __func__);
|
|
||||||
cparams.flash_attn = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) {
|
if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) {
|
||||||
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
|
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
|
||||||
cparams.flash_attn = false;
|
cparams.flash_attn = false;
|
||||||
|
|
|
@ -1111,11 +1111,7 @@ struct test_soft_max : public test_case {
|
||||||
if (this->mask) {
|
if (this->mask) {
|
||||||
mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]);
|
mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]);
|
||||||
}
|
}
|
||||||
ggml_tensor * pos = nullptr;
|
ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
|
||||||
if (max_bias > 0.0f) {
|
|
||||||
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]);
|
|
||||||
}
|
|
||||||
ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias);
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1611,7 +1607,7 @@ public:
|
||||||
|
|
||||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||||
|
|
||||||
kq = ggml_soft_max_ext(ctx, kq, kq_mask, nullptr, kq_scale, 0.0f);
|
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, 0.0f);
|
||||||
|
|
||||||
// split cached v into n_head heads
|
// split cached v into n_head heads
|
||||||
struct ggml_tensor * v =
|
struct ggml_tensor * v =
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue