Allow disabling unused logit skipping code w/ cmake / make options

cmake -DLLAMA_SKIP_UNUSED_LOGITS=OFF ...
LLAMA_NO_SKIP_UNUSED_LOGITS=1 make ...
This commit is contained in:
Olivier Chafik 2023-08-25 14:00:24 +01:00
parent 7ec7ef94a9
commit 5553820d90
3 changed files with 10 additions and 6 deletions

View file

@ -79,6 +79,7 @@ option(LLAMA_METAL "llama: use Metal"
option(LLAMA_MPI "llama: use MPI" OFF) option(LLAMA_MPI "llama: use MPI" OFF)
option(LLAMA_K_QUANTS "llama: use k-quants" ON) option(LLAMA_K_QUANTS "llama: use k-quants" ON)
option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF) option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF)
option(LLAMA_SKIP_UNUSED_LOGITS "llama: skip computation of unused logits" ON)
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
@ -352,6 +353,10 @@ if (LLAMA_CLBLAST)
endif() endif()
endif() endif()
if (LLAMA_SKIP_UNUSED_LOGITS)
add_compile_definitions(LLAMA_SKIP_UNUSED_LOGITS)
endif()
if (LLAMA_ALL_WARNINGS) if (LLAMA_ALL_WARNINGS)
if (NOT MSVC) if (NOT MSVC)
set(c_flags set(c_flags

View file

@ -302,6 +302,11 @@ k_quants.o: k_quants.c k_quants.h
$(CC) $(CFLAGS) -c $< -o $@ $(CC) $(CFLAGS) -c $< -o $@
endif # LLAMA_NO_K_QUANTS endif # LLAMA_NO_K_QUANTS
ifndef LLAMA_NO_SKIP_UNUSED_LOGITS
CFLAGS += -DLLAMA_SKIP_UNUSED_LOGITS
CXXFLAGS += -DLLAMA_SKIP_UNUSED_LOGITS
endif
# #
# Print build information # Print build information
# #

View file

@ -56,12 +56,6 @@
#include <stdio.h> // for _fseeki64 #include <stdio.h> // for _fseeki64
#endif #endif
// TODO: Fix unused logit skipping crashes on ROCm
// (see https://github.com/ggerganov/llama.cpp/pull/2700#issuecomment-1689548127)
#ifndef LLAMA_USE_HIPBLAS
#define LLAMA_SKIP_UNUSED_LOGITS
#endif
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <cassert> #include <cassert>