Compare commits

...
Sign in to create a new pull request.

4 commits

Author SHA1 Message Date
Georgi Gerganov
608f449880
swift : fix build
ggml-ci
2024-02-23 19:02:09 +02:00
Georgi Gerganov
fff1e8a54a
batched.swift : fix build
ggml-ci
2024-02-23 16:15:37 +02:00
Georgi Gerganov
8772658b11
ggml : add I32 <-> F32 conversion
ggml-ci
2024-02-23 14:25:05 +02:00
Georgi Gerganov
fc775366f1
llama : switch to floating-point token positions
ggml-ci
2024-02-23 12:34:16 +02:00
18 changed files with 130 additions and 106 deletions

View file

@ -1015,9 +1015,9 @@ static struct ggml_tensor * forward_lora(
struct ggml_tensor * kc = kv_self.k; struct ggml_tensor * kc = kv_self.k;
struct ggml_tensor * vc = kv_self.v; struct ggml_tensor * vc = kv_self.v;
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, N);
{ {
int * data = (int *) KQ_pos->data; float * data = (float *) KQ_pos->data;
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
data[i] = n_past + i; data[i] = n_past + i;
} }

View file

@ -79,7 +79,7 @@ batch.n_tokens = Int32(tokens.count)
for (i, token) in tokens.enumerated() { for (i, token) in tokens.enumerated() {
batch.token[i] = token batch.token[i] = token
batch.pos[i] = Int32(i) batch.pos[i] = llama_pos(i)
batch.n_seq_id[i] = 1 batch.n_seq_id[i] = 1
// batch.seq_id[i][0] = 0 // batch.seq_id[i][0] = 0
// TODO: is this the proper way to do this? // TODO: is this the proper way to do this?
@ -98,7 +98,7 @@ if llama_decode(context, batch) != 0 {
} }
for i in 1 ..< n_parallel { for i in 1 ..< n_parallel {
llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens) llama_kv_cache_seq_cp(context, 0, Int32(i), 0, llama_pos(batch.n_tokens))
} }
if n_parallel > 1 { if n_parallel > 1 {
@ -125,8 +125,8 @@ while n_cur <= n_len {
continue continue
} }
var n_vocab = llama_n_vocab(model) let n_vocab = llama_n_vocab(model)
var logits = llama_get_logits_ith(context, i_batch[i]) let logits = llama_get_logits_ith(context, i_batch[i])
var candidates: [llama_token_data] = .init(repeating: llama_token_data(), count: Int(n_vocab)) var candidates: [llama_token_data] = .init(repeating: llama_token_data(), count: Int(n_vocab))
@ -173,7 +173,7 @@ while n_cur <= n_len {
// push this new token for next evaluation // push this new token for next evaluation
batch.token[Int(batch.n_tokens)] = new_token_id batch.token[Int(batch.n_tokens)] = new_token_id
batch.pos[Int(batch.n_tokens)] = n_cur batch.pos[Int(batch.n_tokens)] = llama_pos(n_cur)
batch.n_seq_id[Int(batch.n_tokens)] = 1 batch.n_seq_id[Int(batch.n_tokens)] = 1
if let seq_id = batch.seq_id[Int(batch.n_tokens)] { if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
seq_id[0] = Int32(i) seq_id[0] = Int32(i)

View file

@ -554,7 +554,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
}; };
// KQ_pos - contains the positions // KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N); struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, N);
ggml_set_input(KQ_pos); ggml_set_input(KQ_pos);
// rope has so much parameters that we make a custom function for it // rope has so much parameters that we make a custom function for it
@ -743,7 +743,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
// set KQ_pos // set KQ_pos
{ {
int * data = (int *) KQ_pos->data; float * data = (float *) KQ_pos->data;
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
data[i] = n_past + i; data[i] = n_past + i;
} }

View file

@ -129,7 +129,7 @@ actor LlamaContext {
for i1 in 0..<tokens_list.count { for i1 in 0..<tokens_list.count {
let i = Int(i1) let i = Int(i1)
llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false) llama_batch_add(&batch, tokens_list[i], llama_pos(i), [0], false)
} }
batch.logits[Int(batch.n_tokens) - 1] = 1 // true batch.logits[Int(batch.n_tokens) - 1] = 1 // true
@ -183,7 +183,7 @@ actor LlamaContext {
// tokens_list.append(new_token_id) // tokens_list.append(new_token_id)
llama_batch_clear(&batch) llama_batch_clear(&batch)
llama_batch_add(&batch, new_token_id, n_cur, [0], true) llama_batch_add(&batch, new_token_id, llama_pos(n_cur), [0], true)
n_decode += 1 n_decode += 1
n_cur += 1 n_cur += 1
@ -210,7 +210,7 @@ actor LlamaContext {
let n_tokens = pp let n_tokens = pp
for i in 0..<n_tokens { for i in 0..<n_tokens {
llama_batch_add(&batch, 0, Int32(i), [0], false) llama_batch_add(&batch, 0, llama_pos(i), [0], false)
} }
batch.logits[Int(batch.n_tokens) - 1] = 1 // true batch.logits[Int(batch.n_tokens) - 1] = 1 // true
@ -234,7 +234,7 @@ actor LlamaContext {
llama_batch_clear(&batch) llama_batch_clear(&batch)
for j in 0..<pl { for j in 0..<pl {
llama_batch_add(&batch, 0, Int32(i), [Int32(j)], true) llama_batch_add(&batch, 0, llama_pos(i), [Int32(j)], true)
} }
if llama_decode(context, batch) != 0 { if llama_decode(context, batch) != 0 {

View file

@ -338,7 +338,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
if (n_eval > n_batch) { if (n_eval > n_batch) {
n_eval = n_batch; n_eval = n_batch;
} }
llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, }; llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, (float) *n_past, 1, 0, };
if (llama_decode(ctx_llama, batch)) { if (llama_decode(ctx_llama, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return false; return false;

View file

@ -1281,7 +1281,7 @@ struct llama_server_context
} }
const int n_embd = llama_n_embd(model); const int n_embd = llama_n_embd(model);
llama_batch batch_img = { n_eval, nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, slot.n_past, 1, 0, }; llama_batch batch_img = { n_eval, nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, (float) slot.n_past, 1, 0, };
if (llama_decode(ctx, batch_img)) if (llama_decode(ctx, batch_img))
{ {
LOG_TEE("%s : failed to eval image\n", __func__); LOG_TEE("%s : failed to eval image\n", __func__);

View file

@ -291,7 +291,7 @@ static struct ggml_tensor * llama_build_train_graphs(
}; };
// KQ_pos - contains the positions // KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N); struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, N);
ggml_set_input(KQ_pos); ggml_set_input(KQ_pos);
// rope has so much parameters that we make a custom function for it // rope has so much parameters that we make a custom function for it
@ -419,7 +419,7 @@ static struct ggml_tensor * llama_build_train_graphs(
ggml_gallocr_alloc_graph(alloc, gb); ggml_gallocr_alloc_graph(alloc, gb);
if (!measure_only) { if (!measure_only) {
int * data = (int *) KQ_pos->data; float * data = (float *) KQ_pos->data;
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
data[i] = n_past + i; data[i] = n_past + i;
} }

View file

@ -6040,7 +6040,7 @@ static __device__ void rope_yarn(
// rope == RoPE == rotary positional embedding // rope == RoPE == rotary positional embedding
template<typename T, bool has_pos> template<typename T, bool has_pos>
static __global__ void rope( static __global__ void rope(
const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, const T * x, T * dst, int ncols, const float * pos, float freq_scale, int p_delta_rows, float freq_base,
float ext_factor, float attn_factor, rope_corr_dims corr_dims float ext_factor, float attn_factor, rope_corr_dims corr_dims
) { ) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
@ -6053,7 +6053,7 @@ static __global__ void rope(
const int i = row*ncols + col; const int i = row*ncols + col;
const int i2 = row/p_delta_rows; const int i2 = row/p_delta_rows;
const int p = has_pos ? pos[i2] : 0; const float p = has_pos ? pos[i2] : 0.0f;
const float theta_base = p*powf(freq_base, -float(col)/ncols); const float theta_base = p*powf(freq_base, -float(col)/ncols);
float cos_theta, sin_theta; float cos_theta, sin_theta;
@ -6068,7 +6068,7 @@ static __global__ void rope(
template<typename T, bool has_pos> template<typename T, bool has_pos>
static __global__ void rope_neox( static __global__ void rope_neox(
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, const T * x, T * dst, int ncols, int n_dims, const float * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
) { ) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
@ -6095,7 +6095,7 @@ static __global__ void rope_neox(
float cur_rot = inv_ndims * ic - ib; float cur_rot = inv_ndims * ic - ib;
const int p = has_pos ? pos[i2] : 0; const float p = has_pos ? pos[i2] : 0.0f;
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f); const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
float cos_theta, sin_theta; float cos_theta, sin_theta;
@ -6109,7 +6109,7 @@ static __global__ void rope_neox(
} }
static __global__ void rope_glm_f32( static __global__ void rope_glm_f32(
const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, const float * x, float * dst, int ncols, const float * pos, float freq_scale, int p_delta_rows, float freq_base,
int n_ctx int n_ctx
) { ) {
const int col = blockDim.x*blockIdx.x + threadIdx.x; const int col = blockDim.x*blockIdx.x + threadIdx.x;
@ -6124,10 +6124,10 @@ static __global__ void rope_glm_f32(
const int i2 = row/p_delta_rows; const int i2 = row/p_delta_rows;
const float col_theta_scale = powf(freq_base, -2.0f*col/ncols); const float col_theta_scale = powf(freq_base, -2.0f*col/ncols);
// FIXME: this is likely wrong
const int p = pos != nullptr ? pos[i2] : 0;
const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale; const float p = pos != nullptr ? pos[i2] : 0.0f;
const float theta = min(p, (float) n_ctx - 2)*freq_scale*col_theta_scale;
const float sin_theta = sinf(theta); const float sin_theta = sinf(theta);
const float cos_theta = cosf(theta); const float cos_theta = cosf(theta);
@ -6137,7 +6137,7 @@ static __global__ void rope_glm_f32(
dst[i + 0] = x0*cos_theta - x1*sin_theta; dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta; dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale; const float block_theta = max(p - n_ctx - 2, 0.0f)*col_theta_scale;
const float sin_block_theta = sinf(block_theta); const float sin_block_theta = sinf(block_theta);
const float cos_block_theta = cosf(block_theta); const float cos_block_theta = cosf(block_theta);
@ -7688,7 +7688,7 @@ static void clamp_f32_cuda(const float * x, float * dst, const float min, const
template<typename T> template<typename T>
static void rope_cuda( static void rope_cuda(
const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, const T * x, T * dst, int ncols, int nrows, const float * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
) { ) {
GGML_ASSERT(ncols % 2 == 0); GGML_ASSERT(ncols % 2 == 0);
@ -7708,7 +7708,7 @@ static void rope_cuda(
template<typename T> template<typename T>
static void rope_neox_cuda( static void rope_neox_cuda(
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, const T * x, T * dst, int ncols, int n_dims, int nrows, const float * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
) { ) {
GGML_ASSERT(ncols % 2 == 0); GGML_ASSERT(ncols % 2 == 0);
@ -7733,7 +7733,7 @@ static void rope_neox_cuda(
} }
static void rope_glm_f32_cuda( static void rope_glm_f32_cuda(
const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, const float * x, float * dst, int ncols, int nrows, const float * pos, float freq_scale, int p_delta_rows,
float freq_base, int n_ctx, cudaStream_t stream float freq_base, int n_ctx, cudaStream_t stream
) { ) {
GGML_ASSERT(ncols % 4 == 0); GGML_ASSERT(ncols % 4 == 0);
@ -9035,11 +9035,11 @@ static void ggml_cuda_op_rope(
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
const int32_t * pos = nullptr; const float * pos = nullptr;
if ((mode & 1) == 0) { if ((mode & 1) == 0) {
GGML_ASSERT(src1->type == GGML_TYPE_I32); GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src1->ne[0] == ne2); GGML_ASSERT(src1->ne[0] == ne2);
pos = (const int32_t *) src1_dd; pos = (const float *) src1_dd;
} }
const bool is_neox = mode & 2; const bool is_neox = mode & 2;

View file

@ -2057,7 +2057,13 @@ static bool ggml_metal_graph_compute(
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; float freq_base;
float freq_scale;
float ext_factor;
float attn_factor;
float beta_fast;
float beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));

View file

@ -1674,7 +1674,7 @@ static void rope_yarn_corr_dims(
typedef void (rope_t)( typedef void (rope_t)(
device const void * src0, device const void * src0,
device const int32_t * src1, device const float * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01, constant int64_t & ne01,
@ -1709,7 +1709,7 @@ typedef void (rope_t)(
template<typename T> template<typename T>
kernel void kernel_rope( kernel void kernel_rope(
device const void * src0, device const void * src0,
device const int32_t * src1, device const float * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01, constant int64_t & ne01,
@ -1749,11 +1749,11 @@ kernel void kernel_rope(
float corr_dims[2]; float corr_dims[2];
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
device const int32_t * pos = src1; device const float * pos = src1;
const int64_t p = pos[i2]; const float p = pos[i2];
const float theta_0 = (float)p; const float theta_0 = p;
const float inv_ndims = -1.f/n_dims; const float inv_ndims = -1.f/n_dims;
if (!is_neox) { if (!is_neox) {

57
ggml.c
View file

@ -355,6 +355,18 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n) {
} }
} }
static void ggml_i32_to_f32_row(const int32_t * x, float * y, int n) {
for (int i = 0; i < n; i++) {
y[i] = (float) x[i];
}
}
static void ggml_f32_to_i32_row(const float * x, int32_t * y, int n) {
for (int i = 0; i < n; i++) {
y[i] = (int32_t) x[i];
}
}
// //
// timing // timing
// //
@ -454,6 +466,9 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.blck_size = 1, .blck_size = 1,
.type_size = sizeof(int32_t), .type_size = sizeof(int32_t),
.is_quantized = false, .is_quantized = false,
.to_float = (ggml_to_float_t) ggml_i32_to_f32_row,
.from_float = (ggml_from_float_t) ggml_f32_to_i32_row,
.from_float_reference = (ggml_from_float_t) ggml_f32_to_i32_row,
}, },
[GGML_TYPE_F32] = { [GGML_TYPE_F32] = {
.type_name = "f32", .type_name = "f32",
@ -482,7 +497,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.type_size = sizeof(block_q4_0), .type_size = sizeof(block_q4_0),
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q4_0, .to_float = (ggml_to_float_t) dequantize_row_q4_0,
.from_float = quantize_row_q4_0, .from_float = (ggml_from_float_t) quantize_row_q4_0,
.from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference,
.vec_dot = ggml_vec_dot_q4_0_q8_0, .vec_dot = ggml_vec_dot_q4_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
@ -498,7 +513,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.type_size = sizeof(block_q4_1), .type_size = sizeof(block_q4_1),
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q4_1, .to_float = (ggml_to_float_t) dequantize_row_q4_1,
.from_float = quantize_row_q4_1, .from_float = (ggml_from_float_t) quantize_row_q4_1,
.from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference,
.vec_dot = ggml_vec_dot_q4_1_q8_1, .vec_dot = ggml_vec_dot_q4_1_q8_1,
.vec_dot_type = GGML_TYPE_Q8_1, .vec_dot_type = GGML_TYPE_Q8_1,
@ -538,7 +553,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.type_size = sizeof(block_q5_0), .type_size = sizeof(block_q5_0),
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q5_0, .to_float = (ggml_to_float_t) dequantize_row_q5_0,
.from_float = quantize_row_q5_0, .from_float = (ggml_from_float_t) quantize_row_q5_0,
.from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference,
.vec_dot = ggml_vec_dot_q5_0_q8_0, .vec_dot = ggml_vec_dot_q5_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
@ -550,7 +565,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.type_size = sizeof(block_q5_1), .type_size = sizeof(block_q5_1),
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q5_1, .to_float = (ggml_to_float_t) dequantize_row_q5_1,
.from_float = quantize_row_q5_1, .from_float = (ggml_from_float_t) quantize_row_q5_1,
.from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference,
.vec_dot = ggml_vec_dot_q5_1_q8_1, .vec_dot = ggml_vec_dot_q5_1_q8_1,
.vec_dot_type = GGML_TYPE_Q8_1, .vec_dot_type = GGML_TYPE_Q8_1,
@ -562,7 +577,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.type_size = sizeof(block_q8_0), .type_size = sizeof(block_q8_0),
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q8_0, .to_float = (ggml_to_float_t) dequantize_row_q8_0,
.from_float = quantize_row_q8_0, .from_float = (ggml_from_float_t) quantize_row_q8_0,
.from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference,
.vec_dot = ggml_vec_dot_q8_0_q8_0, .vec_dot = ggml_vec_dot_q8_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
@ -577,7 +592,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.blck_size = QK8_1, .blck_size = QK8_1,
.type_size = sizeof(block_q8_1), .type_size = sizeof(block_q8_1),
.is_quantized = true, .is_quantized = true,
.from_float = quantize_row_q8_1, .from_float = (ggml_from_float_t) quantize_row_q8_1,
.from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference,
.vec_dot_type = GGML_TYPE_Q8_1, .vec_dot_type = GGML_TYPE_Q8_1,
.nrows = 1, .nrows = 1,
@ -588,7 +603,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.type_size = sizeof(block_q2_K), .type_size = sizeof(block_q2_K),
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q2_K, .to_float = (ggml_to_float_t) dequantize_row_q2_K,
.from_float = quantize_row_q2_K, .from_float = (ggml_from_float_t) quantize_row_q2_K,
.from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference,
.vec_dot = ggml_vec_dot_q2_K_q8_K, .vec_dot = ggml_vec_dot_q2_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
@ -600,7 +615,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.type_size = sizeof(block_q3_K), .type_size = sizeof(block_q3_K),
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q3_K, .to_float = (ggml_to_float_t) dequantize_row_q3_K,
.from_float = quantize_row_q3_K, .from_float = (ggml_from_float_t) quantize_row_q3_K,
.from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference,
.vec_dot = ggml_vec_dot_q3_K_q8_K, .vec_dot = ggml_vec_dot_q3_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
@ -612,7 +627,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.type_size = sizeof(block_q4_K), .type_size = sizeof(block_q4_K),
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q4_K, .to_float = (ggml_to_float_t) dequantize_row_q4_K,
.from_float = quantize_row_q4_K, .from_float = (ggml_from_float_t) quantize_row_q4_K,
.from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference,
.vec_dot = ggml_vec_dot_q4_K_q8_K, .vec_dot = ggml_vec_dot_q4_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
@ -624,7 +639,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.type_size = sizeof(block_q5_K), .type_size = sizeof(block_q5_K),
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q5_K, .to_float = (ggml_to_float_t) dequantize_row_q5_K,
.from_float = quantize_row_q5_K, .from_float = (ggml_from_float_t) quantize_row_q5_K,
.from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference,
.vec_dot = ggml_vec_dot_q5_K_q8_K, .vec_dot = ggml_vec_dot_q5_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
@ -636,7 +651,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.type_size = sizeof(block_q6_K), .type_size = sizeof(block_q6_K),
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q6_K, .to_float = (ggml_to_float_t) dequantize_row_q6_K,
.from_float = quantize_row_q6_K, .from_float = (ggml_from_float_t) quantize_row_q6_K,
.from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference,
.vec_dot = ggml_vec_dot_q6_K_q8_K, .vec_dot = ggml_vec_dot_q6_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
@ -672,8 +687,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.type_size = sizeof(block_iq3_xxs), .type_size = sizeof(block_iq3_xxs),
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq3_xxs, .to_float = (ggml_to_float_t) dequantize_row_iq3_xxs,
.from_float = quantize_row_iq3_xxs, .from_float = (ggml_from_float_t) quantize_row_iq3_xxs,
.from_float_reference = (ggml_from_float_t)quantize_row_iq3_xxs_reference, .from_float_reference = (ggml_from_float_t) quantize_row_iq3_xxs_reference,
.vec_dot = ggml_vec_dot_iq3_xxs_q8_K, .vec_dot = ggml_vec_dot_iq3_xxs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -696,8 +711,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.type_size = sizeof(block_iq4_nl), .type_size = sizeof(block_iq4_nl),
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq4_nl, .to_float = (ggml_to_float_t) dequantize_row_iq4_nl,
.from_float = quantize_row_iq4_nl, .from_float = (ggml_from_float_t) quantize_row_iq4_nl,
.from_float_reference = (ggml_from_float_t)quantize_row_iq4_nl_reference, .from_float_reference = (ggml_from_float_t) quantize_row_iq4_nl_reference,
.vec_dot = ggml_vec_dot_iq4_nl_q8_0, .vec_dot = ggml_vec_dot_iq4_nl_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1, .nrows = 1,
@ -5254,7 +5269,7 @@ static struct ggml_tensor * ggml_rope_impl(
bool xpos_down, bool xpos_down,
bool inplace) { bool inplace) {
GGML_ASSERT(ggml_is_vector(b)); GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32); GGML_ASSERT(b->type == GGML_TYPE_F32);
GGML_ASSERT(a->ne[2] == b->ne[0]); GGML_ASSERT(a->ne[2] == b->ne[0]);
bool is_node = false; bool is_node = false;
@ -5377,7 +5392,7 @@ struct ggml_tensor * ggml_rope_back(
float xpos_base, float xpos_base,
bool xpos_down) { bool xpos_down) {
GGML_ASSERT(ggml_is_vector(b)); GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32); GGML_ASSERT(b->type == GGML_TYPE_F32);
GGML_ASSERT(a->ne[2] == b->ne[0]); GGML_ASSERT(a->ne[2] == b->ne[0]);
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet"); GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
@ -12352,11 +12367,11 @@ static void ggml_compute_forward_rope_f32(
// this essentially just switches the sign of sin. // this essentially just switches the sign of sin.
const float sin_sign = forward ? 1.0f : -1.0f; const float sin_sign = forward ? 1.0f : -1.0f;
const int32_t * pos = (const int32_t *) src1->data; const float * pos = (const float *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) { for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = pos[i2]; const float p = pos[i2];
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
@ -12523,11 +12538,11 @@ static void ggml_compute_forward_rope_f16(
// this essentially just switches the sign of sin. // this essentially just switches the sign of sin.
const float sin_sign = forward ? 1.0f : -1.0f; const float sin_sign = forward ? 1.0f : -1.0f;
const int32_t * pos = (const int32_t *) src1->data; const float * pos = (const float *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) { for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = pos[i2]; const float p = pos[i2];
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox

View file

@ -1699,8 +1699,8 @@ struct llama_layer {
}; };
struct llama_kv_cell { struct llama_kv_cell {
llama_pos pos = -1; float pos = -1.0f;
llama_pos delta = 0; float delta = 0.0f;
std::set<llama_seq_id> seq_id; std::set<llama_seq_id> seq_id;
@ -1939,10 +1939,10 @@ struct llama_context {
ggml_context * ctx_input = nullptr; ggml_context * ctx_input = nullptr;
struct ggml_tensor * inp_tokens; // I32 [n_batch] struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch] struct ggml_tensor * inp_pos; // F32 [n_batch]
struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch] struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch]
struct ggml_tensor * inp_KQ_pos; // F32 [n_ctx] struct ggml_tensor * inp_KQ_pos; // F32 [n_ctx]
struct ggml_tensor * inp_K_shift; // I32 [n_ctx] struct ggml_tensor * inp_K_shift; // F32 [n_ctx]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch]
@ -2222,7 +2222,7 @@ static void llama_kv_cache_seq_div(
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
int d) { float d) {
if (p0 < 0) p0 = 0; if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
@ -5928,6 +5928,7 @@ struct llm_build_context {
// get input vectors with right size // get input vectors with right size
const size_t stride1 = n_tokens * ggml_type_size(lctx.inp_tokens->type); const size_t stride1 = n_tokens * ggml_type_size(lctx.inp_tokens->type);
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0); struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
struct ggml_tensor * inp_mean = ggml_view_2d(ctx0, lctx.inp_mean, n_tokens, n_tokens, stride1, 0); struct ggml_tensor * inp_mean = ggml_view_2d(ctx0, lctx.inp_mean, n_tokens, n_tokens, stride1, 0);
struct ggml_tensor * inp_cls = ggml_view_1d(ctx0, lctx.inp_cls, n_tokens, 0); struct ggml_tensor * inp_cls = ggml_view_1d(ctx0, lctx.inp_cls, n_tokens, 0);
@ -5938,8 +5939,9 @@ struct llm_build_context {
// token types are hardcoded to zero ("Sentence A") // token types are hardcoded to zero ("Sentence A")
struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
inpL = ggml_add(ctx0, inpL, type_row0); inpL = ggml_add(ctx0, inpL, type_row0);
if (model.arch == LLM_ARCH_BERT) { if (model.arch == LLM_ARCH_BERT) {
inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL); inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, ggml_cast(ctx0, inp_pos, GGML_TYPE_I32)), inpL);
} }
cb(inpL, "inp_embd", -1); cb(inpL, "inp_embd", -1);
@ -7744,7 +7746,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer)); assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
int32_t * data = (int32_t *) lctx.inp_K_shift->data; float * data = (float *) lctx.inp_K_shift->data;
for (int i = 0; i < n_ctx; ++i) { for (int i = 0; i < n_ctx; ++i) {
data[i] = lctx.kv_self.cells[i].delta; data[i] = lctx.kv_self.cells[i].delta;
@ -11690,10 +11692,10 @@ struct llama_context * llama_new_context_with_model(
ctx->inp_tokens = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); ctx->inp_tokens = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch); ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch);
ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch);
ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch); ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch);
ctx->inp_KQ_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx); ctx->inp_KQ_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx);
ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx); ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx);
ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch); ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
@ -12046,7 +12048,7 @@ void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, l
llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta); llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta);
} }
void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, float d) {
if (d == 1) { if (d == 1) {
return; return;
} }
@ -12461,7 +12463,7 @@ int llama_eval_embd(
int32_t n_past) { int32_t n_past) {
llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1); llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1);
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, }; llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, (float) n_past, 1, 0, };
const int ret = llama_decode_internal(*ctx, batch); const int ret = llama_decode_internal(*ctx, batch);
if (ret < 0) { if (ret < 0) {

View file

@ -54,7 +54,7 @@ extern "C" {
struct llama_model; struct llama_model;
struct llama_context; struct llama_context;
typedef int32_t llama_pos; typedef float llama_pos;
typedef int32_t llama_token; typedef int32_t llama_token;
typedef int32_t llama_seq_id; typedef int32_t llama_seq_id;
@ -531,7 +531,7 @@ extern "C" {
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
int d); float d);
// //
// State / sessions // State / sessions

View file

@ -1134,14 +1134,15 @@ struct test_rope : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]); ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[2]);
ggml_set_name(pos, "pos");
ggml_tensor * out = ggml_rope(ctx, a, pos, n_dims, mode, n_ctx); ggml_tensor * out = ggml_rope(ctx, a, pos, n_dims, mode, n_ctx);
return out; return out;
} }
void initialize_tensors(ggml_context * ctx) override { void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) { if (strcmp(ggml_get_name(t), "pos") == 0) {
// pos // pos
std::vector<int> data(ne[2]); std::vector<int> data(ne[2]);
for (int i = 0; i < ne[2]; i++) { for (int i = 0; i < ne[2]; i++) {
@ -1703,7 +1704,7 @@ struct test_llama : public test_llm {
inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens); inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);
// inp_pos - contains the positions // inp_pos - contains the positions
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens); struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_tokens);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads) // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1); struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);
@ -1825,7 +1826,7 @@ struct test_falcon : public test_llm {
inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens); inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);
// inp_pos - contains the positions // inp_pos - contains the positions
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens); struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_tokens);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads) // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1); struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);

View file

@ -1449,9 +1449,9 @@ int main(int argc, const char ** argv) {
for (int n_past = 1; n_past < ne2[2]; ++n_past) { for (int n_past = 1; n_past < ne2[2]; ++n_past) {
x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]); struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ne2[2]);
for (int i = 0; i < ne2[2]; ++i) { for (int i = 0; i < ne2[2]; ++i) {
((int32_t *) p->data)[i] = n_past + i; ((float *) p->data)[i] = n_past + i;
} }
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
@ -1489,9 +1489,9 @@ int main(int argc, const char ** argv) {
for (int n_past = 1; n_past < ne2[2]; ++n_past) { for (int n_past = 1; n_past < ne2[2]; ++n_past) {
x[0] = get_random_tensor_f16(ctx0, ndims, ne2, -1.0f, 1.0f); x[0] = get_random_tensor_f16(ctx0, ndims, ne2, -1.0f, 1.0f);
struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]); struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ne2[2]);
for (int i = 0; i < ne2[2]; ++i) { for (int i = 0; i < ne2[2]; ++i) {
((int32_t *) p->data)[i] = n_past + i; ((float *) p->data)[i] = n_past + i;
} }
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);

View file

@ -143,10 +143,10 @@ int main(int argc, char * argv[]) {
continue; continue;
} }
if (qfns.from_float && qfns.to_float && qfns.vec_dot) {
printf("Testing %s\n", ggml_type_name((ggml_type) i)); printf("Testing %s\n", ggml_type_name((ggml_type) i));
ggml_quantize_init(ei); ggml_quantize_init(ei);
if (qfns.from_float && qfns.to_float) {
const float total_error = total_quantization_error(qfns, test_size, test_data.data()); const float total_error = total_quantization_error(qfns, test_size, test_data.data());
const float max_quantization_error = const float max_quantization_error =
type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS : type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :

View file

@ -275,7 +275,7 @@ int main(int argc, char * argv[]) {
continue; continue;
} }
if (qfns.from_float && qfns.to_float) { if (qfns.from_float && qfns.to_float && qfns.vec_dot) {
printf("%s\n", ggml_type_name(type)); printf("%s\n", ggml_type_name(type));
ggml_quantize_init(type); ggml_quantize_init(type);

View file

@ -146,14 +146,14 @@ int main(int /*argc*/, const char ** /*argv*/) {
const int n_past_0 = 100; const int n_past_0 = 100;
const int n_past_2 = 33; const int n_past_2 = 33;
struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]); struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ne[2]);
struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]); struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ne[2]);
struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]); struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ne[2]);
for (int i = 0; i < ne[2]; ++i) { for (int i = 0; i < ne[2]; ++i) {
((int32_t *) p0->data)[i] = n_past_0 + i; ((float *) p0->data)[i] = n_past_0 + i;
((int32_t *) p1->data)[i] = n_past_2 - n_past_0; ((float *) p1->data)[i] = n_past_2 - n_past_0;
((int32_t *) p2->data)[i] = n_past_2 + i; ((float *) p2->data)[i] = n_past_2 + i;
} }
// test mode 0, 2, 4 (standard, GPT-NeoX, GLM) // test mode 0, 2, 4 (standard, GPT-NeoX, GLM)