ggml: move op parameters from tensors to ggml_tensor::op_params (#2333)
* ggml: move op parameters from tensors to ggml_tensor::op_params * alibi: use memcpy for float params * remove `src[1] = NULL` in ops
This commit is contained in:
parent
e76d630df1
commit
95a6c595e7
4 changed files with 226 additions and 486 deletions
20
ggml-metal.m
20
ggml-metal.m
|
@ -585,7 +585,7 @@ void ggml_metal_graph_compute(
|
|||
encoder = [command_buffer computeCommandEncoder];
|
||||
}
|
||||
|
||||
const int n_past = ((int32_t *)(src1->data))[0];
|
||||
const int n_past = ((int32_t *)(dst->op_params))[0];
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
|
@ -850,9 +850,10 @@ void ggml_metal_graph_compute(
|
|||
|
||||
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
||||
|
||||
const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past);
|
||||
const int n_head = ((int32_t *) src1->data)[1];
|
||||
const float max_bias = ((float *) src1->data)[2];
|
||||
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
|
||||
const int n_head = ((int32_t *) dst->op_params)[1];
|
||||
float max_bias;
|
||||
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
||||
|
||||
if (__builtin_popcount(n_head) != 1) {
|
||||
GGML_ASSERT(false && "only power-of-two n_head implemented");
|
||||
|
@ -890,15 +891,14 @@ void ggml_metal_graph_compute(
|
|||
encoder = [command_buffer computeCommandEncoder];
|
||||
}
|
||||
|
||||
const int n_dims = ((int32_t *) src1->data)[1];
|
||||
const int mode = ((int32_t *) src1->data)[2];
|
||||
|
||||
const int n_past = ((int32_t *)(src1->data))[0];
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_rope];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue