Merge commit 'c43c2da8af' into concedo_experimental

# Conflicts:
#	llama.cpp
This commit is contained in:
Concedo 2023-11-02 11:17:59 +08:00
commit 1ab18ecb53
14 changed files with 1368 additions and 1398 deletions

View file

@ -103,9 +103,24 @@ void process_escapes(std::string& input) {
} }
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
bool result = true;
try {
if (!gpt_params_parse_ex(argc, argv, params)) {
gpt_print_usage(argc, argv, gpt_params());
exit(0);
}
}
catch (const std::invalid_argument & ex) {
fprintf(stderr, "%s\n", ex.what());
gpt_print_usage(argc, argv, gpt_params());
exit(1);
}
return result;
}
bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
bool invalid_param = false; bool invalid_param = false;
std::string arg; std::string arg;
gpt_params default_params;
const std::string arg_prefix = "--"; const std::string arg_prefix = "--";
llama_sampling_params & sparams = params.sparams; llama_sampling_params & sparams = params.sparams;
@ -554,11 +569,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break; break;
} }
} else if (arg == "-h" || arg == "--help") { } else if (arg == "-h" || arg == "--help") {
gpt_print_usage(argc, argv, default_params); return false;
#ifndef LOG_DISABLE_LOGS
log_print_usage();
#endif // LOG_DISABLE_LOGS
exit(0);
} else if (arg == "--random-prompt") { } else if (arg == "--random-prompt") {
params.random_prompt = true; params.random_prompt = true;
} else if (arg == "--in-prefix-bos") { } else if (arg == "--in-prefix-bos") {
@ -617,22 +629,17 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
// End of Parse args for logging parameters // End of Parse args for logging parameters
#endif // LOG_DISABLE_LOGS #endif // LOG_DISABLE_LOGS
} else { } else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); throw std::invalid_argument("error: unknown argument: " + arg);
gpt_print_usage(argc, argv, default_params);
exit(1);
} }
} }
if (invalid_param) { if (invalid_param) {
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); throw std::invalid_argument("error: invalid parameter for argument: " + arg);
gpt_print_usage(argc, argv, default_params);
exit(1);
} }
if (params.prompt_cache_all && if (params.prompt_cache_all &&
(params.interactive || params.interactive_first || (params.interactive || params.interactive_first ||
params.instruct)) { params.instruct)) {
fprintf(stderr, "error: --prompt-cache-all not supported in interactive mode yet\n");
gpt_print_usage(argc, argv, default_params); throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
exit(1);
} }
if (params.escape) { if (params.escape) {
@ -651,6 +658,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
const llama_sampling_params & sparams = params.sparams; const llama_sampling_params & sparams = params.sparams;
printf("\n");
printf("usage: %s [options]\n", argv[0]); printf("usage: %s [options]\n", argv[0]);
printf("\n"); printf("\n");
printf("options:\n"); printf("options:\n");
@ -762,6 +770,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -ld LOGDIR, --logdir LOGDIR\n"); printf(" -ld LOGDIR, --logdir LOGDIR\n");
printf(" path under which to save YAML logs (no logging if unset)\n"); printf(" path under which to save YAML logs (no logging if unset)\n");
printf("\n"); printf("\n");
#ifndef LOG_DISABLE_LOGS
log_print_usage();
#endif // LOG_DISABLE_LOGS
} }
std::string get_system_info(const gpt_params & params) { std::string get_system_info(const gpt_params & params) {

View file

@ -124,6 +124,8 @@ struct gpt_params {
std::string image = ""; // path to an image file std::string image = ""; // path to an image file
}; };
bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params);
bool gpt_params_parse(int argc, char ** argv, gpt_params & params); bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
void gpt_print_usage(int argc, char ** argv, const gpt_params & params); void gpt_print_usage(int argc, char ** argv, const gpt_params & params);

View file

@ -97,38 +97,56 @@
#define LOG_TEE_TARGET stderr #define LOG_TEE_TARGET stderr
#endif #endif
// NOTE: currently disabled as it produces too many log files // Utility for synchronizing log configuration state
// since std::optional was introduced only in c++17
enum LogTriState
{
LogTriStateSame,
LogTriStateFalse,
LogTriStateTrue
};
// Utility to obtain "pid" like unique process id and use it when creating log files. // Utility to obtain "pid" like unique process id and use it when creating log files.
//inline std::string log_get_pid() inline std::string log_get_pid()
//{ {
// static std::string pid; static std::string pid;
// if (pid.empty()) if (pid.empty())
// { {
// // std::this_thread::get_id() is the most portable way of obtaining a "process id" // std::this_thread::get_id() is the most portable way of obtaining a "process id"
// // it's not the same as "pid" but is unique enough to solve multiple instances // it's not the same as "pid" but is unique enough to solve multiple instances
// // trying to write to the same log. // trying to write to the same log.
// std::stringstream ss; std::stringstream ss;
// ss << std::this_thread::get_id(); ss << std::this_thread::get_id();
// pid = ss.str(); pid = ss.str();
// } }
//
// return pid; return pid;
//} }
// Utility function for generating log file names with unique id based on thread id. // Utility function for generating log file names with unique id based on thread id.
// invocation with log_filename_generator( "llama", "log" ) creates a string "llama.<number>.log" // invocation with log_filename_generator( "llama", "log" ) creates a string "llama.<number>.log"
// where the number is a runtime id of the current thread. // where the number is a runtime id of the current thread.
#define log_filename_generator(log_file_basename, log_file_extension) log_filename_generator_impl(log_file_basename, log_file_extension) #define log_filename_generator(log_file_basename, log_file_extension) log_filename_generator_impl(LogTriStateSame, log_file_basename, log_file_extension)
// INTERNAL, DO NOT USE // INTERNAL, DO NOT USE
inline std::string log_filename_generator_impl(const std::string & log_file_basename, const std::string & log_file_extension) inline std::string log_filename_generator_impl(LogTriState multilog, const std::string & log_file_basename, const std::string & log_file_extension)
{ {
static bool _multilog = false;
if (multilog != LogTriStateSame)
{
_multilog = multilog == LogTriStateTrue;
}
std::stringstream buf; std::stringstream buf;
buf << log_file_basename; buf << log_file_basename;
//buf << "."; if (_multilog)
//buf << log_get_pid(); {
buf << ".";
buf << log_get_pid();
}
buf << "."; buf << ".";
buf << log_file_extension; buf << log_file_extension;
@ -213,15 +231,6 @@ inline std::string log_filename_generator_impl(const std::string & log_file_base
#define LOG_TEE_FLF_VAL ,"" #define LOG_TEE_FLF_VAL ,""
#endif #endif
// Utility for synchronizing log configuration state
// since std::optional was introduced only in c++17
enum LogTriState
{
LogTriStateSame,
LogTriStateFalse,
LogTriStateTrue
};
// INTERNAL, DO NOT USE // INTERNAL, DO NOT USE
// USE LOG() INSTEAD // USE LOG() INSTEAD
// //
@ -315,16 +324,23 @@ enum LogTriState
#endif #endif
// INTERNAL, DO NOT USE // INTERNAL, DO NOT USE
inline FILE *log_handler1_impl(bool change = false, LogTriState disable = LogTriStateSame, const std::string & filename = LOG_DEFAULT_FILE_NAME, FILE *target = nullptr) inline FILE *log_handler1_impl(bool change = false, LogTriState append = LogTriStateSame, LogTriState disable = LogTriStateSame, const std::string & filename = LOG_DEFAULT_FILE_NAME, FILE *target = nullptr)
{ {
static bool _initialized{false}; static bool _initialized = false;
static bool _disabled{(filename.empty() && target == nullptr)}; static bool _append = false;
static bool _disabled = filename.empty() && target == nullptr;
static std::string log_current_filename{filename}; static std::string log_current_filename{filename};
static FILE *log_current_target{target}; static FILE *log_current_target{target};
static FILE *logfile = nullptr; static FILE *logfile = nullptr;
if (change) if (change)
{ {
if (append != LogTriStateSame)
{
_append = append == LogTriStateTrue;
return logfile;
}
if (disable == LogTriStateTrue) if (disable == LogTriStateTrue)
{ {
// Disable primary target // Disable primary target
@ -377,7 +393,7 @@ inline FILE *log_handler1_impl(bool change = false, LogTriState disable = LogTri
} }
} }
logfile = fopen(filename.c_str(), "w"); logfile = fopen(filename.c_str(), _append ? "a" : "w");
} }
if (!logfile) if (!logfile)
@ -398,9 +414,9 @@ inline FILE *log_handler1_impl(bool change = false, LogTriState disable = LogTri
} }
// INTERNAL, DO NOT USE // INTERNAL, DO NOT USE
inline FILE *log_handler2_impl(bool change = false, LogTriState disable = LogTriStateSame, FILE *target = nullptr, const std::string & filename = LOG_DEFAULT_FILE_NAME) inline FILE *log_handler2_impl(bool change = false, LogTriState append = LogTriStateSame, LogTriState disable = LogTriStateSame, FILE *target = nullptr, const std::string & filename = LOG_DEFAULT_FILE_NAME)
{ {
return log_handler1_impl(change, disable, filename, target); return log_handler1_impl(change, append, disable, filename, target);
} }
// Disables logs entirely at runtime. // Disables logs entirely at runtime.
@ -411,7 +427,7 @@ inline FILE *log_handler2_impl(bool change = false, LogTriState disable = LogTri
// INTERNAL, DO NOT USE // INTERNAL, DO NOT USE
inline FILE *log_disable_impl() inline FILE *log_disable_impl()
{ {
return log_handler1_impl(true, LogTriStateTrue); return log_handler1_impl(true, LogTriStateSame, LogTriStateTrue);
} }
// Enables logs at runtime. // Enables logs at runtime.
@ -420,19 +436,31 @@ inline FILE *log_disable_impl()
// INTERNAL, DO NOT USE // INTERNAL, DO NOT USE
inline FILE *log_enable_impl() inline FILE *log_enable_impl()
{ {
return log_handler1_impl(true, LogTriStateFalse); return log_handler1_impl(true, LogTriStateSame, LogTriStateFalse);
} }
// Sets target fir logs, either by a file name or FILE* pointer (stdout, stderr, or any valid FILE*) // Sets target fir logs, either by a file name or FILE* pointer (stdout, stderr, or any valid FILE*)
#define log_set_target(target) log_set_target_impl(target) #define log_set_target(target) log_set_target_impl(target)
// INTERNAL, DO NOT USE // INTERNAL, DO NOT USE
inline FILE *log_set_target_impl(const std::string & filename) { return log_handler1_impl(true, LogTriStateSame, filename); } inline FILE *log_set_target_impl(const std::string & filename) { return log_handler1_impl(true, LogTriStateSame, LogTriStateSame, filename); }
inline FILE *log_set_target_impl(FILE *target) { return log_handler2_impl(true, LogTriStateSame, target); } inline FILE *log_set_target_impl(FILE *target) { return log_handler2_impl(true, LogTriStateSame, LogTriStateSame, target); }
// INTERNAL, DO NOT USE // INTERNAL, DO NOT USE
inline FILE *log_handler() { return log_handler1_impl(); } inline FILE *log_handler() { return log_handler1_impl(); }
// Enable or disable creating separate log files for each run.
// can ONLY be invoked BEFORE first log use.
#define log_multilog(enable) log_filename_generator_impl((enable) ? LogTriStateTrue : LogTriStateFalse, "", "")
// Enable or disable append mode for log file.
// can ONLY be invoked BEFORE first log use.
#define log_append(enable) log_append_impl(enable)
// INTERNAL, DO NOT USE
inline FILE *log_append_impl(bool enable)
{
return log_handler1_impl(true, enable ? LogTriStateTrue : LogTriStateFalse, LogTriStateSame);
}
inline void log_test() inline void log_test()
{ {
log_disable(); log_disable();
@ -494,6 +522,18 @@ inline bool log_param_single_parse(const std::string & param)
return true; return true;
} }
if (param == "--log-new")
{
log_multilog(true);
return true;
}
if (param == "--log-append")
{
log_append(true);
return true;
}
return false; return false;
} }
@ -523,7 +563,9 @@ inline void log_print_usage()
printf(" --log-disable Disable trace logs\n"); printf(" --log-disable Disable trace logs\n");
printf(" --log-enable Enable trace logs\n"); printf(" --log-enable Enable trace logs\n");
printf(" --log-file Specify a log filename (without extension)\n"); printf(" --log-file Specify a log filename (without extension)\n");
printf(" Log file will be tagged with unique ID and written as \"<name>.<ID>.log\"\n"); /* */ printf(" --log-new Create a separate new log file on start. "
"Each log file will have unique name: \"<name>.<ID>.log\"\n");
printf(" --log-append Don't truncate the old log file.\n");
} }
#define log_dump_cmdline(argc, argv) log_dump_cmdline_impl(argc, argv) #define log_dump_cmdline(argc, argv) log_dump_cmdline_impl(argc, argv)

View file

@ -39,6 +39,7 @@ void llama_sampling_free(struct llama_sampling_context * ctx) {
void llama_sampling_reset(llama_sampling_context * ctx) { void llama_sampling_reset(llama_sampling_context * ctx) {
if (ctx->grammar != NULL) { if (ctx->grammar != NULL) {
llama_grammar_free(ctx->grammar); llama_grammar_free(ctx->grammar);
ctx->grammar = NULL;
} }
if (!ctx->parsed_grammar.rules.empty()) { if (!ctx->parsed_grammar.rules.empty()) {

View file

@ -1045,6 +1045,7 @@ struct train_params_common get_default_train_params_common() {
params.n_batch = 8; params.n_batch = 8;
params.n_gradient_accumulation = 1; params.n_gradient_accumulation = 1;
params.n_epochs = -1; params.n_epochs = -1;
params.n_gpu_layers = 0;
params.custom_n_ctx = false; params.custom_n_ctx = false;
@ -1080,6 +1081,7 @@ struct train_params_common get_default_train_params_common() {
params.adam_beta2 = 0.999f; params.adam_beta2 = 0.999f;
params.adam_gclip = 1.0f; params.adam_gclip = 1.0f;
params.adam_eps_f = 0.0f; params.adam_eps_f = 0.0f;
return params; return params;
} }

View file

@ -44,6 +44,7 @@ struct train_params_common {
int n_batch; int n_batch;
int n_gradient_accumulation; int n_gradient_accumulation;
int n_epochs; int n_epochs;
int n_gpu_layers;
bool custom_n_ctx; bool custom_n_ctx;

View file

@ -652,7 +652,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32); GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
auto add_to_f32 = [] (struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { auto add_to_f32 = [] (struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
if (ggml_is_quantized(a->type)) { if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16) {
return ggml_add_cast(ctx, a, b, GGML_TYPE_F32); return ggml_add_cast(ctx, a, b, GGML_TYPE_F32);
} else if (a->type == GGML_TYPE_F32) { } else if (a->type == GGML_TYPE_F32) {
return ggml_add(ctx, a, b); return ggml_add(ctx, a, b);
@ -1459,6 +1459,17 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
} }
params->n_rank_w3 = std::stoi(argv[i]); params->n_rank_w3 = std::stoi(argv[i]);
params->custom_n_rank_w3 = true; params->custom_n_rank_w3 = true;
} else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
if (++i >= argc) {
invalid_param = true;
break;
}
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
params->common.n_gpu_layers = std::stoi(argv[i]);
#else
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
#endif
} else { } else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
train_print_usage(argc, argv, &default_params); train_print_usage(argc, argv, &default_params);
@ -1545,6 +1556,7 @@ int main(int argc, char ** argv) {
srand(params.common.seed); srand(params.common.seed);
struct llama_model_params llama_mparams = llama_model_default_params(); struct llama_model_params llama_mparams = llama_model_default_params();
llama_mparams.n_gpu_layers = params.common.n_gpu_layers;
llama_mparams.vocab_only = false; llama_mparams.vocab_only = false;
printf("%s: model base = '%s'\n", __func__, params.fn_model_base); printf("%s: model base = '%s'\n", __func__, params.fn_model_base);

View file

@ -0,0 +1,34 @@
#!/bin/bash
cd `dirname $0`
cd ../..
EXE="./finetune"
if [[ ! $LLAMA_MODEL_DIR ]]; then LLAMA_MODEL_DIR="./models"; fi
if [[ ! $LLAMA_TRAINING_DIR ]]; then LLAMA_TRAINING_DIR="."; fi
# MODEL="$LLAMA_MODEL_DIR/openllama-3b-v2-q8_0.gguf" # This is the model the readme uses.
MODEL="$LLAMA_MODEL_DIR/openllama-3b-v2.gguf" # An f16 model. Note in this case with "-g", you get an f32-format .BIN file that isn't yet supported if you use it with "main --lora" with GPU inferencing.
while getopts "dg" opt; do
case $opt in
d)
DEBUGGER="gdb --args"
;;
g)
EXE="./build/bin/Release/finetune"
GPUARG="--gpu-layers 25"
;;
esac
done
$DEBUGGER $EXE \
--model-base $MODEL \
$GPUARG \
--checkpoint-in chk-ol3b-shakespeare-LATEST.gguf \
--checkpoint-out chk-ol3b-shakespeare-ITERATION.gguf \
--lora-out lora-ol3b-shakespeare-ITERATION.bin \
--train-data "$LLAMA_TRAINING_DIR\shakespeare.txt" \
--save-every 10 \
--threads 10 --adam-iter 30 --batch 4 --ctx 64 \
--use-checkpointing

View file

@ -514,6 +514,15 @@ static __global__ void add_f16_f32_f16(const half * x, const float * y, half * d
dst[i] = __hadd(x[i], __float2half(y[i])); dst[i] = __hadd(x[i], __float2half(y[i]));
} }
static __global__ void add_f16_f32_f32(const half * x, const float * y, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = __half2float(x[i]) + y[i];
}
static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
@ -4694,6 +4703,11 @@ static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, co
add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k); add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
} }
static void add_f16_f32_f32_cuda(const half * x, const float * y, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
add_f16_f32_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
}
static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE; const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky); mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
@ -5993,7 +6007,10 @@ inline void ggml_cuda_op_add(
add_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream); add_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
add_f16_f32_f16_cuda((const half *) src0_dd, src1_dd, (half *) dst_dd, ggml_nelements(src0), main_stream); add_f16_f32_f16_cuda((const half *) src0_dd, src1_dd, (half *) dst_dd, ggml_nelements(src0), main_stream);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
add_f16_f32_f32_cuda((const half *) src0_dd, src1_dd, dst_dd, ggml_nelements(src0), main_stream);
} else { } else {
fprintf(stderr, "src0->type: %d dst->type: %d\n", src0->type, dst->type);
GGML_ASSERT(false); GGML_ASSERT(false);
} }

View file

@ -1001,11 +1001,15 @@ void ggml_metal_graph_compute(
} break; } break;
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
{ {
const int nth = MIN(32, ne00); int nth = 32; // SIMD width
if (ne00%4 == 0) { if (ne00%4 == 0) {
[encoder setComputePipelineState:ctx->pipeline_soft_max_4]; [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
} else { } else {
do {
nth *= 2;
} while (nth <= ne00 && nth <= 1024);
nth /= 2;
[encoder setComputePipelineState:ctx->pipeline_soft_max]; [encoder setComputePipelineState:ctx->pipeline_soft_max];
} }
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1013,8 +1017,9 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
{ {

View file

@ -184,36 +184,73 @@ kernel void kernel_soft_max(
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01, constant int64_t & ne01,
constant int64_t & ne02, constant int64_t & ne02,
uint3 tgpig[[threadgroup_position_in_grid]], threadgroup float * buf [[threadgroup(0)]],
uint3 tpitg[[thread_position_in_threadgroup]], uint tgpig[[threadgroup_position_in_grid]],
uint3 ntg[[threads_per_threadgroup]]) { uint tpitg[[thread_position_in_threadgroup]],
const int64_t i03 = tgpig[2]; uint sgitg[[simdgroup_index_in_threadgroup]],
const int64_t i02 = tgpig[1]; uint tiisg[[thread_index_in_simdgroup]],
const int64_t i01 = tgpig[0]; uint ntg[[threads_per_threadgroup]]) {
const int64_t i03 = (tgpig) / (ne02*ne01);
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
// parallel max // parallel max
float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY; float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
lmax = MAX(lmax, psrc0[i00]); lmax = MAX(lmax, psrc0[i00]);
} }
const float max = simd_max(lmax);
float max = simd_max(lmax);
if (tiisg == 0) {
buf[sgitg] = max;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// broadcast, simd group number is ntg / 32
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
if (tpitg < i) {
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max = buf[0];
// parallel sum // parallel sum
float lsum = 0.0f; float lsum = 0.0f;
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
const float exp_psrc0 = exp(psrc0[i00] - max); const float exp_psrc0 = exp(psrc0[i00] - max);
lsum += exp_psrc0; lsum += exp_psrc0;
// Remember the result of exp here. exp is expensive, so we really do not // Remember the result of exp here. exp is expensive, so we really do not
// whish to compute it twice. // wish to compute it twice.
pdst[i00] = exp_psrc0; pdst[i00] = exp_psrc0;
} }
const float sum = simd_sum(lsum); float sum = simd_sum(lsum);
if (tiisg == 0) {
buf[sgitg] = sum;
}
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { threadgroup_barrier(mem_flags::mem_threadgroup);
// broadcast, simd group number is ntg / 32
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
if (tpitg < i) {
buf[tpitg] += buf[tpitg + i];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf[0];
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
pdst[i00] /= sum; pdst[i00] /= sum;
} }
} }
@ -224,37 +261,73 @@ kernel void kernel_soft_max_4(
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01, constant int64_t & ne01,
constant int64_t & ne02, constant int64_t & ne02,
uint3 tgpig[[threadgroup_position_in_grid]], threadgroup float * buf [[threadgroup(0)]],
uint3 tpitg[[thread_position_in_threadgroup]], uint tgpig[[threadgroup_position_in_grid]],
uint3 ntg[[threads_per_threadgroup]]) { uint tpitg[[thread_position_in_threadgroup]],
const int64_t i03 = tgpig[2]; uint sgitg[[simdgroup_index_in_threadgroup]],
const int64_t i02 = tgpig[1]; uint tiisg[[thread_index_in_simdgroup]],
const int64_t i01 = tgpig[0]; uint ntg[[threads_per_threadgroup]]) {
const int64_t i03 = (tgpig) / (ne02*ne01);
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
// parallel max // parallel max
float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY; float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
lmax4 = fmax(lmax4, psrc4[i00]); lmax4 = fmax(lmax4, psrc4[i00]);
} }
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
const float max = simd_max(lmax); const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
float max = simd_max(lmax);
if (tiisg == 0) {
buf[sgitg] = max;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// broadcast, simd group number is ntg / 32
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
if (tpitg < i) {
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max = buf[0];
// parallel sum // parallel sum
float4 lsum4 = 0.0f; float4 lsum4 = 0.0f;
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) { for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
const float4 exp_psrc4 = exp(psrc4[i00] - max); const float4 exp_psrc4 = exp(psrc4[i00] - max);
lsum4 += exp_psrc4; lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4; pdst4[i00] = exp_psrc4;
} }
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
const float sum = simd_sum(lsum); const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
float sum = simd_sum(lsum);
if (tiisg == 0) {
buf[sgitg] = sum;
}
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) { threadgroup_barrier(mem_flags::mem_threadgroup);
// broadcast, simd group number is ntg / 32
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
if (tpitg < i) {
buf[tpitg] += buf[tpitg + i];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf[0];
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
pdst4[i00] /= sum; pdst4[i00] /= sum;
} }
} }
@ -274,7 +347,7 @@ kernel void kernel_diag_mask_inf(
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
} else { } else {
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
} }
} }
kernel void kernel_diag_mask_inf_8( kernel void kernel_diag_mask_inf_8(

View file

@ -718,6 +718,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
__riscv_vse8_v_i8m1(y[i].qs , vs, vl); __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
} }
#else #else
GGML_UNUSED(nb);
// scalar // scalar
quantize_row_q8_0_reference(x, y, k); quantize_row_q8_0_reference(x, y, k);
#endif #endif
@ -971,6 +972,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
y[i].s = sum*d; y[i].s = sum*d;
} }
#else #else
GGML_UNUSED(nb);
// scalar // scalar
quantize_row_q8_1_reference(x, y, k); quantize_row_q8_1_reference(x, y, k);
#endif #endif

49
ggml.c
View file

@ -3153,7 +3153,7 @@ static struct ggml_tensor * ggml_add_cast_impl(
// TODO: support less-strict constraint // TODO: support less-strict constraint
// GGML_ASSERT(ggml_can_repeat(b, a)); // GGML_ASSERT(ggml_can_repeat(b, a));
GGML_ASSERT(ggml_can_repeat_rows(b, a)); GGML_ASSERT(ggml_can_repeat_rows(b, a));
GGML_ASSERT(ggml_is_quantized(a->type)); // currently only supported for quantized input GGML_ASSERT(ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16); // currently only supported for quantized input and f16
bool is_node = false; bool is_node = false;
@ -6927,9 +6927,15 @@ static void ggml_compute_forward_add_f16_f32(
GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F16);
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); if (dst->type == GGML_TYPE_F32) {
GGML_ASSERT( nb0 == sizeof(float));
}
else {
GGML_ASSERT(dst->type == GGML_TYPE_F16);
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
}
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
// rows per thread // rows per thread
@ -6940,18 +6946,35 @@ static void ggml_compute_forward_add_f16_f32(
const int ir1 = MIN(ir0 + dr, nr); const int ir1 = MIN(ir0 + dr, nr);
if (nb10 == sizeof(float)) { if (nb10 == sizeof(float)) {
for (int ir = ir0; ir < ir1; ++ir) { if (dst->type == GGML_TYPE_F16) {
// src0, src1 and dst are same shape => same indices for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne2*ne1); // src0, src1 and dst are same shape => same indices
const int i2 = (ir - i3*ne2*ne1)/ne1; const int i3 = ir/(ne2*ne1);
const int i1 = (ir - i3*ne2*ne1 - i2*ne1); const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
for (int i = 0; i < ne0; i++) { for (int i = 0; i < ne0; i++) {
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
}
}
} else {
for (int ir = ir0; ir < ir1; ++ir) {
// src0, src1 and dst are same shape => same indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
for (int i = 0; i < ne0; i++) {
dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
}
} }
} }
} }

2343
llama.cpp

File diff suppressed because it is too large Load diff