diff --git a/ggml_vk_generate_shaders.py b/ggml_vk_generate_shaders.py index 7289bf028..866dd4889 100644 --- a/ggml_vk_generate_shaders.py +++ b/ggml_vk_generate_shaders.py @@ -1,5 +1,6 @@ #!/usr/bin/env python +import argparse import asyncio import os import sys @@ -769,6 +770,7 @@ void main() { } """ +GLSLC = "glslc" VK_NUM_TYPES = 16 @@ -809,41 +811,44 @@ K_QUANTS_PER_ITERATION = 1 async def string_to_spv_file(name, code, defines, fp16): - with NamedTemporaryFile(mode="w") as f: - f.write(code) - f.flush() + f = NamedTemporaryFile(mode="w", delete=False) + f.write(code) + f.flush() - cmd = ["glslc", "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", f.name, "-o", os.path.join("vk_shaders", f"{name}{'_fp32' if not fp16 else ''}.comp")] + cmd = [GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", f.name, "-o", os.path.join("vk_shaders", f"{name}{'_fp32' if not fp16 else ''}.comp")] + cmd.extend([f"-D{key}={value}" for key, value in defines.items()]) + + proc = await asyncio.create_subprocess_exec(*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) + + stdout, stderr = await proc.communicate() + + stdout = stdout.decode() + error = stderr.decode() + + if proc.returncode: + # Generate preprocessed code + cmd = [GLSLC, "-E", f.name] cmd.extend([f"-D{key}={value}" for key, value in defines.items()]) proc = await asyncio.create_subprocess_exec(*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) stdout, stderr = await proc.communicate() - stdout = stdout.decode() - error = stderr.decode() + print(" ".join(cmd)) if proc.returncode: - # Generate preprocessed code - cmd = ["glslc", "-E", f.name] - cmd.extend([f"-D{key}={value}" for key, value in defines.items()]) + raise RuntimeError(f"{name=} {f.name=} {stdout=} {stderr=}") - proc = await asyncio.create_subprocess_exec(*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) + preprocessed_code = stdout.decode() - stdout, stderr = await proc.communicate() + cmd.extend([f"-D{key}={value}" for key, value in defines.items()]) + code_with_lines = "\n".join([f"{i}: {line}" for i, line in enumerate(preprocessed_code.splitlines())]) + print(f"ERROR compiling {name}\n\n{code_with_lines}\n\n{error=}") + os.remove(f.name) + sys.exit(proc.returncode) - print(" ".join(cmd)) - - if proc.returncode: - raise RuntimeError(f"{name=} {f.name=} {stdout=} {stderr=}") - - preprocessed_code = stdout.decode() - - cmd.extend([f"-D{key}={value}" for key, value in defines.items()]) - code_with_lines = "\n".join([f"{i}: {line}" for i, line in enumerate(preprocessed_code.splitlines())]) - print(f"ERROR compiling {name}\n\n{code_with_lines}\n\n{error=}") - sys.exit(proc.returncode) + os.remove(f.name) async def main(): @@ -963,4 +968,14 @@ async def main(): await asyncio.gather(*tasks) -asyncio.run(main()) +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="GGML Vulkan Shader Generator") + + parser.add_argument("--glslc", help="Path to glslc") + + args = parser.parse_args() + + if args.glslc: + GLSLC = args.glslc + + asyncio.run(main())