llama : Metal inference (#1642)

* mtl : export the LLaMA computation graph

* ci : disable temporary

* mtl : adapt the MNIST example as starter

* mtl : no need for mtl-export tool, add cli arg for main instead

* mtl : export just a small part of the graph for now to make it easier

* mtl : move MSL code into separate file for easy editing

* mtl : initial get_rows_q4_0 kernel

* mtl : confirmed get_rows_q4_0 is working correctly

* mtl : add rms_norm kernel + confirm working

* mtl : add mul kernel + confirm working

* mtl : initial mul_mat Q4 kernel (wrong results)

* mtl : mul_mat fixes (still wrong)

* mtl : another mul_mat Q4 (still does not work)

* mtl : working mul_mat q4

* ggml : fix handling of "view" ops in ggml_graph_import()

* mtl : add rope kernel

* mtl : add reshape and transpose handling

* ggml : store offset as opt arg for ggml_view_xd() operators

* mtl : add cpy kernel + handle view ops

* mtl : confirm f16 x f32 attention mul mat

* mtl : add scale kernel

* mtl : add diag_mask_inf kernel

* mtl : fix soft_max kernel

* ggml : update ggml_nbytes() to handle non-contiguous tensors

* mtl : verify V tensor contents

* mtl : add f32 -> f32 cpy kernel

* mtl : add silu kernel

* mtl : add non-broadcast mul kernel

* mtl : full GPU inference of the computation graph

* mtl : optimize rms_norm and soft_max kernels

* mtl : add f16 mat x f32 vec multiplication kernel

* mtl : fix bug in f16 x f32 mul mat + speed-up computation

* mtl : faster mul_mat_q4_0_f32 kernel

* mtl : fix kernel signature + roll inner loop

* mtl : more threads for rms_norm + better timing

* mtl : remove printfs from inner loop

* mtl : simplify implementation

* mtl : add save/load vocab to ggml file

* mtl : plug Metal inference into llama.cpp (very quick-n-dirty)

* mtl : make it work with main example

Lots of hacks but at least now it generates text

* mtl : preparing for merge

* mtl : clean-up ggml mtl interface + suport scratch / inplace

* mtl : remove temp / debug code

* metal : final refactoring and simplification

* Revert "ci : disable temporary"

This reverts commit 98c267fc77.

* metal : add comments

* metal : clean-up stuff, fix typos

* readme : add Metal instructions

* readme : add example for main
This commit is contained in:
Georgi Gerganov 2023-06-04 23:34:30 +03:00 committed by GitHub
parent dcb2ed4826
commit ecb217db4f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 1677 additions and 94 deletions

6
ggml.h
View file

@ -425,6 +425,7 @@ extern "C" {
GGML_API void ggml_print_objects(const struct ggml_context * ctx);
GGML_API int64_t ggml_nelements(const struct ggml_tensor * tensor);
GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor);
GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor);
GGML_API int ggml_blck_size (enum ggml_type type);
@ -441,13 +442,16 @@ extern "C" {
// TODO: temporary until model loading of ggml examples is refactored
GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor);
// use this to compute the memory overhead of a tensor
GGML_API size_t ggml_tensor_overhead(void);
// main
GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);
GGML_API void ggml_free(struct ggml_context * ctx);
GGML_API void ggml_free(struct ggml_context * ctx);
GGML_API size_t ggml_used_mem(const struct ggml_context * ctx);