* ggml : add ggml_flash_attn_ext API * ggml : fix GQA support in ggml_flash_attn_ext * ggml : online attention (CPU) * metal : initial implementation * metal : f16 precision * metal : reduce branches * metal : specialize for head size * wip : 8 rows per simd group * wip : 4 rows per simd group * wip : template for rows per warp * metal : parallelize across KV size * metal : parallel reduce across heads * metal : efficient flash_attn_f16 implementation * metal : avoid redundant loads of the attention * metal : scale and mask in matrix form * metal : fix comment * llama : avoid ggml_cast, use F32 query * metal : add parallel reduce version (disabled) * metal : move output into local memory + optimize - the result from each simdgroup now stays in the registers - significantly reduced SRAM usage - more efficient skipping of -INF blocks - avoid simdgroup barrier in hot loop - add comments * metal : add tests, fix scaling, support C > 32 * metal : improve precision * ggml : fix f16 mad * metal : minor * metal : support Q > 8 * tests : add ATTN tests * metal : disable buffer allocation logs * tests : more * metal : faster inner loop for C == 32 * metal : fix array initialization * tests : ifdef * ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext * ggml : fix ggml_soft_max mask requirement * cuda : fix soft_max to use correct mask size * cuda : add flash_attn kernel (wip) * metal : optimize softmax for C > 32 * metal : optimize softmax * tests : minor fix * cuda : avoid zeroing fragments * tests : update dims * cuda : fix __hisinf() result check * cuda : avoid warp_reduce for smax * cuda : use int instead of int64_t Noticeably improves performance (thanks to Johannes) * cuda : make loops use the same loop values Thanks Johannes again for the tip * cuda : unroll some of the loops * cuda : avoid __hisinf branches * cuda : use half2 in softmax * cuda : switch to 1 warp for bs > 16 * cuda : speed-up reduce part of the kernel * cuda : unroll Q*K^T loop * cuda : fix -INF block check * cuda : simplify softmax * cuda : fix matrix names * cuda : minor * llama : adapt to F16 KQ_pos * llama : adapt new models to F16 KQ_mask * ggml : fix F16 store (ARM NEON) * llama : fix type of KQ_mask and KQ_pos * ggml : fix CPU soft_max * tests : add hs=256 * cuda : fix build * metal : improve perf via smaller int registers * cuda : adapt soft_max to F16 mask and pos * CUDA: faster FlashAttention, kernel for bs == 1 * 16 cols for Phi-2 * no vec for hs, no hs==256 ncols==32 for Volta * adjust kernel selection logic * 4 warps, 256 stride for all D * no ncols == 64 * Multiple parallel blocks for batch size 1 * fix compile warnings * fix excessive KQ_b loads * fix cmake build * fix KV cache padding, NaN from INFINITY (#6438) * llama : flash_attn cparam + fix defrag * server: support flash_attn param * server: bench: enable flash_attn param * CUDA: refactor host code, dyn. par. blocks * fix flash_attn_vec_f16 race condition * flush softmax exp below threshold to 0 * store temp KQ in registers * Calculate KQ as FP32 if KQV has GGML_PREC_F32 * Add __hgt2_mask implementation for CUDA 11 * fix KQ FP32 precision fpr parallel_blocks > 1 * llama-bench : add -fa,--flash-attn arg * metal : add BS=1 kernel for flash attention (#6508) * metal : add BS=1 kernel for flash attention (wip) * metal : support more than 1 warps * metal : opts * metal : opt * metal : switch to parallel reduce * metal : reduce registers * metal : simplify * metal : initial FA vec kernel * metal : use F32 attention accumulators * batched-bench : add fattn arg * llama : simplify llama_build_kv_store ggml-ci * llama : adapt build_olmo to changes * ggml : fix arm fp16 store on windows * metal : clean-up * metal : clean-up kernel code * metal : minor * tests : remove benchmarks ggml-ci * ggml : fix avx512 const correctness ggml-ci * ggml : fix soft_max with bias on CPU ggml-ci * common : print --flash-attn in help * ggml : fix num dimensions in ggml_flash_attn_ext * llama : force disable flash attention for incompatible models * ggml : ggml_soft_max support F16/F32 mask/pos ggml-ci * cuda : uint -> uint32_t * cuda : "constexpr dim3" -> "const dim3" ggml-ci * cuda : try to fix __hgt2_mask ggml-ci * ggml : add TODO's for F16/F32 mask/pos support in other backends * llama : replace bool need_kq_pos with use_alibi * llama : prep ALiBi support for BERT models ggml-ci * llama : fix n_batch requirements ggml-ci * cont * server : add help for --flash-attn arg * llama : disable FA for AMD * tests : remove TMP_ATTN_BENCH ggml-ci * llama : support save/load state with FA enabled ggml-ci * ci : add CUDA save-load-state tests ggml-ci * llama : llama_kv_cache_clear zeroes data + fix save-load seq ggml-ci * llama : fix copy-paste errors, add TODO * llama : disallow incompatible states * llama : update llama_state_get_size after v_trans field * metal : remove tmp log * llama : add static reminder for llama_state_get_size * metal : fix max nsg ggml-ci * ci : fix arg order ggml-ci --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de> Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
		
			
				
	
	
		
			309 lines
		
	
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			309 lines
		
	
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import argparse
 | |
| import json
 | |
| import os
 | |
| import re
 | |
| import signal
 | |
| import socket
 | |
| import subprocess
 | |
| import sys
 | |
| import threading
 | |
| import time
 | |
| import traceback
 | |
| from contextlib import closing
 | |
| from datetime import datetime
 | |
| 
 | |
| import matplotlib
 | |
| import matplotlib.dates
 | |
| import matplotlib.pyplot as plt
 | |
| import requests
 | |
| from statistics import mean
 | |
| 
 | |
| 
 | |
| def main(args_in: list[str] | None = None) -> None:
 | |
|     parser = argparse.ArgumentParser(description="Start server benchmark scenario")
 | |
|     parser.add_argument("--name", type=str, help="Bench name", required=True)
 | |
|     parser.add_argument("--runner-label", type=str, help="Runner label", required=True)
 | |
|     parser.add_argument("--branch", type=str, help="Branch name", default="detached")
 | |
|     parser.add_argument("--commit", type=str, help="Commit name", default="dirty")
 | |
|     parser.add_argument("--host", type=str, help="Server listen host", default="0.0.0.0")
 | |
|     parser.add_argument("--port", type=int, help="Server listen host", default="8080")
 | |
|     parser.add_argument("--model-path-prefix", type=str, help="Prefix where to store the model files", default="models")
 | |
|     parser.add_argument("--n-prompts", type=int,
 | |
|                         help="SERVER_BENCH_N_PROMPTS: total prompts to randomly select in the benchmark", required=True)
 | |
|     parser.add_argument("--max-prompt-tokens", type=int,
 | |
|                         help="SERVER_BENCH_MAX_PROMPT_TOKENS: maximum prompt tokens to filter out in the dataset",
 | |
|                         required=True)
 | |
|     parser.add_argument("--max-tokens", type=int,
 | |
|                         help="SERVER_BENCH_MAX_CONTEXT: maximum context size of the completions request to filter out in the dataset: prompt + predicted tokens",
 | |
|                         required=True)
 | |
|     parser.add_argument("--hf-repo", type=str, help="Hugging Face model repository", required=True)
 | |
|     parser.add_argument("--hf-file", type=str, help="Hugging Face model file", required=True)
 | |
|     parser.add_argument("-ngl", "--n-gpu-layers", type=int, help="layers to the GPU for computation", required=True)
 | |
|     parser.add_argument("--ctx-size", type=int, help="Set the size of the prompt context", required=True)
 | |
|     parser.add_argument("--parallel", type=int, help="Set the number of slots for process requests", required=True)
 | |
|     parser.add_argument("--batch-size", type=int, help="Set the batch size for prompt processing", required=True)
 | |
|     parser.add_argument("--ubatch-size", type=int, help="physical maximum batch size", required=True)
 | |
|     parser.add_argument("--scenario", type=str, help="Scenario to run", required=True)
 | |
|     parser.add_argument("--duration", type=str, help="Bench scenario", required=True)
 | |
| 
 | |
|     args = parser.parse_args(args_in)
 | |
| 
 | |
|     start_time = time.time()
 | |
| 
 | |
|     # Start the server and performance scenario
 | |
|     try:
 | |
|         server_process = start_server(args)
 | |
|     except Exception:
 | |
|         print("bench: server start error :")
 | |
|         traceback.print_exc(file=sys.stdout)
 | |
|         sys.exit(1)
 | |
| 
 | |
|     # start the benchmark
 | |
|     try:
 | |
|         start_benchmark(args)
 | |
| 
 | |
|         iterations = 0
 | |
|         with open("results.github.env", 'w') as github_env:
 | |
|             # parse output
 | |
|             with open('k6-results.json', 'r') as bench_results:
 | |
|                 # Load JSON data from file
 | |
|                 data = json.load(bench_results)
 | |
|                 for metric_name in data['metrics']:
 | |
|                     for metric_metric in data['metrics'][metric_name]:
 | |
|                         value = data['metrics'][metric_name][metric_metric]
 | |
|                         if isinstance(value, float) or isinstance(value, int):
 | |
|                             value = round(value, 2)
 | |
|                             data['metrics'][metric_name][metric_metric]=value
 | |
|                             github_env.write(
 | |
|                                 f"{escape_metric_name(metric_name)}_{escape_metric_name(metric_metric)}={value}\n")
 | |
|                 iterations = data['root_group']['checks']['success completion']['passes']
 | |
| 
 | |
|     except Exception:
 | |
|         print("bench: error :")
 | |
|         traceback.print_exc(file=sys.stdout)
 | |
| 
 | |
|     # Stop the server
 | |
|     if server_process:
 | |
|         try:
 | |
|             print(f"bench: shutting down server pid={server_process.pid} ...")
 | |
|             if os.name == 'nt':
 | |
|                 interrupt = signal.CTRL_C_EVENT
 | |
|             else:
 | |
|                 interrupt = signal.SIGINT
 | |
|             server_process.send_signal(interrupt)
 | |
|             server_process.wait(0.5)
 | |
| 
 | |
|         except subprocess.TimeoutExpired:
 | |
|             print(f"server still alive after 500ms, force-killing pid={server_process.pid} ...")
 | |
|             server_process.kill()  # SIGKILL
 | |
|             server_process.wait()
 | |
| 
 | |
|         while is_server_listening(args.host, args.port):
 | |
|             time.sleep(0.1)
 | |
| 
 | |
|     title = (f"llama.cpp {args.name} on {args.runner_label}\n "
 | |
|              f"duration={args.duration} {iterations} iterations")
 | |
|     xlabel = (f"{args.hf_repo}/{args.hf_file}\n"
 | |
|               f"parallel={args.parallel} ctx-size={args.ctx_size} ngl={args.n_gpu_layers} batch-size={args.batch_size} ubatch-size={args.ubatch_size} pp={args.max_prompt_tokens} pp+tg={args.max_tokens}\n"
 | |
|               f"branch={args.branch} commit={args.commit}")
 | |
| 
 | |
|     # Prometheus
 | |
|     end_time = time.time()
 | |
|     prometheus_metrics = {}
 | |
|     if is_server_listening("0.0.0.0", 9090):
 | |
|         metrics = ['prompt_tokens_seconds', 'predicted_tokens_seconds',
 | |
|                    'kv_cache_usage_ratio', 'requests_processing', 'requests_deferred']
 | |
| 
 | |
|         for metric in metrics:
 | |
|             resp = requests.get(f"http://localhost:9090/api/v1/query_range",
 | |
|                                 params={'query': 'llamacpp:' + metric, 'start': start_time, 'end': end_time, 'step': 2})
 | |
| 
 | |
|             with open(f"{metric}.json", 'w') as metric_json:
 | |
|                 metric_json.write(resp.text)
 | |
| 
 | |
|             if resp.status_code != 200:
 | |
|                 print(f"bench: unable to extract prometheus metric {metric}: {resp.text}")
 | |
|             else:
 | |
|                 metric_data = resp.json()
 | |
|                 values = metric_data['data']['result'][0]['values']
 | |
|                 timestamps, metric_values = zip(*values)
 | |
|                 metric_values = [float(value) for value in metric_values]
 | |
|                 prometheus_metrics[metric] = metric_values
 | |
|                 timestamps_dt = [datetime.fromtimestamp(int(ts)) for ts in timestamps]
 | |
|                 plt.figure(figsize=(16, 10), dpi=80)
 | |
|                 plt.plot(timestamps_dt, metric_values, label=metric)
 | |
|                 plt.xticks(rotation=0, fontsize=14, horizontalalignment='center', alpha=.7)
 | |
|                 plt.yticks(fontsize=12, alpha=.7)
 | |
| 
 | |
|                 ylabel = f"llamacpp:{metric}"
 | |
|                 plt.title(title,
 | |
|                           fontsize=14, wrap=True)
 | |
|                 plt.grid(axis='both', alpha=.3)
 | |
|                 plt.ylabel(ylabel, fontsize=22)
 | |
|                 plt.xlabel(xlabel, fontsize=14, wrap=True)
 | |
|                 plt.gca().xaxis.set_major_locator(matplotlib.dates.MinuteLocator())
 | |
|                 plt.gca().xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%Y-%m-%d %H:%M:%S"))
 | |
|                 plt.gcf().autofmt_xdate()
 | |
| 
 | |
|                 # Remove borders
 | |
|                 plt.gca().spines["top"].set_alpha(0.0)
 | |
|                 plt.gca().spines["bottom"].set_alpha(0.3)
 | |
|                 plt.gca().spines["right"].set_alpha(0.0)
 | |
|                 plt.gca().spines["left"].set_alpha(0.3)
 | |
| 
 | |
|                 # Save the plot as a jpg image
 | |
|                 plt.savefig(f'{metric}.jpg', dpi=60)
 | |
|                 plt.close()
 | |
| 
 | |
|                 # Mermaid format in case images upload failed
 | |
|                 with (open(f"{metric}.mermaid", 'w') as mermaid_f):
 | |
|                     mermaid = (
 | |
|                     f"""---
 | |
| config:
 | |
|     xyChart:
 | |
|         titleFontSize: 12
 | |
|         width: 900
 | |
|         height: 600
 | |
|     themeVariables:
 | |
|         xyChart:
 | |
|             titleColor: "#000000"
 | |
| ---
 | |
| xychart-beta
 | |
|     title "{title}"
 | |
|     y-axis "llamacpp:{metric}"
 | |
|     x-axis "llamacpp:{metric}" {int(min(timestamps))} --> {int(max(timestamps))}
 | |
|     line [{', '.join([str(round(float(value), 2)) for value in metric_values])}]
 | |
|                     """)
 | |
|                     mermaid_f.write(mermaid)
 | |
| 
 | |
|     # 140 chars max for commit status description
 | |
|     bench_results = {
 | |
|         "i": iterations,
 | |
|         "req": {
 | |
|             "p95": round(data['metrics']["http_req_duration"]["p(95)"], 2),
 | |
|             "avg": round(data['metrics']["http_req_duration"]["avg"], 2),
 | |
|         },
 | |
|         "pp": {
 | |
|             "p95": round(data['metrics']["llamacpp_prompt_processing_second"]["p(95)"], 2),
 | |
|             "avg": round(data['metrics']["llamacpp_prompt_processing_second"]["avg"], 2),
 | |
|             "0": round(mean(prometheus_metrics['prompt_tokens_seconds']), 2),
 | |
|         },
 | |
|         "tg": {
 | |
|             "p95": round(data['metrics']["llamacpp_tokens_second"]["p(95)"], 2),
 | |
|             "avg": round(data['metrics']["llamacpp_tokens_second"]["avg"], 2),
 | |
|             "0": round(mean(prometheus_metrics['predicted_tokens_seconds']), 2),
 | |
|         },
 | |
|     }
 | |
|     with open("results.github.env", 'a') as github_env:
 | |
|         github_env.write(f"BENCH_RESULTS={json.dumps(bench_results, indent=None, separators=(',', ':') )}\n")
 | |
|         github_env.write(f"BENCH_ITERATIONS={iterations}\n")
 | |
| 
 | |
|         title = title.replace('\n', ' ')
 | |
|         xlabel = xlabel.replace('\n', ' ')
 | |
|         github_env.write(f"BENCH_GRAPH_TITLE={title}\n")
 | |
|         github_env.write(f"BENCH_GRAPH_XLABEL={xlabel}\n")
 | |
| 
 | |
| 
 | |
| def start_benchmark(args):
 | |
|     k6_path = './k6'
 | |
|     if 'BENCH_K6_BIN_PATH' in os.environ:
 | |
|         k6_path = os.environ['BENCH_K6_BIN_PATH']
 | |
|     k6_args = [
 | |
|         'run', args.scenario,
 | |
|         '--no-color',
 | |
|     ]
 | |
|     k6_args.extend(['--duration', args.duration])
 | |
|     k6_args.extend(['--iterations', args.n_prompts])
 | |
|     k6_args.extend(['--vus', args.parallel])
 | |
|     k6_args.extend(['--summary-export', 'k6-results.json'])
 | |
|     args = f"SERVER_BENCH_N_PROMPTS={args.n_prompts} SERVER_BENCH_MAX_PROMPT_TOKENS={args.max_prompt_tokens} SERVER_BENCH_MAX_CONTEXT={args.max_tokens} "
 | |
|     args = args + ' '.join([str(arg) for arg in [k6_path, *k6_args]])
 | |
|     print(f"bench: starting k6 with: {args}")
 | |
|     k6_completed = subprocess.run(args, shell=True, stdout=sys.stdout, stderr=sys.stderr)
 | |
|     if k6_completed.returncode != 0:
 | |
|         raise Exception("bench: unable to run k6")
 | |
| 
 | |
| 
 | |
| def start_server(args):
 | |
|     server_process = start_server_background(args)
 | |
| 
 | |
|     attempts = 0
 | |
|     max_attempts = 20
 | |
|     if 'GITHUB_ACTIONS' in os.environ:
 | |
|         max_attempts *= 2
 | |
| 
 | |
|     while not is_server_listening(args.host, args.port):
 | |
|         attempts += 1
 | |
|         if attempts > max_attempts:
 | |
|             assert False, "server not started"
 | |
|         print(f"bench:     waiting for server to start ...")
 | |
|         time.sleep(0.5)
 | |
| 
 | |
|     print("bench: server started.")
 | |
|     return server_process
 | |
| 
 | |
| 
 | |
| def start_server_background(args):
 | |
|     # Start the server
 | |
|     server_path = '../../../build/bin/server'
 | |
|     if 'LLAMA_SERVER_BIN_PATH' in os.environ:
 | |
|         server_path = os.environ['LLAMA_SERVER_BIN_PATH']
 | |
|     server_args = [
 | |
|         '--host', args.host,
 | |
|         '--port', args.port,
 | |
|     ]
 | |
|     model_file = args.model_path_prefix + os.path.sep + args.hf_file
 | |
|     model_dir  = os.path.dirname(model_file)
 | |
|     if not os.path.exists(model_dir):
 | |
|         os.makedirs(model_dir)
 | |
|     server_args.extend(['--model', model_file])
 | |
|     server_args.extend(['--hf-repo', args.hf_repo])
 | |
|     server_args.extend(['--hf-file', args.hf_file])
 | |
|     server_args.extend(['--n-gpu-layers', args.n_gpu_layers])
 | |
|     server_args.extend(['--ctx-size', args.ctx_size])
 | |
|     server_args.extend(['--parallel', args.parallel])
 | |
|     server_args.extend(['--batch-size', args.batch_size])
 | |
|     server_args.extend(['--ubatch-size', args.ubatch_size])
 | |
|     server_args.extend(['--n-predict', args.max_tokens * 2])
 | |
|     server_args.extend(['--defrag-thold', "0.1"])
 | |
|     server_args.append('--cont-batching')
 | |
|     server_args.append('--metrics')
 | |
|     server_args.append('--flash-attn')
 | |
|     server_args.extend(['--log-format', "text"])
 | |
|     args = [str(arg) for arg in [server_path, *server_args]]
 | |
|     print(f"bench: starting server with: {' '.join(args)}")
 | |
|     pkwargs = {
 | |
|         'stdout': subprocess.PIPE,
 | |
|         'stderr': subprocess.PIPE
 | |
|     }
 | |
|     server_process = subprocess.Popen(
 | |
|         args,
 | |
|         **pkwargs)
 | |
| 
 | |
|     def server_log(in_stream, out_stream):
 | |
|         for line in iter(in_stream.readline, b''):
 | |
|             print(line.decode('utf-8'), end='', file=out_stream)
 | |
| 
 | |
|     thread_stdout = threading.Thread(target=server_log, args=(server_process.stdout, sys.stdout))
 | |
|     thread_stdout.start()
 | |
|     thread_stderr = threading.Thread(target=server_log, args=(server_process.stderr, sys.stderr))
 | |
|     thread_stderr.start()
 | |
| 
 | |
|     return server_process
 | |
| 
 | |
| 
 | |
| def is_server_listening(server_fqdn, server_port):
 | |
|     with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
 | |
|         result = sock.connect_ex((server_fqdn, server_port))
 | |
|         _is_server_listening = result == 0
 | |
|         if _is_server_listening:
 | |
|             print(f"server is listening on {server_fqdn}:{server_port}...")
 | |
|         return _is_server_listening
 | |
| 
 | |
| 
 | |
| def escape_metric_name(metric_name):
 | |
|     return re.sub('[^A-Z0-9]', '_', metric_name.upper())
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     main()
 |