llama : switch to floating-point token positions
ggml-ci
This commit is contained in:
parent
15499eb942
commit
fc775366f1
14 changed files with 68 additions and 61 deletions
12
ggml.c
12
ggml.c
|
@ -5254,7 +5254,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
|||
bool xpos_down,
|
||||
bool inplace) {
|
||||
GGML_ASSERT(ggml_is_vector(b));
|
||||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(b->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
||||
|
||||
bool is_node = false;
|
||||
|
@ -5377,7 +5377,7 @@ struct ggml_tensor * ggml_rope_back(
|
|||
float xpos_base,
|
||||
bool xpos_down) {
|
||||
GGML_ASSERT(ggml_is_vector(b));
|
||||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(b->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
||||
|
||||
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
|
||||
|
@ -12352,11 +12352,11 @@ static void ggml_compute_forward_rope_f32(
|
|||
// this essentially just switches the sign of sin.
|
||||
const float sin_sign = forward ? 1.0f : -1.0f;
|
||||
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
const float * pos = (const float *) src1->data;
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
||||
const int64_t p = pos[i2];
|
||||
const float p = pos[i2];
|
||||
|
||||
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
||||
if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
|
||||
|
@ -12523,11 +12523,11 @@ static void ggml_compute_forward_rope_f16(
|
|||
// this essentially just switches the sign of sin.
|
||||
const float sin_sign = forward ? 1.0f : -1.0f;
|
||||
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
const float * pos = (const float *) src1->data;
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
||||
const int64_t p = pos[i2];
|
||||
const float p = pos[i2];
|
||||
|
||||
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
||||
if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue