llama : Metal inference (#1642)

* mtl : export the LLaMA computation graph

* ci : disable temporary

* mtl : adapt the MNIST example as starter

* mtl : no need for mtl-export tool, add cli arg for main instead

* mtl : export just a small part of the graph for now to make it easier

* mtl : move MSL code into separate file for easy editing

* mtl : initial get_rows_q4_0 kernel

* mtl : confirmed get_rows_q4_0 is working correctly

* mtl : add rms_norm kernel + confirm working

* mtl : add mul kernel + confirm working

* mtl : initial mul_mat Q4 kernel (wrong results)

* mtl : mul_mat fixes (still wrong)

* mtl : another mul_mat Q4 (still does not work)

* mtl : working mul_mat q4

* ggml : fix handling of "view" ops in ggml_graph_import()

* mtl : add rope kernel

* mtl : add reshape and transpose handling

* ggml : store offset as opt arg for ggml_view_xd() operators

* mtl : add cpy kernel + handle view ops

* mtl : confirm f16 x f32 attention mul mat

* mtl : add scale kernel

* mtl : add diag_mask_inf kernel

* mtl : fix soft_max kernel

* ggml : update ggml_nbytes() to handle non-contiguous tensors

* mtl : verify V tensor contents

* mtl : add f32 -> f32 cpy kernel

* mtl : add silu kernel

* mtl : add non-broadcast mul kernel

* mtl : full GPU inference of the computation graph

* mtl : optimize rms_norm and soft_max kernels

* mtl : add f16 mat x f32 vec multiplication kernel

* mtl : fix bug in f16 x f32 mul mat + speed-up computation

* mtl : faster mul_mat_q4_0_f32 kernel

* mtl : fix kernel signature + roll inner loop

* mtl : more threads for rms_norm + better timing

* mtl : remove printfs from inner loop

* mtl : simplify implementation

* mtl : add save/load vocab to ggml file

* mtl : plug Metal inference into llama.cpp (very quick-n-dirty)

* mtl : make it work with main example

Lots of hacks but at least now it generates text

* mtl : preparing for merge

* mtl : clean-up ggml mtl interface + suport scratch / inplace

* mtl : remove temp / debug code

* metal : final refactoring and simplification

* Revert "ci : disable temporary"

This reverts commit 98c267fc77.

* metal : add comments

* metal : clean-up stuff, fix typos

* readme : add Metal instructions

* readme : add example for main
This commit is contained in:
Georgi Gerganov 2023-06-04 23:34:30 +03:00 committed by GitHub
parent dcb2ed4826
commit ecb217db4f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 1677 additions and 94 deletions

View file

@ -105,6 +105,7 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
#CFLAGS += -mfma -mf16c -mavx
#CXXFLAGS += -mfma -mf16c -mavx
endif
ifneq ($(filter ppc64%,$(UNAME_M)),)
POWER9_M := $(shell grep "POWER9" /proc/cpuinfo)
ifneq (,$(findstring POWER9,$(POWER9_M)))
@ -116,6 +117,7 @@ ifneq ($(filter ppc64%,$(UNAME_M)),)
CXXFLAGS += -std=c++23 -DGGML_BIG_ENDIAN
endif
endif
ifndef LLAMA_NO_ACCELERATE
# Mac M1 - include Accelerate framework.
# `-framework Accelerate` works on Mac Intel as well, with negliable performance boost (as of the predict time).
@ -123,7 +125,8 @@ ifndef LLAMA_NO_ACCELERATE
CFLAGS += -DGGML_USE_ACCELERATE
LDFLAGS += -framework Accelerate
endif
endif
endif # LLAMA_NO_ACCELERATE
ifdef LLAMA_OPENBLAS
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas -I/usr/include/openblas
ifneq ($(shell grep -e "Arch Linux" -e "ID_LIKE=arch" /etc/os-release 2>/dev/null),)
@ -131,11 +134,13 @@ ifdef LLAMA_OPENBLAS
else
LDFLAGS += -lopenblas
endif
endif
endif # LLAMA_OPENBLAS
ifdef LLAMA_BLIS
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/blis -I/usr/include/blis
LDFLAGS += -lblis -L/usr/local/lib
endif
endif # LLAMA_BLIS
ifdef LLAMA_CUBLAS
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
@ -156,9 +161,10 @@ endif # LLAMA_CUDA_DMMV_Y
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
endif # LLAMA_CUBLAS
ifdef LLAMA_CLBLAST
CFLAGS += -DGGML_USE_CLBLAST
CXXFLAGS += -DGGML_USE_CLBLAST
CFLAGS += -DGGML_USE_CLBLAST
CXXFLAGS += -DGGML_USE_CLBLAST
# Mac provides OpenCL as a framework
ifeq ($(UNAME_S),Darwin)
LDFLAGS += -lclblast -framework OpenCL
@ -166,23 +172,38 @@ ifdef LLAMA_CLBLAST
LDFLAGS += -lclblast -lOpenCL
endif
OBJS += ggml-opencl.o
ggml-opencl.o: ggml-opencl.cpp ggml-opencl.h
$(CXX) $(CXXFLAGS) -c $< -o $@
endif
endif # LLAMA_CLBLAST
ifdef LLAMA_METAL
CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG
CXXFLAGS += -DGGML_USE_METAL
LDFLAGS += -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
OBJS += ggml-metal.o
ggml-metal.o: ggml-metal.m ggml-metal.h
$(CC) $(CFLAGS) -c $< -o $@
endif # LLAMA_METAL
ifneq ($(filter aarch64%,$(UNAME_M)),)
# Apple M1, M2, etc.
# Raspberry Pi 3, 4, Zero 2 (64-bit)
CFLAGS += -mcpu=native
CXXFLAGS += -mcpu=native
endif
ifneq ($(filter armv6%,$(UNAME_M)),)
# Raspberry Pi 1, Zero
CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access
endif
ifneq ($(filter armv7%,$(UNAME_M)),)
# Raspberry Pi 2
CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
endif
ifneq ($(filter armv8%,$(UNAME_M)),)
# Raspberry Pi 3, 4, Zero 2 (32-bit)
CFLAGS += -mfp16-format=ieee -mno-unaligned-access