This commit is contained in:
John 2023-04-10 22:37:56 +00:00 committed by GitHub
commit 40e9d4acba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -3,7 +3,7 @@
#ifndef LLAMA_UTIL_H
#define LLAMA_UTIL_H
#define _PRELOAD_MMAP_FILE 1 // when using mmap, preload the entire file to prevent loading during first token inference
#include <cstdio>
#include <cstdint>
#include <cerrno>
@ -11,6 +11,7 @@
#include <cstdarg>
#include <cstdlib>
#include <climits>
#include <thread>
#include <string>
#include <vector>
@ -30,6 +31,34 @@
#include <windows.h>
#include <io.h>
#include <stdio.h> // for _fseeki64
typedef volatile LONG atomic_int;
typedef atomic_int atomic_bool;
typedef HANDLE pthread_t;
typedef DWORD thread_ret_t;
static int pthread_create(pthread_t *out, void *unused, thread_ret_t (*func)(void *), void *arg)
{
HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE)func, arg, 0, NULL);
if (handle == NULL)
{
return EAGAIN;
}
*out = handle;
return 0;
}
static int pthread_join(pthread_t thread, void *unused)
{
return (int)WaitForSingleObject(thread, INFINITE);
}
#else
#include <unistd.h>
#include <pthread.h>
#include <stdatomic.h>
typedef void *thread_ret_t;
#endif
#define LLAMA_ASSERT(x) \
@ -156,7 +185,96 @@ static std::string llama_format_win_err(DWORD err) {
struct llama_mmap {
void * addr;
size_t size;
typedef struct
{
size_t start;
size_t end;
void *addr;
int n_threads;
int n_thread;
int page_size;
} thread_data_t;
static thread_ret_t worker_preload_memory(void *arg)
{
thread_data_t *data = (thread_data_t *)arg;
volatile char buffer;
for (size_t offset = data->start + data->n_thread * data->page_size; offset <= data->end; offset += data->n_threads * data->page_size)
{
volatile void *buffer_ptr = &buffer;
memcpy((void *)buffer_ptr, (char *)data->addr + offset, sizeof(buffer));
if (data->n_threads < data->n_thread && buffer==0) exit(-1); // to avoid compiler optimization - the previous simple access method did not work in thread workers
}
return NULL;
}
void preload_mmap_file(void *addr, size_t length, int n_threads)
{
#ifndef _PRELOAD_MMAP_FILE
return;
#endif
// Get the page size of the system
#if defined(_WIN32)
SYSTEM_INFO si;
GetSystemInfo(&si);
long page_size = si.dwPageSize;
#else
long page_size = sysconf(_SC_PAGE_SIZE); // in windows we can use GetSystemInfo:
#endif
if (page_size == -1)
{
perror("sysconf");
return;
}
#ifdef _WIN32
HANDLE hProcess = GetCurrentProcess();
WIN32_MEMORY_RANGE_ENTRY range;
range.VirtualAddress = addr;
range.NumberOfBytes = length;
// if (!VirtualLock(addr, length)) { }; // no benefit. for systems with too little RAM we should lock a part and restrict the preload to that new length
if (!PrefetchVirtualMemory(hProcess, 1, &range, 0)) { }; // Prefetches part of the data and signals readahead to the file system
#else
// todo
//if (posix_madvise(addr, length, POSIX_MADV_WILLNEED) == -1) { };
// readahead() should be the equivalent method for Linux. I don't think madvise will cause a full fetch
// the multi threaded read below is pseudo sequential, it also needs a test without OS level readahead in place (worst case set threads to 1 in linux or return)
#endif
if (n_threads > 32)
n_threads = 32;
pthread_t threads[32];
thread_data_t thread_data[32];
// we split the pages between the threads - that was the only reliable solution I could find
size_t num_pages_per_thread = (length / page_size) / n_threads;
int pages = ceil(length / page_size);
for (int page_start = 0; page_start < pages; page_start += n_threads * num_pages_per_thread)
{
size_t chunk_start = page_start * page_size;
size_t chunk_end = chunk_start + page_size * n_threads * num_pages_per_thread;
for (int i = 0; i < n_threads; ++i)
{
thread_data[i].start = chunk_start;
thread_data[i].end = chunk_end;
if (thread_data[i].end > length)
{
thread_data[i].end = length;
}
thread_data[i].addr = addr;
thread_data[i].page_size = page_size;
thread_data[i].n_threads = n_threads;
thread_data[i].n_thread = i;
pthread_create(&threads[i], NULL, worker_preload_memory, &thread_data[i]);
if (thread_data[i].end == length)
break;
}
for (int i = 0; i < n_threads; ++i)
{
pthread_join(threads[i], NULL);
}
}
}
llama_mmap(const llama_mmap &) = delete;
#ifdef _POSIX_MAPPED_FILES
@ -180,6 +298,8 @@ struct llama_mmap {
fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n",
strerror(errno));
}
// if _PRELOAD_MMAP_FILE is define, this will preload the file into the page cache efficiently
preload_mmap_file(addr, file->size);
}
~llama_mmap() {
@ -217,6 +337,9 @@ struct llama_mmap {
fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
}
// if _PRELOAD_MMAP_FILE is define, this will preload the file into the page cache efficiently
preload_mmap_file(addr, file->size, std::thread::hardware_concurrency()/2);
}
~llama_mmap() {