Compare commits

...
Sign in to create a new pull request.

5 commits

Author SHA1 Message Date
Georgi Gerganov
f3a84b2e0d
llama : better express the KV cache dependencies in the graph 2023-09-04 21:44:48 +03:00
Georgi Gerganov
60c2ef6d92
metal : utilize view_src to see of tensor is a view 2023-09-04 20:49:09 +03:00
Georgi Gerganov
ebd3467cc8
metal : more readable kernel 2023-09-04 20:48:46 +03:00
Georgi Gerganov
7704db2521
ggml : just in case 2023-09-04 20:48:25 +03:00
Georgi Gerganov
ad80e5a4a7
llama : add ggml_cont to trigger bug with Metal 2023-09-04 19:50:34 +03:00
4 changed files with 47 additions and 40 deletions

View file

@ -541,10 +541,7 @@ void ggml_metal_graph_find_concurrency(
int64_t data_start = (int64_t) gf->nodes[i]->data; int64_t data_start = (int64_t) gf->nodes[i]->data;
int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]); int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
for (int j = n_start; j < i; j++) { for (int j = n_start; j < i; j++) {
if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \ if (nodes_unused[j] && gf->nodes[j]->view_src == NULL) {
&& gf->nodes[j]->op != GGML_OP_VIEW \
&& gf->nodes[j]->op != GGML_OP_TRANSPOSE \
&& gf->nodes[j]->op != GGML_OP_PERMUTE) {
if (((int64_t)gf->nodes[j]->data) >= data_start + length || \ if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) { ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
continue; continue;

View file

@ -783,11 +783,11 @@ kernel void kernel_cpy_f16_f16(
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const half * src = (device half *) ((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); device const half * src = (device half *) ((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0]; device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i00*nb0);
*dst_data = *src;
} }
} }

4
ggml.c
View file

@ -4285,7 +4285,7 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
} }
size_t ggml_nbytes(const struct ggml_tensor * tensor) { size_t ggml_nbytes(const struct ggml_tensor * tensor) {
size_t nbytes = tensor->ne[0]*tensor->nb[0]/ggml_blck_size(tensor->type); size_t nbytes = (tensor->ne[0]*tensor->nb[0])/ggml_blck_size(tensor->type);
for (int i = 1; i < GGML_MAX_DIMS; ++i) { for (int i = 1; i < GGML_MAX_DIMS; ++i) {
nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
} }
@ -5213,6 +5213,8 @@ struct ggml_tensor * ggml_view_tensor(
result->nb[i] = src->nb[i]; result->nb[i] = src->nb[i];
} }
result->op = GGML_OP_VIEW;
return result; return result;
} }

View file

@ -2347,6 +2347,10 @@ static struct ggml_cgraph * llm_build_llama(
offload_func_kq(tmpq); offload_func_kq(tmpq);
ggml_set_name (tmpq, "tmpq"); ggml_set_name (tmpq, "tmpq");
struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
offload_func_v(tmpv);
ggml_set_name (tmpv, "tmpv");
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(Kcur); offload_func_kq(Kcur);
ggml_set_name (Kcur, "Kcur"); ggml_set_name (Kcur, "Kcur");
@ -2355,31 +2359,35 @@ static struct ggml_cgraph * llm_build_llama(
offload_func_kq(Qcur); offload_func_kq(Qcur);
ggml_set_name (Qcur, "Qcur"); ggml_set_name (Qcur, "Qcur");
// store key and value to memory
{
// compute the transposed [N, n_embd] V matrix // compute the transposed [N, n_embd] V matrix
struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
offload_func_v(tmpv);
ggml_set_name(tmpv, "tmpv");
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N)); struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N));
offload_func_v(Vcur); offload_func_v(Vcur);
ggml_set_name (Vcur, "Vcur"); ggml_set_name (Vcur, "Vcur");
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); struct ggml_tensor * k;
offload_func_kq(k); struct ggml_tensor * v;
ggml_set_name(k, "k");
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, // store key and value to memory
{
struct ggml_tensor * k_view = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
offload_func_kq(k_view);
ggml_set_name (k_view, "k_view");
struct ggml_tensor * v_view = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
( n_ctx)*ggml_element_size(kv_self.v), ( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v));
offload_func_v(v); offload_func_v(v_view);
ggml_set_name(v, "v"); ggml_set_name (v_view, "v_view");
// important: storing RoPE-ed version of K in the KV cache! // important: storing RoPE-ed version of K in the KV cache!
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); struct ggml_tensor * k_cpy = ggml_cpy(ctx0, Kcur, k_view);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); struct ggml_tensor * v_cpy = ggml_cpy(ctx0, Vcur, v_view);
// TODO: replace with ggml_dependency / ggml_depends_on
k = ggml_view_tensor(ctx0, kv_self.k);
v = ggml_view_tensor(ctx0, kv_self.v);
k->src[0] = k_cpy;
v->src[0] = v_cpy;
} }
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
@ -2387,11 +2395,11 @@ static struct ggml_cgraph * llm_build_llama(
ggml_set_name(Q, "Q"); ggml_set_name(Q, "Q");
struct ggml_tensor * K = struct ggml_tensor * K =
ggml_view_3d(ctx0, kv_self.k, ggml_view_3d(ctx0, k,
n_embd_head, n_past + N, n_head_kv, n_embd_head, n_past + N, n_head_kv,
ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(k)*n_embd_gqa,
ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(k)*n_embd_head,
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); ggml_element_size(k)*n_embd_gqa*n_ctx*il);
offload_func_kq(K); offload_func_kq(K);
ggml_set_name(K, "K"); ggml_set_name(K, "K");
@ -2418,11 +2426,11 @@ static struct ggml_cgraph * llm_build_llama(
// split cached V into n_head heads // split cached V into n_head heads
struct ggml_tensor * V = struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v, ggml_view_3d(ctx0, v,
n_past + N, n_embd_head, n_head_kv, n_past + N, n_embd_head, n_head_kv,
ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(v)*n_ctx,
ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(v)*n_ctx*n_embd_head,
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); ggml_element_size(v)*n_ctx*n_embd_gqa*il);
offload_func_v(V); offload_func_v(V);
ggml_set_name(V, "V"); ggml_set_name(V, "V");
@ -2434,7 +2442,7 @@ static struct ggml_cgraph * llm_build_llama(
// make V contiguous in memory to speed up the matmul, however we waste time on the copy // make V contiguous in memory to speed up the matmul, however we waste time on the copy
// on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
// is there a better way? // is there a better way?
struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head)); struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, v->type, n_past + N, n_embd_head, n_head));
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
#endif #endif