[CANN] Add Ascend NPU backend (#6035)
* [CANN] Add Ascend NPU backend Ascend is a full-stack AI computing infrastructure for industry applications and services based on Huawei Ascend processors and software. CANN (Compute Architecture of Neural Networks), developped by Huawei, is a heterogeneous computing architecture for AI. Co-authored-by: wangshuai09 <391746016@qq.com> * delete trailing whitespaces * Modify the code based on review comment * Rename LLAMA_CANN to GGML_CANN * Make ggml-common.h private * add ggml_cann prefix for acl funcs * Add logging for CANN backend * Delete Trailing whitespace --------- Co-authored-by: wangshuai09 <391746016@qq.com>
This commit is contained in:
		
							parent
							
								
									da3913d8f9
								
							
						
					
					
						commit
						1bdd8ae19f
					
				
					 27 changed files with 10756 additions and 8 deletions
				
			
		|  | @ -106,6 +106,7 @@ llama_option_depr(WARNING     LLAMA_NATIVE              GGML_NATIVE) | |||
| llama_option_depr(WARNING     LLAMA_RPC                 GGML_RPC) | ||||
| llama_option_depr(WARNING     LLAMA_SYCL                GGML_SYCL) | ||||
| llama_option_depr(WARNING     LLAMA_SYCL_F16            GGML_SYCL_F16) | ||||
| llama_option_depr(WARNING     LLAMA_CANN                GGML_CANN) | ||||
| 
 | ||||
| # | ||||
| # build the library | ||||
|  |  | |||
|  | @ -23,6 +23,10 @@ | |||
| #include "ggml-cuda.h" | ||||
| #include "ggml-sycl.h" | ||||
| 
 | ||||
| #ifdef GGML_USE_CANN | ||||
| #include "ggml-cann.h" | ||||
| #endif | ||||
| 
 | ||||
| // utils
 | ||||
| static uint64_t get_time_ns() { | ||||
|     using clock = std::chrono::high_resolution_clock; | ||||
|  | @ -120,6 +124,17 @@ static std::string get_gpu_info() { | |||
|             id += "/"; | ||||
|         } | ||||
|     } | ||||
| #endif | ||||
| #ifdef GGML_USE_CANN | ||||
|     uint32_t count = ggml_backend_cann_get_device_count(); | ||||
|     for (uint32_t i = 0; i < count; i++) { | ||||
|         char buf[128]; | ||||
|         ggml_backend_cann_get_device_description(i, buf, sizeof(buf)); | ||||
|         id += buf; | ||||
|         if (i < count - 1) { | ||||
|             id += "/"; | ||||
|         } | ||||
|     } | ||||
| #endif | ||||
|     // TODO: other backends
 | ||||
|     return id; | ||||
|  |  | |||
|  | @ -16,6 +16,10 @@ | |||
| #include "ggml-metal.h" | ||||
| #endif | ||||
| 
 | ||||
| #ifdef GGML_USE_CANN | ||||
| #include "ggml-cann.h" | ||||
| #endif | ||||
| 
 | ||||
| #define STB_IMAGE_IMPLEMENTATION | ||||
| #include "stb_image.h" | ||||
| 
 | ||||
|  | @ -1001,6 +1005,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { | |||
|     LOG_TEE("%s: CLIP using Metal backend\n", __func__); | ||||
| #endif | ||||
| 
 | ||||
| #ifdef GGML_USE_CANN | ||||
|     new_clip->backend = ggml_backend_cann_init(0); | ||||
|     LOG_TEE("%s: CLIP using CANN backend\n", __func__); | ||||
| #endif | ||||
| 
 | ||||
| 
 | ||||
|     if (!new_clip->backend) { | ||||
|         new_clip->backend = ggml_backend_cpu_init(); | ||||
|  |  | |||
							
								
								
									
										125
									
								
								ggml/include/ggml-cann.h
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										125
									
								
								ggml/include/ggml-cann.h
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,125 @@ | |||
| /*
 | ||||
|  * Copyright (c) 2023-2024 The ggml authors | ||||
|  * | ||||
|  * Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
|  * of this software and associated documentation files (the "Software"), to | ||||
|  * deal in the Software without restriction, including without limitation the | ||||
|  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or | ||||
|  * sell copies of the Software, and to permit persons to whom the Software is | ||||
|  * furnished to do so, subject to the following conditions: | ||||
|  * | ||||
|  * The above copyright notice and this permission notice shall be included in | ||||
|  * all copies or substantial portions of the Software. | ||||
|  * | ||||
|  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||
|  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||
|  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||
|  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||
|  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | ||||
|  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS | ||||
|  * IN THE SOFTWARE. | ||||
|  */ | ||||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include "ggml-backend.h" | ||||
| #include "ggml.h" | ||||
| 
 | ||||
| #ifdef __cplusplus | ||||
| extern "C" { | ||||
| #endif | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Maximum number of CANN devices supported. | ||||
|  */ | ||||
| #define GGML_CANN_MAX_DEVICES 16 | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Initializes the CANN backend for a specified device. | ||||
|  * | ||||
|  * This function initializes the CANN backend for the given device. | ||||
|  * It verifies the device index, allocates a context, and creates a backend | ||||
|  * instance. | ||||
|  * | ||||
|  * @param device The index of the device to initialize. | ||||
|  * @return A pointer to the initialized backend instance, or nullptr on failure. | ||||
|  */ | ||||
| GGML_API GGML_CALL ggml_backend_t ggml_backend_cann_init(int32_t device); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Checks if a given backend is a CANN backend. | ||||
|  * | ||||
|  * This function verifies if the provided backend is a CANN backend by comparing | ||||
|  * its GUID with the CANN backend's GUID. | ||||
|  * | ||||
|  * @param backend The backend instance to check. | ||||
|  * @return True if the backend is a CANN backend, false otherwise. | ||||
|  */ | ||||
| GGML_API GGML_CALL bool ggml_backend_is_cann(ggml_backend_t backend); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Retrieves the CANN buffer type for a specified device. | ||||
|  * | ||||
|  * This function initializes and returns the buffer type interface associated | ||||
|  * with the given device. It ensures thread-safe access using a mutex. | ||||
|  * | ||||
|  * @param device The device index for which to retrieve the buffer type. | ||||
|  * @return A pointer to the buffer type interface for the specified device, or | ||||
|  * nullptr if the device index is out of range. | ||||
|  */ | ||||
| GGML_API GGML_CALL ggml_backend_buffer_type_t | ||||
| ggml_backend_cann_buffer_type(int32_t device); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Retrieves the number of CANN devices available. | ||||
|  * | ||||
|  * This function returns the number of CANN devices available based on | ||||
|  * information obtained from `ggml_cann_info()`. | ||||
|  * | ||||
|  * @return The number of CANN devices available. | ||||
|  */ | ||||
| GGML_API GGML_CALL int32_t ggml_backend_cann_get_device_count(void); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Retrieves the description of a specific CANN device. | ||||
|  * | ||||
|  * This function sets the specified device, retrieves the SoC name, | ||||
|  * and writes it into the provided description buffer. | ||||
|  * | ||||
|  * @param device The device index to retrieve the description for. | ||||
|  * @param description Pointer to a buffer where the description will be written. | ||||
|  * @param description_size Size of the description buffer. | ||||
|  */ | ||||
| GGML_API GGML_CALL void ggml_backend_cann_get_device_description( | ||||
|     int32_t device, char* description, size_t description_size); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Retrieves the memory information of a specific CANN device. | ||||
|  * | ||||
|  * This function sets the specified device, retrieves the free and total | ||||
|  * memory information of the specified type (ACL_HBM_MEM), and stores them | ||||
|  * in the provided pointers. | ||||
|  * | ||||
|  * @param device The device index to retrieve memory information for. | ||||
|  * @param free Pointer to a variable where the free memory size will be stored. | ||||
|  * @param total Pointer to a variable where the total memory size will be | ||||
|  * stored. | ||||
|  */ | ||||
| GGML_API GGML_CALL void ggml_backend_cann_get_device_memory(int32_t device, | ||||
|                                                             size_t* free, | ||||
|                                                             size_t* total); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Set the logging callback for GGML. | ||||
|  * | ||||
|  * This function sets the logging callback and user data for logging. | ||||
|  * | ||||
|  * @param log_callback The logging callback to set. | ||||
|  * @param user_data User data to pass to the logging callback. | ||||
|  */ | ||||
| GGML_API void ggml_backend_cann_log_set_callback(ggml_log_callback log_callback, | ||||
|                                                  void* user_data); | ||||
| 
 | ||||
| #ifdef __cplusplus | ||||
| } | ||||
| #endif | ||||
|  | @ -753,6 +753,8 @@ extern "C" { | |||
|     GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1); | ||||
|     GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1); | ||||
| 
 | ||||
|     GGML_API bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1); | ||||
| 
 | ||||
|     // use this to compute the memory overhead of a tensor
 | ||||
|     GGML_API size_t ggml_tensor_overhead(void); | ||||
| 
 | ||||
|  | @ -2397,6 +2399,7 @@ extern "C" { | |||
|     GGML_API int ggml_cpu_has_rpc        (void); | ||||
|     GGML_API int ggml_cpu_has_vsx        (void); | ||||
|     GGML_API int ggml_cpu_has_matmul_int8(void); | ||||
|     GGML_API int ggml_cpu_has_cann       (void); | ||||
| 
 | ||||
|     //
 | ||||
|     // Internal types and functions exposed for tests and benchmarks
 | ||||
|  |  | |||
|  | @ -770,6 +770,74 @@ if (GGML_CPU_HBM) | |||
|     target_link_libraries(ggml PUBLIC memkind) | ||||
| endif() | ||||
| 
 | ||||
| if (GGML_CANN) | ||||
|     if ("cann${CANN_INSTALL_DIR}" STREQUAL "cann" AND DEFINED ENV{ASCEND_TOOLKIT_HOME}) | ||||
|         set(CANN_INSTALL_DIR $ENV{ASCEND_TOOLKIT_HOME}) | ||||
|         message(STATUS "CANN: updated CANN_INSTALL_DIR from ASCEND_TOOLKIT_HOME=$ENV{ASCEND_TOOLKIT_HOME}") | ||||
|     endif() | ||||
| 
 | ||||
|     if (CANN_INSTALL_DIR) | ||||
|         # Only Support Linux. | ||||
|         if (GGML_CANN) | ||||
|             if (NOT UNIX) | ||||
|                 set(GGML_CANN OFF) | ||||
|                 message(WARNING "CANN: CANN toolkit supports unix but not ${CMAKE_SYSTEM_NAME}. Turning off GGML_CANN") | ||||
|             endif() | ||||
|         endif() | ||||
| 
 | ||||
|         # Supported platforms: x86-64, arm64 | ||||
|         if (GGML_CANN) | ||||
|             if (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") | ||||
|             elseif (CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "amd64") | ||||
|             else() | ||||
|                 set(GGML_CANN OFF) | ||||
|                 message(WARNING "CANN: CANN toolkit supports x86-64 and arm64 but not ${CMAKE_SYSTEM_PROCESSOR}. Turning off GGML_CANN") | ||||
|             endif() | ||||
|         endif() | ||||
| 
 | ||||
|         # Set header and libs | ||||
|         if(GGML_CANN) | ||||
|             set(CANN_INCLUDE_DIRS | ||||
|                 ${CANN_INSTALL_DIR}/include | ||||
|                 ${CANN_INSTALL_DIR}/include/aclnn | ||||
|                 ${CANN_INSTALL_DIR}/acllib/include | ||||
|             ) | ||||
| 
 | ||||
|             # TODO: find libs | ||||
|             link_directories( | ||||
|                 ${CANN_INSTALL_DIR}/lib64 | ||||
|             ) | ||||
| 
 | ||||
|             add_subdirectory(ggml-cann/kernels) | ||||
|             list(APPEND CANN_LIBRARIES | ||||
|                 ascendcl | ||||
|                 nnopbase | ||||
|                 opapi | ||||
|                 acl_op_compiler | ||||
|                 ascendc_kernels | ||||
|             ) | ||||
| 
 | ||||
|             set(GGML_HEADERS_CANN "../include/ggml-cann.h") | ||||
|             file(GLOB GGML_SOURCES_CANN "ggml-cann/*.cpp") | ||||
|             list(APPEND GGML_SOURCES_CANN "ggml-cann.cpp") | ||||
| 
 | ||||
|             message(STATUS "CANN: CANN_INCLUDE_DIRS =  ${CANN_INCLUDE_DIRS}") | ||||
|             message(STATUS "CANN: CANN_LIBRARIES =  ${CANN_LIBRARIES}") | ||||
| 
 | ||||
|             set(GGML_EXTRA_LIBS     ${GGML_EXTRA_LIBS}     ${CANN_LIBRARIES} ) | ||||
|             set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} ${CANN_INCLUDE_DIRS}) | ||||
|             list(APPEND GGML_CDEF_PUBLIC GGML_USE_CANN) | ||||
|         endif() | ||||
|     else() | ||||
|         set(GGML_CANN OFF) | ||||
|         message(WARNING "CANN: Can't find CANN_INSTALL_DIR, do you forget to source set_var.sh. Turning off GGML_CANN") | ||||
|     endif() | ||||
| 
 | ||||
|     if(NOT GGML_CANN) | ||||
|         message(WARNING "CANN: GGML_CANN is turned OFF, see above for details.") | ||||
|     endif() | ||||
| endif() | ||||
| 
 | ||||
| function(get_flags CCID CCVER) | ||||
|     set(C_FLAGS "") | ||||
|     set(CXX_FLAGS "") | ||||
|  | @ -1184,6 +1252,7 @@ add_library(ggml | |||
|             ${GGML_SOURCES_ROCM}      ${GGML_HEADERS_ROCM} | ||||
|             ${GGML_SOURCES_BLAS}      ${GGML_HEADERS_BLAS} | ||||
|             ${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE} | ||||
|             ${GGML_SOURCES_CANN}      ${GGML_HEADERS_CANN} | ||||
|             ggml-aarch64.c            ggml-aarch64.h | ||||
|             ) | ||||
| 
 | ||||
|  |  | |||
|  | @ -445,6 +445,11 @@ GGML_CALL static void ggml_backend_registry_init(void) { | |||
|     extern GGML_CALL void ggml_backend_kompute_reg_devices(void); | ||||
|     ggml_backend_kompute_reg_devices(); | ||||
| #endif | ||||
| 
 | ||||
| #ifdef GGML_USE_CANN | ||||
|     extern GGML_CALL int ggml_backend_cann_reg_devices(void); | ||||
|     ggml_backend_cann_reg_devices(); | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) { | ||||
|  |  | |||
							
								
								
									
										2023
									
								
								ggml/src/ggml-cann.cpp
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										2023
									
								
								ggml/src/ggml-cann.cpp
									
										
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load diff
											
										
									
								
							
							
								
								
									
										168
									
								
								ggml/src/ggml-cann/.clang-format
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										168
									
								
								ggml/src/ggml-cann/.clang-format
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,168 @@ | |||
| --- | ||||
| Language:        Cpp | ||||
| # BasedOnStyle:  Google | ||||
| AccessModifierOffset: -1 | ||||
| AlignAfterOpenBracket: Align | ||||
| AlignConsecutiveMacros: false | ||||
| AlignConsecutiveAssignments: false | ||||
| AlignConsecutiveDeclarations: false | ||||
| AlignEscapedNewlines: Left | ||||
| AlignOperands:   true | ||||
| AlignTrailingComments: true | ||||
| AllowAllArgumentsOnNextLine: true | ||||
| AllowAllConstructorInitializersOnNextLine: true | ||||
| AllowAllParametersOfDeclarationOnNextLine: true | ||||
| AllowShortBlocksOnASingleLine: Never | ||||
| AllowShortCaseLabelsOnASingleLine: false | ||||
| AllowShortFunctionsOnASingleLine: All | ||||
| AllowShortLambdasOnASingleLine: All | ||||
| AllowShortIfStatementsOnASingleLine: WithoutElse | ||||
| AllowShortLoopsOnASingleLine: true | ||||
| AlwaysBreakAfterDefinitionReturnType: None | ||||
| AlwaysBreakAfterReturnType: None | ||||
| AlwaysBreakBeforeMultilineStrings: true | ||||
| AlwaysBreakTemplateDeclarations: Yes | ||||
| BinPackArguments: true | ||||
| BinPackParameters: true | ||||
| BraceWrapping: | ||||
|   AfterCaseLabel:  false | ||||
|   AfterClass:      false | ||||
|   AfterControlStatement: false | ||||
|   AfterEnum:       false | ||||
|   AfterFunction:   false | ||||
|   AfterNamespace:  false | ||||
|   AfterObjCDeclaration: false | ||||
|   AfterStruct:     false | ||||
|   AfterUnion:      false | ||||
|   AfterExternBlock: false | ||||
|   BeforeCatch:     false | ||||
|   BeforeElse:      false | ||||
|   IndentBraces:    false | ||||
|   SplitEmptyFunction: true | ||||
|   SplitEmptyRecord: true | ||||
|   SplitEmptyNamespace: true | ||||
| BreakBeforeBinaryOperators: None | ||||
| BreakBeforeBraces: Attach | ||||
| BreakBeforeInheritanceComma: false | ||||
| BreakInheritanceList: BeforeColon | ||||
| BreakBeforeTernaryOperators: true | ||||
| BreakConstructorInitializersBeforeComma: false | ||||
| BreakConstructorInitializers: BeforeColon | ||||
| BreakAfterJavaFieldAnnotations: false | ||||
| BreakStringLiterals: true | ||||
| ColumnLimit:     80 | ||||
| CommentPragmas:  '^ IWYU pragma:' | ||||
| CompactNamespaces: false | ||||
| ConstructorInitializerAllOnOneLineOrOnePerLine: true | ||||
| ConstructorInitializerIndentWidth: 4 | ||||
| ContinuationIndentWidth: 4 | ||||
| Cpp11BracedListStyle: true | ||||
| DeriveLineEnding: true | ||||
| DerivePointerAlignment: true | ||||
| DisableFormat:   false | ||||
| ExperimentalAutoDetectBinPacking: false | ||||
| FixNamespaceComments: true | ||||
| ForEachMacros: | ||||
|   - foreach | ||||
|   - Q_FOREACH | ||||
|   - BOOST_FOREACH | ||||
| IncludeBlocks:   Regroup | ||||
| IncludeCategories: | ||||
|   - Regex:           '^<ext/.*\.h>' | ||||
|     Priority:        2 | ||||
|     SortPriority:    0 | ||||
|   - Regex:           '^<.*\.h>' | ||||
|     Priority:        1 | ||||
|     SortPriority:    0 | ||||
|   - Regex:           '^<.*' | ||||
|     Priority:        2 | ||||
|     SortPriority:    0 | ||||
|   - Regex:           '.*' | ||||
|     Priority:        3 | ||||
|     SortPriority:    0 | ||||
| IncludeIsMainRegex: '([-_](test|unittest))?$' | ||||
| IncludeIsMainSourceRegex: '' | ||||
| IndentCaseLabels: true | ||||
| IndentGotoLabels: true | ||||
| IndentPPDirectives: None | ||||
| IndentWidth:     4 | ||||
| IndentWrappedFunctionNames: false | ||||
| JavaScriptQuotes: Leave | ||||
| JavaScriptWrapImports: true | ||||
| KeepEmptyLinesAtTheStartOfBlocks: false | ||||
| MacroBlockBegin: '' | ||||
| MacroBlockEnd:   '' | ||||
| MaxEmptyLinesToKeep: 1 | ||||
| NamespaceIndentation: None | ||||
| ObjCBinPackProtocolList: Never | ||||
| ObjCBlockIndentWidth: 2 | ||||
| ObjCSpaceAfterProperty: false | ||||
| ObjCSpaceBeforeProtocolList: true | ||||
| PenaltyBreakAssignment: 2 | ||||
| PenaltyBreakBeforeFirstCallParameter: 1 | ||||
| PenaltyBreakComment: 300 | ||||
| PenaltyBreakFirstLessLess: 120 | ||||
| PenaltyBreakString: 1000 | ||||
| PenaltyBreakTemplateDeclaration: 10 | ||||
| PenaltyExcessCharacter: 1000000 | ||||
| PenaltyReturnTypeOnItsOwnLine: 200 | ||||
| PointerAlignment: Left | ||||
| RawStringFormats: | ||||
|   - Language:        Cpp | ||||
|     Delimiters: | ||||
|       - cc | ||||
|       - CC | ||||
|       - cpp | ||||
|       - Cpp | ||||
|       - CPP | ||||
|       - 'c++' | ||||
|       - 'C++' | ||||
|     CanonicalDelimiter: '' | ||||
|     BasedOnStyle:    google | ||||
|   - Language:        TextProto | ||||
|     Delimiters: | ||||
|       - pb | ||||
|       - PB | ||||
|       - proto | ||||
|       - PROTO | ||||
|     EnclosingFunctions: | ||||
|       - EqualsProto | ||||
|       - EquivToProto | ||||
|       - PARSE_PARTIAL_TEXT_PROTO | ||||
|       - PARSE_TEST_PROTO | ||||
|       - PARSE_TEXT_PROTO | ||||
|       - ParseTextOrDie | ||||
|       - ParseTextProtoOrDie | ||||
|     CanonicalDelimiter: '' | ||||
|     BasedOnStyle:    google | ||||
| ReflowComments:  true | ||||
| SortIncludes:    true | ||||
| SortUsingDeclarations: true | ||||
| SpaceAfterCStyleCast: false | ||||
| SpaceAfterLogicalNot: false | ||||
| SpaceAfterTemplateKeyword: true | ||||
| SpaceBeforeAssignmentOperators: true | ||||
| SpaceBeforeCpp11BracedList: false | ||||
| SpaceBeforeCtorInitializerColon: true | ||||
| SpaceBeforeInheritanceColon: true | ||||
| SpaceBeforeParens: ControlStatements | ||||
| SpaceBeforeRangeBasedForLoopColon: true | ||||
| SpaceInEmptyBlock: false | ||||
| SpaceInEmptyParentheses: false | ||||
| SpacesBeforeTrailingComments: 2 | ||||
| SpacesInAngles:  false | ||||
| SpacesInConditionalStatement: false | ||||
| SpacesInContainerLiterals: true | ||||
| SpacesInCStyleCastParentheses: false | ||||
| SpacesInParentheses: false | ||||
| SpacesInSquareBrackets: false | ||||
| SpaceBeforeSquareBrackets: false | ||||
| Standard:        Auto | ||||
| StatementMacros: | ||||
|   - Q_UNUSED | ||||
|   - QT_REQUIRE_VERSION | ||||
| TabWidth:        8 | ||||
| UseCRLF:         false | ||||
| UseTab:          Never | ||||
| ... | ||||
| 
 | ||||
							
								
								
									
										2579
									
								
								ggml/src/ggml-cann/Doxyfile
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										2579
									
								
								ggml/src/ggml-cann/Doxyfile
									
										
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load diff
											
										
									
								
							
							
								
								
									
										198
									
								
								ggml/src/ggml-cann/acl_tensor.cpp
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										198
									
								
								ggml/src/ggml-cann/acl_tensor.cpp
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,198 @@ | |||
| /*
 | ||||
|  * Copyright (c) 2023-2024 The ggml authors | ||||
|  * | ||||
|  * Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
|  * of this software and associated documentation files (the "Software"), to | ||||
|  * deal in the Software without restriction, including without limitation the | ||||
|  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or | ||||
|  * sell copies of the Software, and to permit persons to whom the Software is | ||||
|  * furnished to do so, subject to the following conditions: | ||||
|  * | ||||
|  * The above copyright notice and this permission notice shall be included in | ||||
|  * all copies or substantial portions of the Software. | ||||
|  * | ||||
|  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||
|  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||
|  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||
|  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||
|  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | ||||
|  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS | ||||
|  * IN THE SOFTWARE. | ||||
|  */ | ||||
| 
 | ||||
| #include "acl_tensor.h" | ||||
| 
 | ||||
| #include <algorithm> | ||||
| #include <cstring> | ||||
| 
 | ||||
| aclDataType ggml_cann_type_mapping(ggml_type type) { | ||||
|     switch (type) { | ||||
|         case GGML_TYPE_F32: | ||||
|             return ACL_FLOAT; | ||||
|         case GGML_TYPE_F16: | ||||
|             return ACL_FLOAT16; | ||||
|         case GGML_TYPE_I8: | ||||
|             return ACL_INT8; | ||||
|         case GGML_TYPE_I16: | ||||
|             return ACL_INT16; | ||||
|         case GGML_TYPE_I32: | ||||
|             return ACL_INT32; | ||||
|         default: | ||||
|             return ACL_DT_UNDEFINED; | ||||
|     } | ||||
|     return ACL_DT_UNDEFINED; | ||||
| } | ||||
| 
 | ||||
| aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne, | ||||
|                                    size_t* nb, int64_t dims, aclFormat format, | ||||
|                                    size_t offset) { | ||||
|     // If tensor is bcasted, Up to GGML_MAX_DIMS additional dimensions will be
 | ||||
|     // added.
 | ||||
|     int64_t acl_ne[GGML_MAX_DIMS * 2], acl_stride[GGML_MAX_DIMS * 2]; | ||||
| 
 | ||||
|     int64_t acl_storage_len = 0; | ||||
|     if (ne == nullptr) { | ||||
|         acl_storage_len = ggml_nbytes(tensor); | ||||
|         for (int i = 0; i < GGML_MAX_DIMS; i++) { | ||||
|             acl_ne[i] = tensor->ne[i]; | ||||
|             // The step size of acl is in elements.
 | ||||
|             acl_stride[i] = tensor->nb[i] / ggml_element_size(tensor); | ||||
|         } | ||||
|     } else { | ||||
|         // With bcast
 | ||||
|         for (int i = 0; i < dims; i++) { | ||||
|             acl_storage_len += (ne[i] - 1) * nb[i]; | ||||
|             acl_ne[i] = ne[i]; | ||||
|             acl_stride[i] = nb[i] / ggml_element_size(tensor); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     // Reverse ne and stride.
 | ||||
|     int64_t final_dims = (dims == 0 ? GGML_MAX_DIMS : dims); | ||||
|     std::reverse(acl_ne, acl_ne + final_dims); | ||||
|     std::reverse(acl_stride, acl_stride + final_dims); | ||||
| 
 | ||||
|     aclTensor* acl_tensor = aclCreateTensor( | ||||
|         acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride, | ||||
|         offset / ggml_element_size(tensor), format, &acl_storage_len, 1, | ||||
|         tensor->data); | ||||
| 
 | ||||
|     return acl_tensor; | ||||
| } | ||||
| 
 | ||||
| bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1) { | ||||
|     for (int i = 0; i < GGML_MAX_DIMS; i++) { | ||||
|         if (t1->ne[i] != t0->ne[i] && t1->ne[i] != 1) { | ||||
|             return true; | ||||
|         } | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype, | ||||
|                                    size_t type_size, int64_t* ne, size_t* nb, | ||||
|                                    int64_t dims, aclFormat format, | ||||
|                                    size_t offset) { | ||||
|     int64_t tmp_ne[GGML_MAX_DIMS * 2]; | ||||
|     int64_t tmp_stride[GGML_MAX_DIMS * 2]; | ||||
| 
 | ||||
|     memcpy(tmp_ne, ne, dims * sizeof(int64_t)); | ||||
|     for (int i = 0; i < dims; i++) { | ||||
|         tmp_stride[i] = nb[i] / type_size; | ||||
|     } | ||||
| 
 | ||||
|     std::reverse(tmp_ne, tmp_ne + dims); | ||||
|     std::reverse(tmp_stride, tmp_stride + dims); | ||||
| 
 | ||||
|     int64_t acl_storage_len = 0; | ||||
|     for (int i = 0; i < dims; i++) { | ||||
|         acl_storage_len += (ne[i] - 1) * nb[i]; | ||||
|     } | ||||
| 
 | ||||
|     aclTensor* acl_tensor = | ||||
|         aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size, | ||||
|                         format, &acl_storage_len, 1, data_ptr); | ||||
| 
 | ||||
|     return acl_tensor; | ||||
| } | ||||
| 
 | ||||
| int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0, | ||||
|                                   const ggml_tensor* src1, | ||||
|                                   int64_t* bcast_src0_ne, | ||||
|                                   int64_t* bcast_src1_ne, size_t* bcast_src0_nb, | ||||
|                                   size_t* bcast_src1_nb) { | ||||
|     GGML_ASSERT(ggml_can_repeat(src1, src0)); | ||||
|     int bcast_dim_cnt = 0; | ||||
|     for (int i = 0; i < GGML_MAX_DIMS; i++) { | ||||
|         int64_t nr = src0->ne[i] / src1->ne[i]; | ||||
|         bcast_src0_ne[bcast_dim_cnt] = src0->ne[i] / nr; | ||||
|         bcast_src1_ne[bcast_dim_cnt] = src1->ne[i]; | ||||
|         bcast_src0_nb[bcast_dim_cnt] = src0->nb[i]; | ||||
|         bcast_src1_nb[bcast_dim_cnt] = src1->nb[i]; | ||||
|         bcast_dim_cnt++; | ||||
|         if (nr != 1) { | ||||
|             // Need to add an extra dim.
 | ||||
|             bcast_src0_ne[bcast_dim_cnt] = nr; | ||||
|             bcast_src1_ne[bcast_dim_cnt] = 1; | ||||
|             bcast_src0_nb[bcast_dim_cnt] = bcast_src0_nb[bcast_dim_cnt - 1] * | ||||
|                                            bcast_src0_ne[bcast_dim_cnt - 1]; | ||||
|             bcast_src1_nb[bcast_dim_cnt] = bcast_src1_nb[bcast_dim_cnt - 1] * | ||||
|                                            bcast_src1_ne[bcast_dim_cnt - 1]; | ||||
|             bcast_dim_cnt++; | ||||
|         } | ||||
|     } | ||||
|     return bcast_dim_cnt; | ||||
| } | ||||
| 
 | ||||
| int64_t ggml_cann_get_mulmat_bcast_shape( | ||||
|     const int64_t* input_ne, const int64_t* weight_ne, const int64_t* dst_ne, | ||||
|     const size_t* input_nb, const size_t* weight_nb, const size_t* dst_nb, | ||||
|     int64_t* bcast_input_ne, int64_t* bcast_weight_ne, int64_t* bcast_dst_ne, | ||||
|     size_t* bcast_input_nb, size_t* bcast_weight_nb, size_t* bcast_dst_nb) { | ||||
|     // input and dst shoule in same shape, except first two dims.
 | ||||
|     GGML_ASSERT(input_ne[2] == dst_ne[2]); | ||||
|     GGML_ASSERT(input_ne[3] == dst_ne[3]); | ||||
| 
 | ||||
|     int bcast_dim_cnt = 0; | ||||
| 
 | ||||
|     // For mul_mat, a dimension needs to be added before the dimension that
 | ||||
|     // weight needs to be expanded to satisfy the bcast rule of matrix
 | ||||
|     // multiplication.
 | ||||
|     for (int i = 0; i < GGML_MAX_DIMS; i++) { | ||||
|         int64_t nr = input_ne[i] / weight_ne[i]; | ||||
|         // Do not use bcast in the first two dimensions because we only support
 | ||||
|         // the bcast batch dimension. Just copy them.
 | ||||
|         if (i < 2 || nr == 1) { | ||||
|             bcast_input_ne[bcast_dim_cnt] = input_ne[i]; | ||||
|             bcast_weight_ne[bcast_dim_cnt] = weight_ne[i]; | ||||
|             bcast_dst_ne[bcast_dim_cnt] = dst_ne[i]; | ||||
| 
 | ||||
|             bcast_input_nb[bcast_dim_cnt] = input_nb[i]; | ||||
|             bcast_weight_nb[bcast_dim_cnt] = weight_nb[i]; | ||||
|             bcast_dst_nb[bcast_dim_cnt] = dst_nb[i]; | ||||
|             bcast_dim_cnt++; | ||||
|         } else { | ||||
|             // Need to add an extra dim.
 | ||||
|             bcast_input_ne[bcast_dim_cnt] = nr; | ||||
|             bcast_dst_ne[bcast_dim_cnt] = nr; | ||||
|             bcast_weight_ne[bcast_dim_cnt] = 1; | ||||
|             bcast_input_nb[bcast_dim_cnt] = input_nb[i]; | ||||
|             bcast_dst_nb[bcast_dim_cnt] = dst_nb[i]; | ||||
|             bcast_weight_nb[bcast_dim_cnt] = weight_nb[i]; | ||||
|             bcast_dim_cnt++; | ||||
| 
 | ||||
|             bcast_input_ne[bcast_dim_cnt] = input_ne[i] / nr; | ||||
|             bcast_dst_ne[bcast_dim_cnt] = dst_ne[i] / nr; | ||||
|             bcast_weight_ne[bcast_dim_cnt] = weight_ne[i]; | ||||
|             bcast_input_nb[bcast_dim_cnt] = bcast_input_nb[bcast_dim_cnt - 1] * | ||||
|                                             bcast_input_ne[bcast_dim_cnt - 1]; | ||||
|             bcast_dst_nb[bcast_dim_cnt] = bcast_dst_nb[bcast_dim_cnt - 1] * | ||||
|                                           bcast_dst_ne[bcast_dim_cnt - 1]; | ||||
|             bcast_weight_nb[bcast_dim_cnt] = | ||||
|                 bcast_weight_nb[bcast_dim_cnt - 1] * | ||||
|                 bcast_weight_ne[bcast_dim_cnt - 1]; | ||||
|             bcast_dim_cnt++; | ||||
|         } | ||||
|     } | ||||
|     return bcast_dim_cnt; | ||||
| } | ||||
							
								
								
									
										230
									
								
								ggml/src/ggml-cann/acl_tensor.h
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										230
									
								
								ggml/src/ggml-cann/acl_tensor.h
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,230 @@ | |||
| /*
 | ||||
|  * Copyright (c) 2023-2024 The ggml authors | ||||
|  * | ||||
|  * Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
|  * of this software and associated documentation files (the "Software"), to | ||||
|  * deal in the Software without restriction, including without limitation the | ||||
|  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or | ||||
|  * sell copies of the Software, and to permit persons to whom the Software is | ||||
|  * furnished to do so, subject to the following conditions: | ||||
|  * | ||||
|  * The above copyright notice and this permission notice shall be included in | ||||
|  * all copies or substantial portions of the Software. | ||||
|  * | ||||
|  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||
|  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||
|  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||
|  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||
|  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | ||||
|  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS | ||||
|  * IN THE SOFTWARE. | ||||
|  */ | ||||
| 
 | ||||
| #ifndef CANN_ACL_TENSOR_H | ||||
| #define CANN_ACL_TENSOR_H | ||||
| 
 | ||||
| #include <aclnn/aclnn_base.h> | ||||
| #include "common.h" | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief	Maps a ggml_type to its corresponding aclDataType. | ||||
|  * | ||||
|  * @details	This function takes a ggml_type as input and returns the corresponding | ||||
|  *			aclDataType. It supports mapping for various ggml_types. If the input type | ||||
|  *			does not match any of the predefined ggml_types, the function returns | ||||
|  *          ACL_DT_UNDEFINED. | ||||
|  * | ||||
|  * @param	type    The ggml_type to be mapped. | ||||
|  * @return	The corresponding aclDataType. If the input type is not recognized, | ||||
|  *			ACL_DT_UNDEFINED is returned. | ||||
|  */ | ||||
| aclDataType ggml_cann_type_mapping(ggml_type type); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Creates an ACL tensor from a ggml_tensor with optional shape. | ||||
|  * | ||||
|  * @details This function creates an ACL tensor based on the properties of the | ||||
|  *          provided ggml_tensor. It supports customer shape by adjusting dimensions | ||||
|  *          and strides accordingly. If customer shape is applied, additional | ||||
|  *          dimensions and strides are calculated based on the provided parameters. | ||||
|  * | ||||
|  * @param   tensor      Pointer to the ggml_tensor to be converted to ACL tensor. | ||||
|  * @param   ne          Pointer to an array containing dimensions. Defaults to nullptr | ||||
|  *                      if no customer shape is applied. | ||||
|  * @param   nb          Pointer to an array containing strides. Defaults to nullptr | ||||
|  *                      if no customer shape is applied. | ||||
|  * @param   dims        Number of dimensions in the tensor. Defaults to 0 if no customer | ||||
|  *                      shape is applied. | ||||
|  * @param   format      ACL tensor format. Defaults to ACL_FORMAT_ND. | ||||
|  * @param   offset      Offset in bytes for the ACL tensor data. Defaults to 0. | ||||
|  * @return  Pointer to the created ACL tensor. | ||||
|  */ | ||||
| aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne = nullptr, | ||||
|                              size_t* nb = nullptr, int64_t dims = 0, | ||||
|                              aclFormat format = ACL_FORMAT_ND, | ||||
|                              size_t offset = 0); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Creates an ACL tensor from provided parameters. | ||||
|  * | ||||
|  * @details This function creates an ACL tensor using the provided data pointer, | ||||
|  *          data type, dimensions, strides, format, offset, and additional parameters. | ||||
|  *          It calculates necessary dimensions and strides based on the provided ne and nb | ||||
|  *          arrays, adjusting them for the ACL tensor creation. The ACL storage length | ||||
|  *          is also calculated based on the provided dimensions and strides. | ||||
|  * | ||||
|  * @param   data_ptr    Pointer to the data buffer for the ACL tensor. | ||||
|  * @param   dtype       ACL data type of the tensor. | ||||
|  * @param   type_size   Size of each element in the tensor data buffer. | ||||
|  * @param   ne          Pointer to an array containing tensor dimensions. | ||||
|  * @param   nb          Pointer to an array containing tensor strides. | ||||
|  * @param   dims        Number of dimensions of the tensor. | ||||
|  * @param   format      ACL tensor format. Defaults to ACL_FORMAT_ND. | ||||
|  * @param   offset      Offset in bytes for the ACL tensor data. Defaults to 0. | ||||
|  * @return  Pointer to the created ACL tensor. | ||||
|  */ | ||||
| aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype, | ||||
|                              size_t type_size, int64_t* ne, size_t* nb, | ||||
|                              int64_t dims, aclFormat format = ACL_FORMAT_ND, | ||||
|                              size_t offset = 0); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Checks if tensors require broadcasting based on their shapes. | ||||
|  * | ||||
|  * @details This function determines if two ggml_tensors need to be broadcasted for | ||||
|  *          element-wise operations. Broadcasting is necessary if the shapes of the | ||||
|  *          tensors are not identical and no dimension in either tensor equals 1. | ||||
|  * | ||||
|  * @param   t0      Pointer to the first ggml_tensor. | ||||
|  * @param   t1      Pointer to the second ggml_tensor. | ||||
|  * @return  True if broadcasting is needed, False otherwise. | ||||
|  * | ||||
|  * @remarks This function iterates over the dimensions of t0 and t1. It checks if each | ||||
|  *          dimension in t1 differs from t0's corresponding dimension and is not equal | ||||
|  *          to 1. If such a dimension is found, broadcasting is required to align t1 | ||||
|  *          with t0 for element-wise operations. | ||||
|  */ | ||||
| bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Computes broadcast shapes and strides for two ggml_tensors. | ||||
|  * | ||||
|  * @details This function calculates the broadcast shapes and strides for two ggml_tensors, | ||||
|  *          following the broadcasting rules similar to numpy. It adjusts dimensions and | ||||
|  *          strides to ensure compatibility for element-wise operations where one tensor | ||||
|  *          can be broadcasted to match the shape of another tensor. | ||||
|  * | ||||
|  * @param   src0                Pointer to the first ggml_tensor. | ||||
|  * @param   src1                Pointer to the second ggml_tensor. | ||||
|  * @param   bcast_ne_src0       Output array to store broadcasted dimensions for src0. | ||||
|  * @param   bcast_ne_src1       Output array to store broadcasted dimensions for src1. | ||||
|  * @param   bcast_nb_src0       Output array to store broadcasted strides for src0. | ||||
|  * @param   bcast_nb_src1       Output array to store broadcasted strides for src1. | ||||
|  * @return  Number of dimensions in the broadcasted shape. | ||||
|  * | ||||
|  * @pre     ggml_can_repeat(src1, src0) must return true, indicating src1 can be broadcasted | ||||
|  *          to match src0. | ||||
|  * | ||||
|  * @remarks This function iterates over the dimensions of src0 and src1, calculating the | ||||
|  *          necessary broadcast dimensions and strides. If a dimension requires broadcasting | ||||
|  *          (i.e., its size in src1 is smaller than in src0), an additional dimension is | ||||
|  *          added with size calculated to match src0's dimension. This adjustment ensures | ||||
|  *          that src1 can be element-wise broadcasted to src0's shape. | ||||
|  * | ||||
|  *  How it works: | ||||
|  * | ||||
|  *  if dim0 has padding. | ||||
|  *  a -> (2, 2) padding = 2 | ||||
|  *   a: [[1, 2, *, *] | ||||
|  *       [2, 3, *, *]] | ||||
|  *  nb = (8, 4, 2) | ||||
|  * | ||||
|  *  if a should bcast with b -> (2, 4) | ||||
|  *  b' -> (2, 2, 2) | ||||
|  *  b : [[1, 2, 3, 4, *, *] | ||||
|  *       [5, 6, 7, 8, *, *]] | ||||
|  *  nb = (12, 6, 1) | ||||
|  * | ||||
|  *  after bcast: | ||||
|  *  a' -> (2, 1, 2) | ||||
|  *  a': [[[1, 2], *, *] | ||||
|  *       [[2, 3], *, *]] | ||||
|  *  nb = (8, 4, 2, 1) | ||||
|  * | ||||
|  *  b' : [[[1, 2], [3, 4], *, *] | ||||
|  *        [[5, 6], [7, 8], *, *]] | ||||
|  *  nb = (12, 6, 2, 1) | ||||
|  *  \endcode | ||||
|  * | ||||
|  *  dim1 in a inserted dim, should add nb for dim1, | ||||
|  *  and all other nb moves to next in order. | ||||
|  */ | ||||
| int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0, const ggml_tensor* src1, | ||||
|                         int64_t* bcast_ne_src0, int64_t* bcast_ne_src1, | ||||
|                         size_t* bcast_nb_src0, size_t* bcast_nb_src1); | ||||
| 
 | ||||
| // Bcast macro to avoid duplicate code.
 | ||||
| #define BCAST_SHAPE(src0, src1)                                              \ | ||||
|     int64_t bcast_##src0##_ne[GGML_MAX_DIMS * 2];                            \ | ||||
|     int64_t bcast_##src1##_ne[GGML_MAX_DIMS * 2];                            \ | ||||
|     size_t bcast_##src0##_nb[GGML_MAX_DIMS * 2];                             \ | ||||
|     size_t bcast_##src1##_nb[GGML_MAX_DIMS * 2];                             \ | ||||
|     int64_t bcast_dims = ggml_cann_get_bcast_shape(                          \ | ||||
|         src0, src1, bcast_##src0##_ne, bcast_##src1##_ne, bcast_##src0##_nb, \ | ||||
|         bcast_##src1##_nb); | ||||
| 
 | ||||
| #define BCAST_PARAM(tensor) bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Calculates broadcast shapes for matrix multiplication. | ||||
|  * | ||||
|  * @details This function computes the broadcast shapes required for matrix multiplication | ||||
|  *          based on the input, weight, and destination tensor shapes. It ensures that the | ||||
|  *          dimensions of weight tensors are expanded appropriately to satisfy matrix | ||||
|  *          multiplication broadcast rules. | ||||
|  * | ||||
|  * @param input_ne      Array containing the dimensions of the input tensor. | ||||
|  * @param weight_ne     Array containing the dimensions of the weight tensor. | ||||
|  * @param dst_ne        Array containing the dimensions of the destination tensor. | ||||
|  * @param input_nb      Array containing the strides of the input tensor. | ||||
|  * @param weight_nb     Array containing the strides of the weight tensor. | ||||
|  * @param dst_nb        Array containing the strides of the destination tensor. | ||||
|  * @param bcast_input_ne    Output array for broadcasted input tensor dimensions. | ||||
|  * @param bcast_weight_ne   Output array for broadcasted weight tensor dimensions. | ||||
|  * @param bcast_dst_ne      Output array for broadcasted destination tensor dimensions. | ||||
|  * @param bcast_input_nb    Output array for broadcasted input tensor strides. | ||||
|  * @param bcast_weight_nb   Output array for broadcasted weight tensor strides. | ||||
|  * @param bcast_dst_nb      Output array for broadcasted destination tensor strides. | ||||
|  * @return The number of dimensions in the broadcasted tensors. | ||||
|  * | ||||
|  * @remarks This function iterates over the tensor dimensions and calculates the broadcast | ||||
|  *          shapes needed for matrix multiplication. It ensures that dimensions where | ||||
|  *          weight tensor requires expansion are appropriately handled to conform with | ||||
|  *          broadcasting rules. | ||||
|  * @note compare with ggml_cann_get_bcast_shape, mul_mat broadcast need add this new dim | ||||
|  *       before cast dim. | ||||
|  * @sa ggml_cann_get_bcast_shape | ||||
|  */ | ||||
| int64_t ggml_cann_get_mulmat_bcast_shape( | ||||
|     const int64_t* input_ne, const int64_t* weight_ne, const int64_t* dst_ne, | ||||
|     const size_t* input_nb, const size_t* weight_nb, const size_t* dst_nb, | ||||
|     int64_t* bcast_input_ne, int64_t* bcast_weight_ne, int64_t* bcast_dst_ne, | ||||
|     size_t* bcast_input_nb, size_t* bcast_weight_nb, size_t* bcast_dst_nb); | ||||
| 
 | ||||
| // Bcast macro to avoid duplicate code.
 | ||||
| #define BCAST_MUL_MAT_SHAPE(input, weight, dst)                         \ | ||||
|     int64_t bcast_##input##_ne[GGML_MAX_DIMS * 2];                      \ | ||||
|     int64_t bcast_##weight##_ne[GGML_MAX_DIMS * 2];                     \ | ||||
|     int64_t bcast_##dst##_ne[GGML_MAX_DIMS * 2];                        \ | ||||
|     size_t bcast_##input##_nb[GGML_MAX_DIMS * 2];                       \ | ||||
|     size_t bcast_##weight##_nb[GGML_MAX_DIMS * 2];                      \ | ||||
|     size_t bcast_##dst##_nb[GGML_MAX_DIMS * 2];                         \ | ||||
|     int64_t bcast_dims = ggml_cann_get_mulmat_bcast_shape(              \ | ||||
|         input->ne, weight->ne, dst->ne, input->nb, weight->nb, dst->nb, \ | ||||
|         bcast_##input##_ne, bcast_##weight##_ne, bcast_##dst##_ne,      \ | ||||
|         bcast_##input##_nb, bcast_##weight##_nb, bcast_##dst##_nb); | ||||
| 
 | ||||
| #define BCAST_MUL_MAT_PARAM(tensor) \ | ||||
|     bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims | ||||
| 
 | ||||
| #endif  // CANN_ACL_TENSOR_H
 | ||||
							
								
								
									
										2944
									
								
								ggml/src/ggml-cann/aclnn_ops.cpp
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										2944
									
								
								ggml/src/ggml-cann/aclnn_ops.cpp
									
										
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load diff
											
										
									
								
							
							
								
								
									
										592
									
								
								ggml/src/ggml-cann/aclnn_ops.h
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										592
									
								
								ggml/src/ggml-cann/aclnn_ops.h
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,592 @@ | |||
| #ifndef CANN_ACLNN_OPS | ||||
| #define CANN_ACLNN_OPS | ||||
| 
 | ||||
| /**
 | ||||
|  * @file    acl_tensor | ||||
|  * @brief   This file contains related functions of ggml_tensor and acl_tensor. | ||||
|  *          Contains conversion from ggml_tensor to acl_tensor, broadcast and other | ||||
|  *          functions. | ||||
|  * @author  hipudding <huafengchun@gmail.com> | ||||
|  * @author  wangshuai09 <391746016@qq.com> | ||||
|  * @date    July 15, 2024 | ||||
|  * | ||||
|  * Copyright (c) 2023-2024 The ggml authors | ||||
|  * | ||||
|  * Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
|  * of this software and associated documentation files (the "Software"), to | ||||
|  * deal in the Software without restriction, including without limitation the | ||||
|  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or | ||||
|  * sell copies of the Software, and to permit persons to whom the Software is | ||||
|  * furnished to do so, subject to the following conditions: | ||||
|  * | ||||
|  * The above copyright notice and this permission notice shall be included in | ||||
|  * all copies or substantial portions of the Software. | ||||
|  * | ||||
|  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||
|  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||
|  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||
|  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||
|  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | ||||
|  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS | ||||
|  * IN THE SOFTWARE. | ||||
|  */ | ||||
| 
 | ||||
| #include <aclnnop/aclnn_add.h> | ||||
| #include <aclnnop/aclnn_arange.h> | ||||
| #include <aclnnop/aclnn_argsort.h> | ||||
| #include <aclnnop/aclnn_cat.h> | ||||
| #include <aclnnop/aclnn_clamp.h> | ||||
| #include <aclnnop/aclnn_div.h> | ||||
| #include <aclnnop/aclnn_gelu.h> | ||||
| #include <aclnnop/aclnn_hardsigmoid.h> | ||||
| #include <aclnnop/aclnn_hardswish.h> | ||||
| #include <aclnnop/aclnn_leaky_relu.h> | ||||
| #include <aclnnop/aclnn_mul.h> | ||||
| #include <aclnnop/aclnn_relu.h> | ||||
| #include <aclnnop/aclnn_silu.h> | ||||
| #include <aclnnop/aclnn_tanh.h> | ||||
| #include "acl_tensor.h" | ||||
| #include "common.h" | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Repeats a ggml tensor along each dimension to match the dimensions | ||||
|  *          of another tensor. | ||||
|  * | ||||
|  * @details This function repeats the elements of a source ggml tensor along | ||||
|  *          each dimension to create a destination tensor with the specified | ||||
|  *          dimensions. The operation is performed using the ACL backend and | ||||
|  *          executed asynchronously on the device. | ||||
|  * | ||||
|  * @param   ctx The CANN context used for operations. | ||||
|  * @param   dst The ggml tensor representing the destination, which op is | ||||
|  *              GGML_OP_REPEAT and specifies the desired dimensions. | ||||
|  */ | ||||
| void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Adds two ggml tensors using the CANN backend. | ||||
|  * | ||||
|  * @details This function performs an element-wise addition of two tensors. In | ||||
|  *          case the tensors do not have the same shape, one or both tensors | ||||
|  *          will be broadcasted to match the shape of the other before the | ||||
|  *          addition is performed.The formula for the operation is given by: | ||||
|  *          \f[ | ||||
|  *              \text{dst} = \text{acl_src0} + \alpha \cdot \text{acl_src1} | ||||
|  *          \f] | ||||
|  * | ||||
|  * @param ctx The CANN context used for operations. | ||||
|  * @param dst The ggml tensor representing the destination, result of the | ||||
|  *            addition is stored at dst->data, and dst->op is `GGML_OP_ADD` | ||||
|  */ | ||||
| void ggml_cann_add(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Applies the Leaky ReLU activation function to a tensor using the CANN | ||||
|  *          backend. | ||||
|  * | ||||
|  * @details This function computes the Leaky ReLU activation for each element of | ||||
|  *          the input tensor. The Leaky ReLU function allows a small gradient | ||||
|  *          when the unit is not active (i.e., when the input is negative). The | ||||
|  *          Leaky ReLU function is defined as: | ||||
|  *          \f[ | ||||
|  *              \text{dst} = \max(0, src) + \text{negativeSlope} \cdot \min(0, | ||||
|  *               src) | ||||
|  *          \f] | ||||
|  *          `negativeSlope` is in dst->params. | ||||
|  * | ||||
|  * @param ctx The CANN context used for operations. | ||||
|  * @param dst The destination tensor where the result of the Leaky ReLU | ||||
|  *            activation is stored, which op is `GGML_OP_LEAKY_RELU` | ||||
|  */ | ||||
| void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief    Concatenates multiple tensors along a specified dimension using the | ||||
|  *           CANN backend. | ||||
|  * | ||||
|  * @param ctx        The CANN context used for operations. | ||||
|  * @param tensorList A pointer to the list of tensors to be concatenated. | ||||
|  * @param dst        The destination tensor where the result of the | ||||
|  *                   concatenation is stored. dst->op is `GGML_OP_CONCAT`. | ||||
|  * @param concat_dim The dimension along which the tensors are concatenated. | ||||
|  * | ||||
|  * @attention tensorList length should be 2 and the dimension using for concat | ||||
|  *            default to 1. | ||||
|  */ | ||||
| void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Generates a sequence of evenly spaced values within a specified | ||||
|  *          interval for a ggml tensor using the CANN backend. | ||||
|  * | ||||
|  * @details This function creates a sequence of numbers over a specified i | ||||
|  *          nterval, starting from `start`, ending before `stop`, and | ||||
|  *          incrementing by `step`. The sequence is stored in the destination | ||||
|  *          tensor `dst`. | ||||
|  * | ||||
|  * @param ctx The CANN context used for operations. | ||||
|  * @param dst The destination tensor where the generated sequence will be stored. | ||||
|  *            `start`, 'stop' and 'step' are in dst->op_params and dst->op is | ||||
|  *            `GGML_OP_ARANGE`. | ||||
|  */ | ||||
| void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Computes the square of the elements of a ggml tensor using the CANN | ||||
|  *          backend. | ||||
|  * @details The function sets the second source tensor of the destination | ||||
|  *          tensor `dst` to be equal to the first source tensor. This is | ||||
|  *          effectively squaring the elements since the multiplication becomes | ||||
|  *          `element * element`. | ||||
|  * @param ctx The CANN context used for operations. | ||||
|  * @param dst The destination tensor where the squared values will be stored, | ||||
|  *            which dst->op is `GGML_OP_SQR`. | ||||
|  */ | ||||
| void ggml_cann_sqr(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Applies a clamp operation to the elements of a ggml tensor using the | ||||
|  *          CANN backend. | ||||
|  * | ||||
|  * @details This function clamps the elements of the input tensor `src` to a | ||||
|  *          specified range defined by `min` and `max` values. The result is | ||||
|  *          stored in the destination tensor `dst`. The operation is defined as: | ||||
|  *          \f[ | ||||
|  *              y = \max(\min(x, max\_value), min\_value) | ||||
|  *           \f] | ||||
|  *          where `x` is an element of the input tensor, and `y` is the | ||||
|  *          corresponding element in the output tensor. | ||||
|  * @param ctx The CANN context used for operations. | ||||
|  * @param dst The destination tensor where the clamped values will be stored. | ||||
|  *            dst->op is `GGML_OP_CLAMP`, `min` and `max` value is in dst->params. | ||||
|  */ | ||||
| void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Scales the elements of a ggml tensor by a constant factor using the | ||||
|  *          CANN backend. | ||||
|  * | ||||
|  * @details This function multiplies each element of the input tensor `src` by | ||||
|  *          a scaling factor `scale`, storing the result in the destination | ||||
|  *          tensor `dst`. The operation is defined as: | ||||
|  *          \f[ | ||||
|  *             dst = src \times scale | ||||
|  *          \f] | ||||
|  * | ||||
|  * @param ctx The CANN context used for operations. | ||||
|  * @param dst The destination tensor where the scaled values will be stored. | ||||
|  *            dst->op is `GGML_OP_SCALE` and `scale` value is in dst->params. | ||||
|  */ | ||||
| void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Sorts the elements of a ggml tensor and returns the indices that | ||||
|  *          would sort the tensor using the CANN backend. | ||||
|  * | ||||
|  * @details This function performs an argsort operation on the input tensor | ||||
|  *          `src`. It sorts the elements of `src` in either ascending or | ||||
|  *          descending order, depending on the `GGML_SORT_ORDER_DESC`, | ||||
|  *          and returns the indices that would sort the original tensor. | ||||
|  * | ||||
|  * @param ctx The CANN context used for operations. | ||||
|  * @param dst The destination tensor where the sorted indices will be stored. | ||||
|  *            dst->op is `GGML_OP_ARGSORT`. | ||||
|  */ | ||||
| void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Computes the Layer Normalization for a ggml tensor using the CANN | ||||
|  *          backend. | ||||
|  * | ||||
|  * @details This function applies the Layer Normalization operation on the | ||||
|  *          input tensor `src` and stores the result in the destination tensor | ||||
|  *          `dst`. Layer Normalization normalizes the features at each sample in | ||||
|  *          a mini-batch independently. It is commonly used in neural networks | ||||
|  *          to normalize the activations of a layer by adjusting and scaling | ||||
|  *          the outputs. | ||||
|  *          The operation is defined as: | ||||
|  *          \f[ | ||||
|  *              \text { out }=\frac{x-\mathrm{E}[x]}{\sqrt{\text{Var}[x]+eps}} | ||||
|  *          \f] | ||||
|  *          `Var` defaults dst->ne[0]. `eps` is in dst->params. | ||||
|  * | ||||
|  * @param ctx The CANN context used for operations. | ||||
|  * @param dst The destination tensor where the normalized values will be stored. | ||||
|  * @attention `Var` defaults to dst->ne[0]. | ||||
|  */ | ||||
| void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief  Computes the Group Normalization for a ggml tensor using the CANN | ||||
|  *         backend. | ||||
|  * | ||||
|  * @brief  This function applies the Group Normalization operation on the input | ||||
|  *         tensor `src` and stores the result in the destination tensor `dst`. | ||||
|  *         Group Normalization divides the channels into groups and normalizes | ||||
|  *         the features within each group across spatial locations. | ||||
|  *         It is commonly used in convolutional neural networks to improve | ||||
|  *         training stability and performance. | ||||
|  *         The operation is defined as: | ||||
|  *         \f[ | ||||
|  *             \text { out }=\frac{x-\mathrm{E}[x]}{\sqrt{\text{Var}[x]+eps}} | ||||
|  *         \f] | ||||
|  * | ||||
|  * @param ctx The CANN context used for operations. | ||||
|  * @param dst The destination tensor where the normalized values will be stored. | ||||
|  *            `n_groups` is in dst->params, which split C channel to `n_groups`. | ||||
|  *            dst->op is `GGML_OP_GROUP_NORM`. | ||||
|  * | ||||
|  * @attention eps defaults to 1e-6f. | ||||
|  */ | ||||
| void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Computes the accumulation of tensors using the CANN backend. | ||||
|  * | ||||
|  * @details This function performs an accumulation operation on two tensors. | ||||
|  *          Depending on the `inplace` flag, it either updates the destination | ||||
|  *          tensor `dst` in place by adding `alpha * src1` to it, or it creates | ||||
|  *          a new tensor as the result of `src0 + alpha * src1` and stores it in | ||||
|  *          `dst`. | ||||
|  *          The operation is defined as: | ||||
|  *          \f[ | ||||
|  *               dst = src0 + alpha \times src1 | ||||
|  *          \f] | ||||
|  *          if `inplace` is `true`, `src0` is equal to 'dst'. | ||||
|  * @param ctx The CANN context used for operations. | ||||
|  * @param dst The destination tensor where the accumulated values will be stored. | ||||
|  *            `inplace` is in dst->params, and dst->op is `GGML_OP_ACC`. | ||||
|  */ | ||||
| void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Computes the sum of elements along the last dimension of a ggml tensor | ||||
|  *          using the CANN backend. | ||||
|  * | ||||
|  * @details This function performs a reduction sum operation along the last | ||||
|  *          dimension of the input tensor `src`. The result of the sum is stored | ||||
|  *          in the destination tensor `dst`. | ||||
|  * | ||||
|  * @param ctx The CANN context used for operations. | ||||
|  * @param dst The destination tensor where the reduced values will be stored。 | ||||
|  *            dst->op is `GGML_OP_SUM_ROWS`. | ||||
|  * | ||||
|  * @attention `reduce_dims` defaults to 3, which means the last dimension. | ||||
|  */ | ||||
| void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Upsamples a ggml tensor using nearest neighbor interpolation using | ||||
|  *          the CANN backend. | ||||
|  * | ||||
|  * @details This function performs upsampling of the input tensor `src` using | ||||
|  *          nearest neighbor interpolation. The upsampling is applied to the | ||||
|  *          height and width dimensions (last two dimensions) of the tensor. The | ||||
|  *          result is stored in the destination tensor `dst`, which must have | ||||
|  *          the appropriate dimensions for the upsampled output. | ||||
|  * | ||||
|  * @param ctx The CANN context used for operations. | ||||
|  * @param dst The destination tensor where the upsampled values will be stored. | ||||
|  *            dst->op is `GGML_OP_UPSCALE`. | ||||
|  */ | ||||
| void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx, | ||||
|                                   ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Pads a ggml tensor to match the dimensions of the destination tensor | ||||
|  *          using the CANN backend. | ||||
|  * | ||||
|  * @details This function pads the input tensor `src` so that it matches the | ||||
|  *          dimensions of the destination tensor `dst`. The amount of padding | ||||
|  *          is calculated based on the difference in sizes between `src` and | ||||
|  *          `dst` along each dimension. The padded tensor is stored in `dst`. | ||||
|  * | ||||
|  * @param ctx The CANN context used for operations. | ||||
|  * @param dst The destination tensor, which specifies the target dimensions for | ||||
|  *            padding. dst->op is `GGML_OP_PAD`. | ||||
|  */ | ||||
| void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Executes a 2D pooling operation on a ggml tensor using the CANN | ||||
|  *          backend. | ||||
|  * | ||||
|  * @details This function dispatches the execution of a 2D pooling operation on | ||||
|  *          the input tensor `dst`. The type of pooling (average or max) is | ||||
|  *          determined by the `op` parameter, which is read from the operation | ||||
|  *          parameters of `dst`. The function supports average pooling | ||||
|  *          (`GGML_OP_POOL_AVG`) and max pooling (`GGML_OP_POOL_MAX`). If an | ||||
|  *          invalid operation is encountered, the function asserts a failure. | ||||
|  * | ||||
|  * @param ctx The CANN context used for operations. | ||||
|  * @param dst The destination tensor on which the pooling operation is to be | ||||
|  *            performed. dst->op is `GGML_OP_POOL_2D`. | ||||
|  */ | ||||
| void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Duplicates a ggml tensor using the CANN backend. | ||||
|  * | ||||
|  * @details This function duplicates the contents of the source tensor `src` to | ||||
|  *          the destination tensor `dst`. The function supports various tensor | ||||
|  *          types and configurations, including handling of extra data, type | ||||
|  *          conversions, and special cases for contiguous and non-contiguous | ||||
|  *          tensors. | ||||
|  * | ||||
|  * @param ctx The CANN context used for operations. | ||||
|  * @param dst The destination tensor where the duplicated data will be stored. | ||||
|  *            dst->op is `GGML_OP_DUP` | ||||
|  * | ||||
|  * @attention Only support Fp16/FP32. Not support when src and dst have | ||||
|  *            different shape and dst is no-contiguous. | ||||
|  * @note:     This func need to simplify. | ||||
|  */ | ||||
| void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Computes the Root Mean Square (RMS) normalization of a ggml tensor | ||||
|  *          using the CANN backend. | ||||
|  * | ||||
|  * @details This function applies RMS normalization to the input tensor `src` | ||||
|  *          and stores the result in the destination tensor `dst`. RMS | ||||
|  *          normalization involves computing the root mean square of the input | ||||
|  *          tensor along a specified dimension and then dividing each element of | ||||
|  *          the tensor by this value, adjusted by a small epsilon value to | ||||
|  *          prevent division by zero. | ||||
|  *          The operation is defined as: | ||||
|  *          \f[ | ||||
|  *               \text{RmsNorm}\left(x_i\right)=\frac{x_i}{\text{Rms}(\mathbf{x})} g_i, | ||||
|  *               \quad \text { where } \text{Rms}(\mathbf{x})=\sqrt{\frac{1}{n} \sum_{i=1}^n x_i^2+e p s} | ||||
|  *          \f] | ||||
|  *          `eps` is in dst->op_params. | ||||
|  * @param ctx The CANN context used for operations. | ||||
|  * @param dst The destination tensor where the normalized values will be stored. | ||||
|  *            dst->op is `GGML_OP_RMS_NORM`. | ||||
|  */ | ||||
| void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Applies a diagonal mask to the tensor with a specified value. | ||||
|  * | ||||
|  * @details This function creates a mask tensor filled with ones, then applies | ||||
|  *          an upper triangular and lower triangular operation to it based on | ||||
|  *          the number of past elements specified. Afterward, it adds the masked | ||||
|  *          tensor to the destination tensor in-place. | ||||
|  * | ||||
|  * @param ctx The backend CANN context used for operations. | ||||
|  * @param dst The destination tensor where the result will be stored. dst->op is | ||||
|  *            `GGML_OP_DIAG_MASK` | ||||
|  * @param value The value to use for masking. | ||||
|  */ | ||||
| void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst, float value); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Performs an image-to-column transformation on the input tensor. | ||||
|  * | ||||
|  * @details This function takes an input tensor and applies an image-to-column | ||||
|  *          operation, converting spatial dimensions into column-like | ||||
|  *          structures suitable for convolutional operations. It supports both | ||||
|  *          half-precision (F16) and single-precision (F32) floating-point data | ||||
|  *          types. | ||||
|  * | ||||
|  * @param ctx The backend CANN context for executing operations. | ||||
|  * @param dst The destination tensor that stores the result of the operation. | ||||
|  *            dst->op is `GGML_OP_IM2COL`. | ||||
|  */ | ||||
| void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Computes time step embeddings using sine and cosine functions. | ||||
|  * | ||||
|  * @details This function calculates time step embeddings by applying sine and | ||||
|  *          cosine transformations to a given input tensor, which is typically | ||||
|  *          used in temporal models like diffusion models or transformers to | ||||
|  *          encode time information effectively. | ||||
|  * | ||||
|  * @param ctx The backend CANN context for executing operations. | ||||
|  * @param dst The destination tensor where the result of the embedding operation | ||||
|  *            will be stored. dst->op is `GGML_OP_TIMESTEP_EMBEDDING`. | ||||
|  */ | ||||
| void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| // @see ggml_cann_dup.
 | ||||
| void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Computes the softmax activation with optional masking. | ||||
|  * | ||||
|  * @details This function computes the softmax activation over the input tensor, | ||||
|  *          optionally applying a mask and scaling factor. It supports both FP16 | ||||
|  *          and FP32 data types and can handle masking by broadcasting the mask | ||||
|  *          across rows if necessary. | ||||
|  *          The function performs the following steps: | ||||
|  *          1. Multiplies the input tensor by a scale factor. | ||||
|  *          2. Optionally casts the mask tensor to FP32 if it is in FP16 format. | ||||
|  *          3. Broadcasts the mask tensor if its dimensions do not match the | ||||
|  *             input tensor's dimensions. | ||||
|  *          4. Adds the mask to the scaled input tensor. | ||||
|  *          5. Applies the softmax activation function along the specified | ||||
|  *             dimension. | ||||
|  * | ||||
|  * @param ctx The backend CANN context for executing operations. | ||||
|  * @param dst The destination tensor where the result will be stored. dst->op is | ||||
|  *            `GGML_OP_SOFTMAX`. | ||||
|  */ | ||||
| void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Extracts specific rows from a tensor based on indices. | ||||
|  * | ||||
|  * @details This function retrieves rows from a source tensor src0 according to | ||||
|  *          the indices provided in another tensor src1 and stores the result in | ||||
|  *          a destination tensor (\p dst). It supports different data types | ||||
|  *          including F32, F16, Q4_0, and Q8_0. | ||||
|  * | ||||
|  * @param ctx The backend CANN context for executing operations. | ||||
|  * @param dst The destination tensor where the extracted rows will be stored. | ||||
|  *            dst->op is `GGML_OP_GET_ROWS`. | ||||
|  */ | ||||
| void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief   Executes matrix multiplication for the given tensor. | ||||
|  * | ||||
|  * @details This function performs matrix multiplication on the source tensors | ||||
|  *          associated with the destination tensor. It supports matrix | ||||
|  *          multiplication F32, F16, and Q8_0. | ||||
|  * | ||||
|  * @param ctx The backend CANN context for executing operations. | ||||
|  * @param dst The destination tensor for storing the result of the matrix | ||||
|  *            multiplication. dst->op is `GGML_OP_MUL_MAT`. | ||||
|  */ | ||||
| void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Applies Rotary Positional Embedding (RoPE) to the input tensor. | ||||
|  * | ||||
|  * @details This function implements the RoPE mechanism, which is a method to | ||||
|  *          encode positional information into sequence data, particularly | ||||
|  *          useful in transformer models. It supports both F32 and F16 data | ||||
|  *          types. | ||||
|  * | ||||
|  * @param ctx The backend CANN context for executing operations. | ||||
|  * @param dst The destination tensor where the RoPE-transformed data will be | ||||
|  *            stored. dst->op is `GGML_OP_ROPE`. | ||||
|  * | ||||
|  * @note The function currently does not support cases where the n_dims is less | ||||
|  *       than the input tensor's first dimension. | ||||
|  * @note The function currently does not support cases where the freq_factors is | ||||
|  *       not NULL. | ||||
|  * @note The function currently does not support cases where the ext_factor is | ||||
|  *       not equal 0. | ||||
|  * @note The function currently does not support cases where the freq_scale is | ||||
|  *       not equal 1. | ||||
|  */ | ||||
| void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||||
| 
 | ||||
| template <aclnnStatus getWorkspaceSize(const aclTensor*, const aclTensor*, | ||||
|                                        aclTensor*, uint64_t*, aclOpExecutor**), | ||||
|           aclnnStatus execute(void*, uint64_t, aclOpExecutor*, aclrtStream)> | ||||
| void ggml_cann_mul_div(ggml_backend_cann_context& ctx, ggml_tensor* dst) { | ||||
|     ggml_tensor* src0 = dst->src[0]; | ||||
|     ggml_tensor* src1 = dst->src[1]; | ||||
|     GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); | ||||
| 
 | ||||
|     aclTensor* acl_src0; | ||||
|     aclTensor* acl_src1; | ||||
|     aclTensor* acl_dst; | ||||
| 
 | ||||
|     // Need bcast
 | ||||
|     if (!ggml_are_same_shape(src0, src1) && ggml_cann_need_bcast(src0, src1)) { | ||||
|         BCAST_SHAPE(src0, src1) | ||||
|         acl_src0 = ggml_cann_create_tensor(src0, BCAST_PARAM(src0)); | ||||
|         acl_src1 = ggml_cann_create_tensor(src1, BCAST_PARAM(src1)); | ||||
|         acl_dst = ggml_cann_create_tensor(dst, BCAST_PARAM(src0)); | ||||
|     } else { | ||||
|         acl_src0 = ggml_cann_create_tensor(src0); | ||||
|         acl_src1 = ggml_cann_create_tensor(src1); | ||||
|         acl_dst = ggml_cann_create_tensor(dst); | ||||
|     } | ||||
| 
 | ||||
|     uint64_t workspaceSize = 0; | ||||
|     aclOpExecutor* executor; | ||||
|     void* workspaceAddr = nullptr; | ||||
| 
 | ||||
|     ACL_CHECK(getWorkspaceSize(acl_src0, acl_src1, acl_dst, &workspaceSize, | ||||
|                                &executor)); | ||||
|     if (workspaceSize > 0) { | ||||
|         ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); | ||||
|         workspaceAddr = workspace_allocator.get(); | ||||
|     } | ||||
| 
 | ||||
|     aclrtStream main_stream = ctx.stream(); | ||||
|     ACL_CHECK(execute(workspaceAddr, workspaceSize, executor, main_stream)); | ||||
| 
 | ||||
|     ACL_CHECK(aclDestroyTensor(acl_src0)); | ||||
|     ACL_CHECK(aclDestroyTensor(acl_src1)); | ||||
|     ACL_CHECK(aclDestroyTensor(acl_dst)); | ||||
| } | ||||
| 
 | ||||
| // Activation functions template.
 | ||||
| template <aclnnStatus getWorkspaceSize(const aclTensor*, aclTensor*, uint64_t*, | ||||
|                                        aclOpExecutor**), | ||||
|           aclnnStatus execute(void*, uint64_t, aclOpExecutor*, | ||||
|                               const aclrtStream)> | ||||
| void ggml_cann_activation(ggml_backend_cann_context& ctx, ggml_tensor* dst) { | ||||
|     ggml_tensor* src = dst->src[0]; | ||||
| 
 | ||||
|     GGML_ASSERT(src->type == GGML_TYPE_F32); | ||||
|     GGML_ASSERT(dst->type == GGML_TYPE_F32); | ||||
| 
 | ||||
|     aclTensor* acl_src = ggml_cann_create_tensor(src); | ||||
|     aclTensor* acl_dst = ggml_cann_create_tensor(dst); | ||||
| 
 | ||||
|     uint64_t workspaceSize = 0; | ||||
|     aclOpExecutor* executor; | ||||
|     void* workspaceAddr = nullptr; | ||||
| 
 | ||||
|     ACL_CHECK(getWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor)); | ||||
|     if (workspaceSize > 0) { | ||||
|         ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); | ||||
|         workspaceAddr = workspace_allocator.get(); | ||||
|     } | ||||
| 
 | ||||
|     aclrtStream main_stream = ctx.stream(); | ||||
|     ACL_CHECK(execute(workspaceAddr, workspaceSize, executor, main_stream)); | ||||
| 
 | ||||
|     ACL_CHECK(aclDestroyTensor(acl_src)); | ||||
|     ACL_CHECK(aclDestroyTensor(acl_dst)); | ||||
| } | ||||
| 
 | ||||
| // Activation functions template for const aclTensors.
 | ||||
| template <aclnnStatus getWorkspaceSize(const aclTensor*, const aclTensor*, | ||||
|                                        uint64_t*, aclOpExecutor**), | ||||
|           aclnnStatus execute(void*, uint64_t, aclOpExecutor*, | ||||
|                               const aclrtStream)> | ||||
| void ggml_cann_activation(ggml_backend_cann_context& ctx, ggml_tensor* dst) { | ||||
|     ggml_tensor* src = dst->src[0]; | ||||
| 
 | ||||
|     GGML_ASSERT(src->type == GGML_TYPE_F32); | ||||
|     GGML_ASSERT(dst->type == GGML_TYPE_F32); | ||||
| 
 | ||||
|     aclTensor* acl_src = ggml_cann_create_tensor(src); | ||||
|     aclTensor* acl_dst = ggml_cann_create_tensor(dst); | ||||
| 
 | ||||
|     uint64_t workspaceSize = 0; | ||||
|     aclOpExecutor* executor; | ||||
|     void* workspaceAddr = nullptr; | ||||
| 
 | ||||
|     ACL_CHECK(getWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor)); | ||||
|     if (workspaceSize > 0) { | ||||
|         ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); | ||||
|         workspaceAddr = workspace_allocator.get(); | ||||
|     } | ||||
| 
 | ||||
|     aclrtStream main_stream = ctx.stream(); | ||||
|     ACL_CHECK(execute(workspaceAddr, workspaceSize, executor, main_stream)); | ||||
| 
 | ||||
|     ACL_CHECK(aclDestroyTensor(acl_src)); | ||||
|     ACL_CHECK(aclDestroyTensor(acl_dst)); | ||||
| } | ||||
| 
 | ||||
| #endif  // CANN_ACLNN_OPS
 | ||||
							
								
								
									
										282
									
								
								ggml/src/ggml-cann/common.h
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										282
									
								
								ggml/src/ggml-cann/common.h
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,282 @@ | |||
| /*
 | ||||
|  * Copyright (c) 2023-2024 The ggml authors | ||||
|  * | ||||
|  * Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
|  * of this software and associated documentation files (the "Software"), to | ||||
|  * deal in the Software without restriction, including without limitation the | ||||
|  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or | ||||
|  * sell copies of the Software, and to permit persons to whom the Software is | ||||
|  * furnished to do so, subject to the following conditions: | ||||
|  * | ||||
|  * The above copyright notice and this permission notice shall be included in | ||||
|  * all copies or substantial portions of the Software. | ||||
|  * | ||||
|  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||
|  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||
|  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||
|  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||
|  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | ||||
|  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS | ||||
|  * IN THE SOFTWARE. | ||||
|  */ | ||||
| 
 | ||||
| #ifndef CANN_COMMON_H | ||||
| #define CANN_COMMON_H | ||||
| 
 | ||||
| #include <acl/acl.h> | ||||
| 
 | ||||
| #include <cstdio> | ||||
| #include <iostream> | ||||
| #include <map> | ||||
| #include <memory> | ||||
| #include <string> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include "../include/ggml-cann.h" | ||||
| #include "../include/ggml.h" | ||||
| 
 | ||||
| #define MATRIX_ROW_PADDING 512 | ||||
| #define GGML_CANN_MAX_STREAMS 8 | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Handles CANN-related errors by printing an error message and | ||||
|  *        terminating the program. | ||||
|  * @param stmt The statement that caused the error. | ||||
|  * @param func The function in which the error occurred. | ||||
|  * @param file The file in which the error occurred. | ||||
|  * @param line The line number at which the error occurred. | ||||
|  * @param msg The error message. | ||||
|  */ | ||||
| [[noreturn]] void ggml_cann_error(const char* stmt, const char* func, | ||||
|                                   const char* file, int line, const char* msg); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Checks the result of a CANN function call and invokes the error | ||||
|  *        handler if the call fails. | ||||
|  * @param stmt The CANN function call to check. | ||||
|  * @param success The success code that indicates the call was successful. | ||||
|  * @param error_fn The function to call to retrieve the error message. | ||||
|  */ | ||||
| #define ACL_CHECK_GEN(stmt, success, error_fn)                                \ | ||||
|     do {                                                                      \ | ||||
|         int err_code = (stmt);                                                \ | ||||
|         if (err_code != (success)) {                                          \ | ||||
|             ggml_cann_error(#stmt, __func__, __FILE__, __LINE__, error_fn()); \ | ||||
|         }                                                                     \ | ||||
|     } while (0); | ||||
| 
 | ||||
| #define ACL_CHECK(stmt) ACL_CHECK_GEN(stmt, 0, aclGetRecentErrMsg) | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Contains information about CANN devices. | ||||
|  */ | ||||
| struct ggml_cann_device_info { | ||||
|     /**
 | ||||
|      * @brief Number of CANN devices available. | ||||
|      */ | ||||
|     int32_t device_count; | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Information about a single CANN device. | ||||
|      */ | ||||
|     struct cann_device_info { | ||||
|         int cc;                 /**< Compute capability.                   */ | ||||
|         size_t smpb;            /**< Maximum shared memory per block.      */ | ||||
|         bool vmm;               /**< Virtual memory support.               */ | ||||
|         size_t vmm_granularity; /**< Granularity of virtual memory.        */ | ||||
|         size_t total_vram;      /**< Total video RAM available on the device. */ | ||||
|     }; | ||||
| 
 | ||||
|     cann_device_info devices[GGML_CANN_MAX_DEVICES] = | ||||
|         {}; /**< Array of CANN device information. */ | ||||
| }; | ||||
| 
 | ||||
| const ggml_cann_device_info& ggml_cann_info(); | ||||
| 
 | ||||
| void ggml_cann_set_device(int32_t device); | ||||
| int32_t ggml_cann_get_device(); | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Abstract base class for memory pools used by CANN. | ||||
|  */ | ||||
| struct ggml_cann_pool { | ||||
|     /**
 | ||||
|      * @brief Virtual destructor for the memory pool. | ||||
|      */ | ||||
|     virtual ~ggml_cann_pool() = default; | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Allocates memory from the pool. | ||||
|      * | ||||
|      * @param size         The size of the memory block to allocate. | ||||
|      * @param actual_size  Pointer to a variable where the actual allocated size | ||||
|      *                     will be stored. | ||||
|      * @return             Pointer to the allocated memory block. | ||||
|      */ | ||||
|     virtual void* alloc(size_t size, size_t* actual_size) = 0; | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Frees a previously allocated memory block. | ||||
|      * | ||||
|      * @param ptr   Pointer to the memory block to free. | ||||
|      * @param size  Size of the memory block to free. | ||||
|      * @note Note that all CANN opertors are running async. Make sure memory is | ||||
|      *       still avaiable before this operator finished. | ||||
|      */ | ||||
|     virtual void free(void* ptr, size_t size) = 0; | ||||
| }; | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief RAII wrapper for managing memory allocations from a CANN memory pool. | ||||
|  */ | ||||
| struct ggml_cann_pool_alloc { | ||||
|     ggml_cann_pool* pool = nullptr; /**< Pointer to the memory pool. */ | ||||
|     void* ptr = nullptr;    /**< Pointer to the allocated memory block. */ | ||||
|     size_t actual_size = 0; /**< Actual size of the allocated memory block. */ | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Default constructor. | ||||
|      */ | ||||
|     ggml_cann_pool_alloc() = default; | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Constructor that initializes the memory pool. | ||||
|      * @param pool Reference to the memory pool. | ||||
|      */ | ||||
|     explicit ggml_cann_pool_alloc(ggml_cann_pool& pool) : pool(&pool) {} | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Constructor that initializes the memory pool and allocates memory. | ||||
|      * @param pool Reference to the memory pool. | ||||
|      * @param size Size of the memory block to allocate. | ||||
|      */ | ||||
|     ggml_cann_pool_alloc(ggml_cann_pool& pool, size_t size) : pool(&pool) { | ||||
|         alloc(size); | ||||
|     } | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Destructor that frees the allocated memory block. | ||||
|      */ | ||||
|     ~ggml_cann_pool_alloc() { | ||||
|         if (ptr != nullptr) { | ||||
|             pool->free(ptr, actual_size); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Allocates memory from the pool. | ||||
|      * @param size Size of the memory block to allocate. | ||||
|      * @return Pointer to the allocated memory block. | ||||
|      */ | ||||
|     void* alloc(size_t size) { | ||||
|         GGML_ASSERT(pool != nullptr); | ||||
|         GGML_ASSERT(ptr == nullptr); | ||||
|         ptr = pool->alloc(size, &this->actual_size); | ||||
|         return ptr; | ||||
|     } | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Allocates memory from a specific memory pool. | ||||
|      * @param pool Reference to the memory pool. | ||||
|      * @param size Size of the memory block to allocate. | ||||
|      * @return Pointer to the allocated memory block. | ||||
|      */ | ||||
|     void* alloc(ggml_cann_pool& pool, size_t size) { | ||||
|         this->pool = &pool; | ||||
|         return alloc(size); | ||||
|     } | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Gets the pointer to the allocated memory block. | ||||
|      * @return Pointer to the allocated memory block. | ||||
|      */ | ||||
|     void* get() { return ptr; } | ||||
| 
 | ||||
|     // Deleted copy constructor
 | ||||
|     ggml_cann_pool_alloc(const ggml_cann_pool_alloc&) = delete; | ||||
| 
 | ||||
|     // Deleted move constructor
 | ||||
|     ggml_cann_pool_alloc(ggml_cann_pool_alloc&&) = delete; | ||||
| 
 | ||||
|     // Deleted copy assignment operator
 | ||||
|     ggml_cann_pool_alloc& operator=(const ggml_cann_pool_alloc&) = delete; | ||||
| 
 | ||||
|     // Deleted move assignment operator
 | ||||
|     ggml_cann_pool_alloc& operator=(ggml_cann_pool_alloc&&) = delete; | ||||
| }; | ||||
| 
 | ||||
| /**
 | ||||
|  * @brief Context for managing CANN backend operations. | ||||
|  */ | ||||
| struct ggml_backend_cann_context { | ||||
|     int32_t device;                  /**< Device ID. */ | ||||
|     std::string name;                /**< Name of the device. */ | ||||
|     aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */ | ||||
| 
 | ||||
|     aclrtStream streams[GGML_CANN_MAX_STREAMS] = { | ||||
|         {nullptr}}; /**< Array of streams for the device. */ | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Constructor for initializing the context with a given device. | ||||
|      * @param device Device ID. | ||||
|      */ | ||||
|     explicit ggml_backend_cann_context(int device) | ||||
|         : device(device), name("CANN" + std::to_string(device)) {} | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Destructor for cleaning up resources. | ||||
|      */ | ||||
|     ~ggml_backend_cann_context() { | ||||
|         if (copy_event != nullptr) { | ||||
|             ACL_CHECK(aclrtDestroyEvent(copy_event)); | ||||
|         } | ||||
|         for (int i = 0; i < GGML_CANN_MAX_STREAMS; ++i) { | ||||
|             if (streams[i] != nullptr) { | ||||
|                 ACL_CHECK(aclrtDestroyStream(streams[i])); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Get or create a stream for a given index. | ||||
|      * @param stream Index of the stream. | ||||
|      * @return The stream corresponding to the given index. | ||||
|      */ | ||||
|     aclrtStream stream(int stream) { | ||||
|         if (streams[stream] == nullptr) { | ||||
|             ggml_cann_set_device(device); | ||||
|             ACL_CHECK(aclrtCreateStream(&streams[stream])); | ||||
|         } | ||||
|         return streams[stream]; | ||||
|     } | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Get or create the default stream (index 0). | ||||
|      * @return The default stream. | ||||
|      */ | ||||
|     aclrtStream stream() { return stream(0); } | ||||
| 
 | ||||
|     // TODO: each stream should have a memory pool.
 | ||||
|     std::unique_ptr<ggml_cann_pool> | ||||
|         mem_pool; /**< Memory pool for the device. */ | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Create a new memory pool for a given device. | ||||
|      * @param device Device ID. | ||||
|      * @return A unique pointer to the new memory pool. | ||||
|      */ | ||||
|     static std::unique_ptr<ggml_cann_pool> new_pool_for_device(int device); | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Get or create the memory pool for the context. | ||||
|      * @return Reference to the memory pool. | ||||
|      */ | ||||
|     ggml_cann_pool& pool() { | ||||
|         if (mem_pool == nullptr) { | ||||
|             mem_pool = new_pool_for_device(device); | ||||
|         } | ||||
|         return *mem_pool; | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| #endif  // CANN_COMMON_H
 | ||||
							
								
								
									
										32
									
								
								ggml/src/ggml-cann/kernels/CMakeLists.txt
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								ggml/src/ggml-cann/kernels/CMakeLists.txt
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,32 @@ | |||
| if (NOT SOC_TYPE) | ||||
|     set (SOC_TYPE "Ascend910B3") | ||||
| endif() | ||||
| 
 | ||||
| file(GLOB SRC_FILES | ||||
|     get_row_f32.cpp | ||||
|     get_row_f16.cpp | ||||
|     get_row_q4_0.cpp | ||||
|     get_row_q8_0.cpp | ||||
|     quantize_f32_q8_0.cpp | ||||
|     quantize_f16_q8_0.cpp | ||||
|     dup.cpp | ||||
| ) | ||||
| 
 | ||||
| string(TOLOWER ${SOC_TYPE} SOC_VERSION) | ||||
| set(ASCEND_CANN_PACKAGE_PATH ${CANN_INSTALL_DIR}) | ||||
| set(RUN_MODE "npu" CACHE STRING "run mode: npu/sim") | ||||
| 
 | ||||
| if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake) | ||||
|     set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake) | ||||
| elseif(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/ascendc_devkit/tikcpp/samples/cmake) | ||||
|     set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/ascendc_devkit/tikcpp/samples/cmake) | ||||
| else() | ||||
|     message(FATAL_ERROR "ascendc_kernel_cmake does not exist, please check whether the compiler package is installed.") | ||||
| endif() | ||||
| include(${ASCENDC_CMAKE_DIR}/ascendc.cmake) | ||||
| 
 | ||||
| ascendc_library(ascendc_kernels STATIC | ||||
|     ${SRC_FILES} | ||||
| ) | ||||
| 
 | ||||
| #ascendc_compile_definitions(ascendc_kernels PRIVATE -DASCENDC_DUMP) | ||||
							
								
								
									
										17
									
								
								ggml/src/ggml-cann/kernels/ascendc_kernels.h
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								ggml/src/ggml-cann/kernels/ascendc_kernels.h
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,17 @@ | |||
| #ifndef ASCENDC_KERNELS_H | ||||
| #define ASCENDC_KERNELS_H | ||||
| 
 | ||||
| #include "aclrtlaunch_ascendc_get_row_f32.h" | ||||
| #include "aclrtlaunch_ascendc_get_row_f16.h" | ||||
| #include "aclrtlaunch_ascendc_get_row_q8_0.h" | ||||
| #include "aclrtlaunch_ascendc_get_row_q4_0.h" | ||||
| 
 | ||||
| #include "aclrtlaunch_ascendc_quantize_f32_q8_0.h" | ||||
| #include "aclrtlaunch_ascendc_quantize_f16_q8_0.h" | ||||
| 
 | ||||
| #include "aclrtlaunch_ascendc_dup_by_rows_fp16.h" | ||||
| #include "aclrtlaunch_ascendc_dup_by_rows_fp32.h" | ||||
| #include "aclrtlaunch_ascendc_dup_by_rows_fp32_to_fp16.h" | ||||
| #include "aclrtlaunch_ascendc_dup_by_rows_fp16_to_fp32.h" | ||||
| 
 | ||||
| #endif  // ASCENDC_KERNELS_H
 | ||||
							
								
								
									
										223
									
								
								ggml/src/ggml-cann/kernels/dup.cpp
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										223
									
								
								ggml/src/ggml-cann/kernels/dup.cpp
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,223 @@ | |||
| #include "kernel_operator.h" | ||||
| 
 | ||||
| #include <cmath> | ||||
| 
 | ||||
| using namespace AscendC; | ||||
| 
 | ||||
| #define BUFFER_NUM 2 | ||||
| 
 | ||||
| template <typename SRC_T, typename DST_T> | ||||
| class DupByRows { | ||||
|    public: | ||||
|     __aicore__ inline DupByRows() {} | ||||
|     __aicore__ inline void init(GM_ADDR src, GM_ADDR dst, int64_t *input_ne_ub, | ||||
|                                 size_t *input_nb_ub) { | ||||
|         /* Dup by rows when src is contigous on first dimension and dst is
 | ||||
|         contiguous, each kernel process one row. | ||||
|         */ | ||||
| 
 | ||||
|         // Input has four dims.
 | ||||
|         int64_t op_block_num = GetBlockNum(); | ||||
|         int64_t op_block_idx = GetBlockIdx(); | ||||
| 
 | ||||
|         // param
 | ||||
|         num_rows = input_ne_ub[1] * input_ne_ub[2] * input_ne_ub[3]; | ||||
|         num_elem = input_ne_ub[0]; | ||||
| 
 | ||||
|         // index for (ne[1], ne[2], ne[3]): (idx_ne1, idx_ne2, idx_ne3)
 | ||||
|         idx_ne3 = op_block_idx / (input_ne_ub[1] * input_ne_ub[2]); | ||||
|         idx_ne2 = (op_block_idx - idx_ne3 * (input_ne_ub[1] * input_ne_ub[2])) | ||||
|                   / (input_ne_ub[1]); | ||||
|         idx_ne1 = op_block_idx - idx_ne3 * (input_ne_ub[1] * input_ne_ub[2]) | ||||
|                 - idx_ne2 * input_ne_ub[1]; | ||||
| 
 | ||||
|         // src may not contiguous in dim [1,2,3], so stride decited by ne&nb
 | ||||
|         src_stride = input_nb_ub[3] * idx_ne3 + input_nb_ub[2] * idx_ne2 | ||||
|                      + input_nb_ub[1] * idx_ne1; | ||||
| 
 | ||||
|         // dst is contiguous
 | ||||
|         dst_stride = op_block_idx * (input_ne_ub[0] * sizeof(DST_T)); | ||||
| 
 | ||||
|         src_gm.SetGlobalBuffer(reinterpret_cast<__gm__ SRC_T *>(src + | ||||
|                                                                 src_stride)); | ||||
|         dst_gm.SetGlobalBuffer(reinterpret_cast<__gm__ DST_T *>(dst + | ||||
|                                                                 dst_stride)); | ||||
| 
 | ||||
|         pipe.InitBuffer(src_queue, BUFFER_NUM, (sizeof(SRC_T) * num_elem + | ||||
|                                                 32 - 1) / 32 * 32); | ||||
|         pipe.InitBuffer(dst_queue, BUFFER_NUM, (sizeof(DST_T) * num_elem + | ||||
|                                                 32 - 1) / 32 * 32); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void copy_in() { | ||||
|         LocalTensor<SRC_T> src_local = src_queue.AllocTensor<SRC_T>(); | ||||
| 
 | ||||
|         DataCopyExtParams dataCopyParams; | ||||
|         dataCopyParams.blockCount = 1; | ||||
|         dataCopyParams.blockLen = num_elem * sizeof(SRC_T); | ||||
|         DataCopyPadExtParams<SRC_T> padParams; | ||||
|         DataCopyPad(src_local, src_gm, dataCopyParams, padParams); | ||||
| 
 | ||||
|         src_queue.EnQue(src_local); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void copy_out() { | ||||
|         LocalTensor<DST_T> dst_local = dst_queue.DeQue<DST_T>(); | ||||
| 
 | ||||
|         DataCopyExtParams dataCopyParams; | ||||
|         dataCopyParams.blockCount = 1; | ||||
|         dataCopyParams.blockLen = num_elem * sizeof(DST_T); | ||||
|         DataCopyPad(dst_gm, dst_local, dataCopyParams); | ||||
| 
 | ||||
|         dst_queue.FreeTensor(dst_local); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void dup() { | ||||
|         // main process, copy one row data from src to dst.
 | ||||
|         copy_in(); | ||||
| 
 | ||||
|         LocalTensor<SRC_T> src_local = src_queue.DeQue<SRC_T>(); | ||||
|         LocalTensor<DST_T> dst_local = dst_queue.AllocTensor<DST_T>(); | ||||
| 
 | ||||
|         int32_t BLOCK_NUM = 32 / sizeof(DST_T); | ||||
|         DataCopy(dst_local, src_local, (num_elem + BLOCK_NUM - 1) | ||||
|                                         / BLOCK_NUM * BLOCK_NUM); | ||||
|         dst_queue.EnQue<DST_T>(dst_local); | ||||
| 
 | ||||
|         src_queue.FreeTensor(src_local); | ||||
|         copy_out(); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void dup_with_cast() { | ||||
|         // main process, copy one row data from src to dst.
 | ||||
|         // cast dtype from src to dst.
 | ||||
|         copy_in(); | ||||
| 
 | ||||
|         LocalTensor<SRC_T> src_local = src_queue.DeQue<SRC_T>(); | ||||
|         LocalTensor<DST_T> dst_local = dst_queue.AllocTensor<DST_T>(); | ||||
| 
 | ||||
|         Cast(dst_local, src_local, RoundMode::CAST_NONE, num_elem); | ||||
|         dst_queue.EnQue<DST_T>(dst_local); | ||||
| 
 | ||||
|         src_queue.FreeTensor(src_local); | ||||
|         copy_out(); | ||||
|     } | ||||
| 
 | ||||
|    private: | ||||
| 
 | ||||
|     TPipe pipe; | ||||
|     GlobalTensor<SRC_T> src_gm; | ||||
|     GlobalTensor<DST_T> dst_gm; | ||||
| 
 | ||||
|     int64_t num_rows; | ||||
|     int64_t num_elem; | ||||
|     int64_t idx_ne3; | ||||
|     int64_t idx_ne2; | ||||
|     int64_t idx_ne1; | ||||
|     int64_t src_stride; | ||||
|     int64_t dst_stride; | ||||
| 
 | ||||
|     TQue<QuePosition::VECIN, BUFFER_NUM> src_queue; | ||||
|     TQue<QuePosition::VECOUT, BUFFER_NUM> dst_queue; | ||||
| }; | ||||
| 
 | ||||
| template <typename T> | ||||
| __aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) { | ||||
|     auto gm_ptr = (__gm__ uint8_t *)gm; | ||||
|     auto ub_ptr = (uint8_t *)(ub); | ||||
|     for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) { | ||||
|         *ub_ptr = *gm_ptr; | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp16( | ||||
|                                                         GM_ADDR src_gm, | ||||
|                                                         GM_ADDR dst_gm, | ||||
|                                                         GM_ADDR input_ne_gm, | ||||
|                                                         GM_ADDR input_nb_gm, | ||||
|                                                         GM_ADDR output_ne_gm, | ||||
|                                                         GM_ADDR output_nb_gm) { | ||||
| 
 | ||||
|     int64_t input_ne_ub[4]; | ||||
|     size_t input_nb_ub[4]; | ||||
|     int64_t output_ne_ub[4]; | ||||
|     size_t output_nb_ub[4]; | ||||
| 
 | ||||
|     copy_to_ub(input_ne_gm, input_ne_ub, 32); | ||||
|     copy_to_ub(input_nb_gm, input_nb_ub, 32); | ||||
|     copy_to_ub(output_ne_gm, output_ne_ub, 32); | ||||
|     copy_to_ub(output_nb_gm, output_nb_ub, 32); | ||||
| 
 | ||||
|     DupByRows<half, half> op; | ||||
|     op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub); | ||||
|     op.dup(); | ||||
| } | ||||
| 
 | ||||
| extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp32( | ||||
|                                                         GM_ADDR src_gm, | ||||
|                                                         GM_ADDR dst_gm, | ||||
|                                                         GM_ADDR input_ne_gm, | ||||
|                                                         GM_ADDR input_nb_gm, | ||||
|                                                         GM_ADDR output_ne_gm, | ||||
|                                                         GM_ADDR output_nb_gm) { | ||||
|     int64_t input_ne_ub[4]; | ||||
|     size_t input_nb_ub[4]; | ||||
|     int64_t output_ne_ub[4]; | ||||
|     size_t output_nb_ub[4]; | ||||
| 
 | ||||
|     copy_to_ub(input_ne_gm, input_ne_ub, 32); | ||||
|     copy_to_ub(input_nb_gm, input_nb_ub, 32); | ||||
|     copy_to_ub(output_ne_gm, output_ne_ub, 32); | ||||
|     copy_to_ub(output_nb_gm, output_nb_ub, 32); | ||||
| 
 | ||||
|     DupByRows<float_t, float_t> op; | ||||
|     op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub); | ||||
|     op.dup(); | ||||
| } | ||||
| 
 | ||||
| extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp32_to_fp16( | ||||
|                                                         GM_ADDR src_gm, | ||||
|                                                         GM_ADDR dst_gm, | ||||
|                                                         GM_ADDR input_ne_gm, | ||||
|                                                         GM_ADDR input_nb_gm, | ||||
|                                                         GM_ADDR output_ne_gm, | ||||
|                                                         GM_ADDR output_nb_gm) { | ||||
| 
 | ||||
|     int64_t input_ne_ub[4]; | ||||
|     size_t input_nb_ub[4]; | ||||
|     int64_t output_ne_ub[4]; | ||||
|     size_t output_nb_ub[4]; | ||||
| 
 | ||||
|     copy_to_ub(input_ne_gm, input_ne_ub, 32); | ||||
|     copy_to_ub(input_nb_gm, input_nb_ub, 32); | ||||
|     copy_to_ub(output_ne_gm, output_ne_ub, 32); | ||||
|     copy_to_ub(output_nb_gm, output_nb_ub, 32); | ||||
| 
 | ||||
|     DupByRows<float_t, half> op; | ||||
|     op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub); | ||||
|     op.dup_with_cast(); | ||||
| } | ||||
| 
 | ||||
| extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp16_to_fp32( | ||||
|                                                         GM_ADDR src_gm, | ||||
|                                                         GM_ADDR dst_gm, | ||||
|                                                         GM_ADDR input_ne_gm, | ||||
|                                                         GM_ADDR input_nb_gm, | ||||
|                                                         GM_ADDR output_ne_gm, | ||||
|                                                         GM_ADDR output_nb_gm) { | ||||
| 
 | ||||
|     // copy params from gm to ub.
 | ||||
|     int64_t input_ne_ub[4]; | ||||
|     size_t input_nb_ub[4]; | ||||
|     int64_t output_ne_ub[4]; | ||||
|     size_t output_nb_ub[4]; | ||||
| 
 | ||||
|     copy_to_ub(input_ne_gm, input_ne_ub, 32); | ||||
|     copy_to_ub(input_nb_gm, input_nb_ub, 32); | ||||
|     copy_to_ub(output_ne_gm, output_ne_ub, 32); | ||||
|     copy_to_ub(output_nb_gm, output_nb_ub, 32); | ||||
| 
 | ||||
|     DupByRows<half, float_t> op; | ||||
|     op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub); | ||||
|     op.dup_with_cast(); | ||||
| } | ||||
							
								
								
									
										186
									
								
								ggml/src/ggml-cann/kernels/get_row_f16.cpp
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										186
									
								
								ggml/src/ggml-cann/kernels/get_row_f16.cpp
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,186 @@ | |||
| #include "kernel_operator.h" | ||||
| 
 | ||||
| // optimize me. Use template to avoid copy code.
 | ||||
| using namespace AscendC; | ||||
| 
 | ||||
| #define BUFFER_NUM 2 | ||||
| 
 | ||||
| class GET_ROW_F16 { | ||||
|    public: | ||||
|     __aicore__ inline GET_ROW_F16() {} | ||||
|     __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output, | ||||
|                                 int64_t *input_ne_ub, size_t *input_nb_ub, | ||||
|                                 int64_t *indices_ne_ub, size_t *indices_nb_ub, | ||||
|                                 int64_t *output_ne_ub, size_t *output_nb_ub) { | ||||
|         // TODO, use template for F16/f32
 | ||||
|         int64_t op_block_num = GetBlockNum(); | ||||
|         int64_t op_block_idx = GetBlockIdx(); | ||||
| 
 | ||||
|         for (int i = 0; i < 4; i++) { | ||||
|             input_ne[i] = input_ne_ub[i]; | ||||
|             input_stride[i] = input_nb_ub[i] / input_nb_ub[0]; | ||||
| 
 | ||||
|             indices_ne[i] = indices_ne_ub[i]; | ||||
|             indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0]; | ||||
| 
 | ||||
|             output_ne[i] = output_ne_ub[i]; | ||||
|             output_stride[i] = output_nb_ub[i] / output_nb_ub[0]; | ||||
|         } | ||||
| 
 | ||||
|         // Indices has two dims. n_elements = all rows should get.
 | ||||
|         // dr, all rows should this thread get.
 | ||||
|         uint64_t n_elements = | ||||
|             indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3]; | ||||
|         dr = n_elements / op_block_num; | ||||
| 
 | ||||
|         uint64_t tails = n_elements % op_block_num; | ||||
|         if (op_block_idx < tails) { | ||||
|             dr += 1; | ||||
|             ir = dr * op_block_idx; | ||||
|         } else { | ||||
|             ir = dr * op_block_idx + tails; | ||||
|         } | ||||
| 
 | ||||
|         input_gm.SetGlobalBuffer((__gm__ half *)input); | ||||
|         indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices); | ||||
|         output_gm.SetGlobalBuffer((__gm__ float *)output); | ||||
| 
 | ||||
|         uint64_t input_local_buffer_size = ((input_ne[0] * sizeof(half) + 31) | ||||
|                                              & ~31); | ||||
|         uint64_t output_local_buffer_size = ((input_ne[0] * sizeof(float) + 31) | ||||
|                                               & ~31); | ||||
| 
 | ||||
|         local_buffer_elems = input_local_buffer_size / sizeof(half); | ||||
| 
 | ||||
|         // TODO, consider long row that can't put in UB.
 | ||||
|         // All data should asign to 32. It's ok because all data is align to 32.
 | ||||
|         pipe.InitBuffer(input_queue, BUFFER_NUM, input_local_buffer_size); | ||||
|         pipe.InitBuffer(output_queue, BUFFER_NUM, output_local_buffer_size); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void copy_in(uint32_t offset, size_t len) { | ||||
|         LocalTensor<half> input_local = input_queue.AllocTensor<half>(); | ||||
|         size_t tail = len % 32; | ||||
|         len = len & ~31; | ||||
|         DataCopy(input_local, input_gm[offset], len); | ||||
|         if(tail != 0) { | ||||
|             DataCopyExtParams dataCopyParams; | ||||
|             dataCopyParams.blockCount = 1; | ||||
|             dataCopyParams.blockLen = tail * sizeof(half); | ||||
|             DataCopyPadExtParams<half> padParams; | ||||
|             DataCopyPad(input_local[len], input_gm[offset + len], | ||||
|                         dataCopyParams, padParams); | ||||
|         } | ||||
|         input_queue.EnQue(input_local); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void copy_out(uint32_t offset, size_t len) { | ||||
|         LocalTensor<float> output_local = output_queue.DeQue<float>(); | ||||
|         size_t tail = len % 32; | ||||
|         len = len & ~31; | ||||
|         DataCopy(output_gm[offset], output_local, len); | ||||
|         if(tail != 0) { | ||||
|             DataCopyExtParams dataCopyParams; | ||||
|             dataCopyParams.blockCount = 1; | ||||
|             dataCopyParams.blockLen = tail * sizeof(float); | ||||
|             DataCopyPad(output_gm[offset + len], output_local[len], | ||||
|                         dataCopyParams); | ||||
|         } | ||||
|         output_queue.FreeTensor(output_local); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void calculate_row(int64_t idx) { | ||||
|         const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]); | ||||
|         const int64_t indices_ne1_idx = | ||||
|             (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) / | ||||
|             indices_ne[0]; | ||||
|         const int64_t indices_ne0_idx = | ||||
|             (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] - | ||||
|              indices_ne1_idx * indices_ne[0]); | ||||
| 
 | ||||
|         const int64_t indices_offset = indices_ne0_idx * indices_stride[0] + | ||||
|                                        indices_ne1_idx * indices_stride[1] + | ||||
|                                        indices_ne2_idx * indices_stride[2]; | ||||
|         const int32_t selected_row_idx = indices_gm.GetValue(indices_offset); | ||||
| 
 | ||||
|         const int64_t input_offset = selected_row_idx * input_stride[1] + | ||||
|                                      indices_ne1_idx * input_stride[2] + | ||||
|                                      indices_ne2_idx * input_stride[3]; | ||||
| 
 | ||||
|         const int64_t output_offset = indices_ne0_idx * output_stride[1] + | ||||
|                                       indices_ne1_idx * output_stride[2] + | ||||
|                                       indices_ne2_idx * output_stride[3]; | ||||
| 
 | ||||
|         copy_in(input_offset, input_ne[0]); | ||||
|         LocalTensor<half> input_local = input_queue.DeQue<half>(); | ||||
|         LocalTensor<float> output_local = output_queue.AllocTensor<float>(); | ||||
| 
 | ||||
|         Cast(output_local, input_local, RoundMode::CAST_NONE, | ||||
|              local_buffer_elems); | ||||
|         output_queue.EnQue(output_local); | ||||
|         copy_out(output_offset, input_ne[0]); | ||||
| 
 | ||||
|         input_queue.FreeTensor(input_local); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void calculate() { | ||||
|         for (int64_t i = ir; i < ir + dr; i++) { | ||||
|             calculate_row(i); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|    private: | ||||
|     int64_t input_ne[4]; | ||||
|     size_t input_stride[4]; | ||||
| 
 | ||||
|     int64_t indices_ne[4]; | ||||
|     size_t indices_stride[4]; | ||||
| 
 | ||||
|     int64_t output_ne[4]; | ||||
|     size_t output_stride[4]; | ||||
| 
 | ||||
|     size_t local_buffer_elems; | ||||
| 
 | ||||
|     int64_t ir; | ||||
|     int64_t dr; | ||||
| 
 | ||||
|     TPipe pipe; | ||||
|     GlobalTensor<half> input_gm; | ||||
|     GlobalTensor<int32_t> indices_gm; | ||||
|     GlobalTensor<float> output_gm; | ||||
|     TQue<QuePosition::VECIN, BUFFER_NUM> input_queue; | ||||
|     TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue; | ||||
| }; | ||||
| 
 | ||||
| template <typename T> | ||||
| __aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) { | ||||
|     auto gm_ptr = (__gm__ uint8_t *)gm; | ||||
|     auto ub_ptr = (uint8_t *)(ub); | ||||
|     for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) { | ||||
|         *ub_ptr = *gm_ptr; | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| extern "C" __global__ __aicore__ void ascendc_get_row_f16( | ||||
|     GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm, | ||||
|     GM_ADDR input_ne_gm, GM_ADDR input_nb_gm, GM_ADDR indices_ne_gm, | ||||
|     GM_ADDR indices_nb_gm, GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) { | ||||
|     int64_t input_ne_ub[4]; | ||||
|     size_t input_nb_ub[4]; | ||||
|     int64_t indices_ne_ub[4]; | ||||
|     size_t indices_nb_ub[4]; | ||||
|     int64_t output_ne_ub[4]; | ||||
|     size_t output_nb_ub[4]; | ||||
| 
 | ||||
|     copy_to_ub(input_ne_gm, input_ne_ub, 32); | ||||
|     copy_to_ub(input_nb_gm, input_nb_ub, 32); | ||||
|     copy_to_ub(indices_ne_gm, indices_ne_ub, 32); | ||||
|     copy_to_ub(indices_nb_gm, indices_nb_ub, 32); | ||||
|     copy_to_ub(output_ne_gm, output_ne_ub, 32); | ||||
|     copy_to_ub(output_nb_gm, output_nb_ub, 32); | ||||
| 
 | ||||
|     GET_ROW_F16 op; | ||||
|     op.init(input_gm, indices_gm, output_gm, input_ne_ub, input_nb_ub, | ||||
|             indices_ne_ub, indices_nb_ub, output_ne_ub, output_nb_ub); | ||||
|     op.calculate(); | ||||
| } | ||||
							
								
								
									
										180
									
								
								ggml/src/ggml-cann/kernels/get_row_f32.cpp
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										180
									
								
								ggml/src/ggml-cann/kernels/get_row_f32.cpp
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,180 @@ | |||
| #include "kernel_operator.h" | ||||
| 
 | ||||
| // optimize me. Use template to avoid copy code.
 | ||||
| using namespace AscendC; | ||||
| 
 | ||||
| #define BUFFER_NUM 2 | ||||
| 
 | ||||
| class GET_ROW_F32 { | ||||
|    public: | ||||
|     __aicore__ inline GET_ROW_F32() {} | ||||
|     __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output, | ||||
|                                 int64_t *input_ne_ub, size_t *input_nb_ub, | ||||
|                                 int64_t *indices_ne_ub, size_t *indices_nb_ub, | ||||
|                                 int64_t *output_ne_ub, size_t *output_nb_ub) { | ||||
|         int64_t op_block_num = GetBlockNum(); | ||||
|         int64_t op_block_idx = GetBlockIdx(); | ||||
| 
 | ||||
|         for (int i = 0; i < 4; i++) { | ||||
|             input_ne[i] = input_ne_ub[i]; | ||||
|             input_stride[i] = input_nb_ub[i] / input_nb_ub[0]; | ||||
| 
 | ||||
|             indices_ne[i] = indices_ne_ub[i]; | ||||
|             indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0]; | ||||
| 
 | ||||
|             output_ne[i] = output_ne_ub[i]; | ||||
|             output_stride[i] = output_nb_ub[i] / output_nb_ub[0]; | ||||
|         } | ||||
| 
 | ||||
|         // Indices has two dims. n_elements = all rows should get.
 | ||||
|         // dr, all rows should this thread get.
 | ||||
|         uint64_t n_elements = | ||||
|             indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3]; | ||||
|         dr = n_elements / op_block_num; | ||||
| 
 | ||||
|         uint64_t tails = n_elements % op_block_num; | ||||
|         if (op_block_idx < tails) { | ||||
|             dr += 1; | ||||
|             ir = dr * op_block_idx; | ||||
|         } else { | ||||
|             ir = dr * op_block_idx + tails; | ||||
|         } | ||||
| 
 | ||||
|         input_gm.SetGlobalBuffer((__gm__ float *)input); | ||||
|         indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices); | ||||
|         output_gm.SetGlobalBuffer((__gm__ float *)output); | ||||
| 
 | ||||
|         uint64_t local_buffer_size = ((input_ne[0] * sizeof(float) + 31) & ~31); | ||||
|         local_buffer_elems = local_buffer_size / sizeof(float); | ||||
| 
 | ||||
|         // TODO, consider long row that can't put in UB.
 | ||||
|         // All data should asign to 32. It's ok because all data is align to 32.
 | ||||
|         pipe.InitBuffer(input_queue, BUFFER_NUM, local_buffer_size); | ||||
|         pipe.InitBuffer(output_queue, BUFFER_NUM, local_buffer_size); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void copy_in(uint32_t offset, size_t len) { | ||||
|         LocalTensor<float> input_local = input_queue.AllocTensor<float>(); | ||||
|         size_t tail = len % 32; | ||||
|         len = len & ~31; | ||||
|         DataCopy(input_local, input_gm[offset], len); | ||||
|         if(tail != 0) { | ||||
|             DataCopyExtParams dataCopyParams; | ||||
|             dataCopyParams.blockCount = 1; | ||||
|             dataCopyParams.blockLen = tail * sizeof(float); | ||||
|             DataCopyPadExtParams<float> padParams; | ||||
|             DataCopyPad(input_local[len], input_gm[offset + len], | ||||
|                         dataCopyParams, padParams); | ||||
|         } | ||||
|         input_queue.EnQue(input_local); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void copy_out(uint32_t offset, size_t len) { | ||||
|         LocalTensor<float> output_local = output_queue.DeQue<float>(); | ||||
|         size_t tail = len % 32; | ||||
|         len = len & ~31; | ||||
|         DataCopy(output_gm[offset], output_local, len); | ||||
|         if(tail != 0) { | ||||
|             DataCopyExtParams dataCopyParams; | ||||
|             dataCopyParams.blockCount = 1; | ||||
|             dataCopyParams.blockLen = tail * sizeof(float); | ||||
|             DataCopyPad(output_gm[offset + len], output_local[len], | ||||
|                         dataCopyParams); | ||||
|         } | ||||
|         output_queue.FreeTensor(output_local); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void calculate_row(int64_t idx) { | ||||
|         const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]); | ||||
|         const int64_t indices_ne1_idx = | ||||
|             (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) / | ||||
|             indices_ne[0]; | ||||
|         const int64_t indices_ne0_idx = | ||||
|             (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] - | ||||
|              indices_ne1_idx * indices_ne[0]); | ||||
| 
 | ||||
|         const int64_t indices_offset = indices_ne0_idx * indices_stride[0] + | ||||
|                                        indices_ne1_idx * indices_stride[1] + | ||||
|                                        indices_ne2_idx * indices_stride[2]; | ||||
|         const int32_t selected_row_idx = indices_gm.GetValue(indices_offset); | ||||
| 
 | ||||
|         const int64_t input_offset = selected_row_idx * input_stride[1] + | ||||
|                                      indices_ne1_idx * input_stride[2] + | ||||
|                                      indices_ne2_idx * input_stride[3]; | ||||
| 
 | ||||
|         const int64_t output_offset = indices_ne0_idx * output_stride[1] + | ||||
|                                       indices_ne1_idx * output_stride[2] + | ||||
|                                       indices_ne2_idx * output_stride[3]; | ||||
| 
 | ||||
|         copy_in(input_offset, input_ne[0]); | ||||
|         LocalTensor<float> input_local = input_queue.DeQue<float>(); | ||||
|         LocalTensor<float> output_local = output_queue.AllocTensor<float>(); | ||||
| 
 | ||||
|         DataCopy(output_local, input_local, local_buffer_elems); | ||||
|         output_queue.EnQue(output_local); | ||||
|         copy_out(output_offset, input_ne[0]); | ||||
| 
 | ||||
|         input_queue.FreeTensor(input_local); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void calculate() { | ||||
|         for (int64_t i = ir; i < ir + dr; i++) { | ||||
|             calculate_row(i); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|    private: | ||||
|     int64_t input_ne[4]; | ||||
|     size_t input_stride[4]; | ||||
| 
 | ||||
|     int64_t indices_ne[4]; | ||||
|     size_t indices_stride[4]; | ||||
| 
 | ||||
|     int64_t output_ne[4]; | ||||
|     size_t output_stride[4]; | ||||
| 
 | ||||
|     size_t local_buffer_elems; | ||||
| 
 | ||||
|     int64_t ir; | ||||
|     int64_t dr; | ||||
| 
 | ||||
|     TPipe pipe; | ||||
|     GlobalTensor<float> input_gm; | ||||
|     GlobalTensor<int32_t> indices_gm; | ||||
|     GlobalTensor<float> output_gm; | ||||
|     TQue<QuePosition::VECIN, BUFFER_NUM> input_queue; | ||||
|     TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue; | ||||
| }; | ||||
| 
 | ||||
| template <typename T> | ||||
| __aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) { | ||||
|     auto gm_ptr = (__gm__ uint8_t *)gm; | ||||
|     auto ub_ptr = (uint8_t *)(ub); | ||||
|     for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) { | ||||
|         *ub_ptr = *gm_ptr; | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| extern "C" __global__ __aicore__ void ascendc_get_row_f32( | ||||
|     GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm, | ||||
|     GM_ADDR input_ne_gm, GM_ADDR input_nb_gm, GM_ADDR indices_ne_gm, | ||||
|     GM_ADDR indices_nb_gm, GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) { | ||||
|     int64_t input_ne_ub[4]; | ||||
|     size_t input_nb_ub[4]; | ||||
|     int64_t indices_ne_ub[4]; | ||||
|     size_t indices_nb_ub[4]; | ||||
|     int64_t output_ne_ub[4]; | ||||
|     size_t output_nb_ub[4]; | ||||
| 
 | ||||
|     copy_to_ub(input_ne_gm, input_ne_ub, 32); | ||||
|     copy_to_ub(input_nb_gm, input_nb_ub, 32); | ||||
|     copy_to_ub(indices_ne_gm, indices_ne_ub, 32); | ||||
|     copy_to_ub(indices_nb_gm, indices_nb_ub, 32); | ||||
|     copy_to_ub(output_ne_gm, output_ne_ub, 32); | ||||
|     copy_to_ub(output_nb_gm, output_nb_ub, 32); | ||||
| 
 | ||||
|     GET_ROW_F32 op; | ||||
|     op.init(input_gm, indices_gm, output_gm, input_ne_ub, input_nb_ub, | ||||
|             indices_ne_ub, indices_nb_ub, output_ne_ub, output_nb_ub); | ||||
|     op.calculate(); | ||||
| } | ||||
							
								
								
									
										193
									
								
								ggml/src/ggml-cann/kernels/get_row_q4_0.cpp
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										193
									
								
								ggml/src/ggml-cann/kernels/get_row_q4_0.cpp
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,193 @@ | |||
| #include "kernel_operator.h" | ||||
| 
 | ||||
| // optimize me. Use template to avoid copy code.
 | ||||
| using namespace AscendC; | ||||
| 
 | ||||
| #define BUFFER_NUM 2 | ||||
| 
 | ||||
| #define QK4_0 32 | ||||
| 
 | ||||
| class GET_ROW_Q4_0 { | ||||
|    public: | ||||
|     __aicore__ inline GET_ROW_Q4_0() {} | ||||
|     __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output, | ||||
|                                 int64_t *input_ne_ub, int64_t *indices_ne_ub, | ||||
|                                 size_t *indices_nb_ub, int64_t *output_ne_ub, | ||||
|                                 size_t *output_nb_ub) { | ||||
|         int64_t op_block_num = GetBlockNum(); | ||||
|         int64_t op_block_idx = GetBlockIdx(); | ||||
| 
 | ||||
|         for (int i = 0; i < 4; i++) { | ||||
|             input_ne[i] = input_ne_ub[i]; | ||||
|             indices_ne[i] = indices_ne_ub[i]; | ||||
|             indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0]; | ||||
|             scale_ne[i] = input_ne_ub[i]; | ||||
|             output_ne[i] = output_ne_ub[i]; | ||||
|             output_stride[i] = output_nb_ub[i] / output_nb_ub[0]; | ||||
|         } | ||||
| 
 | ||||
|         // one scale for a group.
 | ||||
|         scale_ne[0] /= QK4_0; | ||||
| 
 | ||||
|         input_stride[0] = 1; | ||||
|         scale_stride[0] = 1; | ||||
|         output_stride[0] = 1; | ||||
|         for (int i = 1; i < 4; i++) { | ||||
|             input_stride[i] = input_stride[i - 1] * input_ne[i - 1]; | ||||
|             scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1]; | ||||
|         } | ||||
| 
 | ||||
|         group_size_in_row = input_ne[0] / QK4_0; | ||||
|         int64_t scale_offset = input_ne[0] * input_ne[1] * input_ne[2] * | ||||
|                                input_ne[3] / 2; | ||||
| 
 | ||||
|         // Indices has two dims. n_elements = all rows should get.
 | ||||
|         // dr, all rows should this thread get.
 | ||||
|         uint64_t n_elements = | ||||
|             indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3]; | ||||
|         dr = n_elements / op_block_num; | ||||
| 
 | ||||
|         uint64_t tails = n_elements % op_block_num; | ||||
|         if (op_block_idx < tails) { | ||||
|             dr += 1; | ||||
|             ir = dr * op_block_idx; | ||||
|         } else { | ||||
|             ir = dr * op_block_idx + tails; | ||||
|         } | ||||
| 
 | ||||
|         input_gm.SetGlobalBuffer((__gm__ int4b_t *)input); | ||||
|         scale_gm.SetGlobalBuffer((__gm__ half *)(input + scale_offset)); | ||||
|         indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices); | ||||
|         output_gm.SetGlobalBuffer((__gm__ float *)output); | ||||
| 
 | ||||
|         pipe.InitBuffer(input_queue, BUFFER_NUM, QK4_0 * sizeof(int4b_t)); | ||||
|         pipe.InitBuffer(cast_queue, BUFFER_NUM, QK4_0 * sizeof(half)); | ||||
|         pipe.InitBuffer(output_queue, BUFFER_NUM, QK4_0 * sizeof(float)); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void copy_in(uint32_t offset) { | ||||
|         LocalTensor<int4b_t> input_local = input_queue.AllocTensor<int4b_t>(); | ||||
|         // 32 * sizeof(int4b_t) = 16, which is not aligned to 32, why no error?
 | ||||
|         DataCopy(input_local, input_gm[offset], QK4_0); | ||||
|         input_queue.EnQue(input_local); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void copy_out(uint32_t offset) { | ||||
|         LocalTensor<float> output_local = output_queue.DeQue<float>(); | ||||
|         DataCopy(output_gm[offset], output_local, QK4_0); | ||||
|         output_queue.FreeTensor(output_local); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void calculate_group(int64_t idx, int64_t group) { | ||||
|         const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]); | ||||
|         const int64_t indices_ne1_idx = | ||||
|             (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) / | ||||
|             indices_ne[0]; | ||||
|         const int64_t indices_ne0_idx = | ||||
|             (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] - | ||||
|              indices_ne1_idx * indices_ne[0]); | ||||
| 
 | ||||
|         const int64_t indices_offset = indices_ne0_idx * indices_stride[0] + | ||||
|                                        indices_ne1_idx * indices_stride[1] + | ||||
|                                        indices_ne2_idx * indices_stride[2]; | ||||
|         const int32_t selected_row_idx = indices_gm.GetValue(indices_offset); | ||||
| 
 | ||||
|         const int64_t input_offset = selected_row_idx * input_stride[1] + | ||||
|                                      indices_ne1_idx * input_stride[2] + | ||||
|                                      indices_ne2_idx * input_stride[3] + | ||||
|                                      group * QK4_0; | ||||
|         const int64_t scale_offset = selected_row_idx * scale_stride[1] + | ||||
|                                      indices_ne1_idx * scale_stride[2] + | ||||
|                                      indices_ne2_idx * scale_stride[3] + group; | ||||
|         const int64_t output_offset = indices_ne0_idx * output_stride[1] + | ||||
|                                       indices_ne1_idx * output_stride[2] + | ||||
|                                       indices_ne2_idx * output_stride[3] + | ||||
|                                       group * QK4_0; | ||||
| 
 | ||||
|         copy_in(input_offset); | ||||
|         LocalTensor<int4b_t> input_local = input_queue.DeQue<int4b_t>(); | ||||
|         LocalTensor<half> cast_local = cast_queue.AllocTensor<half>(); | ||||
|         LocalTensor<float> output_local = output_queue.AllocTensor<float>(); | ||||
| 
 | ||||
|         // TODO: cast more data to speed up.
 | ||||
|         Cast(cast_local, input_local, RoundMode::CAST_NONE, QK4_0); | ||||
|         Cast(output_local, cast_local, RoundMode::CAST_NONE, QK4_0); | ||||
| 
 | ||||
|         // Only mul need compile by group.
 | ||||
|         half scale = scale_gm.GetValue(scale_offset); | ||||
| 
 | ||||
|         Muls(output_local, output_local, (float)scale, QK4_0); | ||||
| 
 | ||||
|         input_queue.FreeTensor(input_local); | ||||
|         cast_queue.FreeTensor(cast_local); | ||||
|         output_queue.EnQue(output_local); | ||||
| 
 | ||||
|         copy_out(output_offset); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void calculate() { | ||||
|         for (int64_t i = ir; i < ir + dr; i++) { | ||||
|             for (int64_t j = 0; j < group_size_in_row; j++) { | ||||
|                 calculate_group(i, j); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|    private: | ||||
|     int64_t input_ne[4]; | ||||
|     size_t input_stride[4]; | ||||
| 
 | ||||
|     int64_t scale_ne[4]; | ||||
|     size_t scale_stride[4]; | ||||
| 
 | ||||
|     int64_t indices_ne[4]; | ||||
|     size_t indices_stride[4]; | ||||
| 
 | ||||
|     int64_t output_ne[4]; | ||||
|     size_t output_stride[4]; | ||||
| 
 | ||||
|     int64_t ir; | ||||
|     int64_t dr; | ||||
| 
 | ||||
|     int64_t group_size_in_row; | ||||
| 
 | ||||
|     TPipe pipe; | ||||
|     GlobalTensor<int4b_t> input_gm; | ||||
|     GlobalTensor<half> scale_gm; | ||||
|     GlobalTensor<int32_t> indices_gm; | ||||
|     GlobalTensor<float> output_gm; | ||||
|     TQue<QuePosition::VECIN, BUFFER_NUM> input_queue; | ||||
|     TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue; | ||||
|     TQue<QuePosition::VECIN, BUFFER_NUM> cast_queue; | ||||
| }; | ||||
| 
 | ||||
| template <typename T> | ||||
| __aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) { | ||||
|     auto gm_ptr = (__gm__ uint8_t *)gm; | ||||
|     auto ub_ptr = (uint8_t *)(ub); | ||||
|     for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) { | ||||
|         *ub_ptr = *gm_ptr; | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| extern "C" __global__ __aicore__ void ascendc_get_row_q4_0( | ||||
|     GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm, | ||||
|     GM_ADDR input_ne_gm, GM_ADDR indices_ne_gm, GM_ADDR indices_nb_gm, | ||||
|     GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) { | ||||
|     int64_t input_ne_ub[4]; | ||||
|     int64_t indices_ne_ub[4]; | ||||
|     size_t indices_nb_ub[4]; | ||||
|     int64_t output_ne_ub[4]; | ||||
|     size_t output_nb_ub[4]; | ||||
| 
 | ||||
|     copy_to_ub(input_ne_gm, input_ne_ub, 32); | ||||
|     copy_to_ub(indices_ne_gm, indices_ne_ub, 32); | ||||
|     copy_to_ub(indices_nb_gm, indices_nb_ub, 32); | ||||
|     copy_to_ub(output_ne_gm, output_ne_ub, 32); | ||||
|     copy_to_ub(output_nb_gm, output_nb_ub, 32); | ||||
| 
 | ||||
|     GET_ROW_Q4_0 op; | ||||
|     op.init(input_gm, indices_gm, output_gm, input_ne_ub, indices_ne_ub, | ||||
|             indices_nb_ub, output_ne_ub, output_nb_ub); | ||||
|     op.calculate(); | ||||
| } | ||||
							
								
								
									
										191
									
								
								ggml/src/ggml-cann/kernels/get_row_q8_0.cpp
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										191
									
								
								ggml/src/ggml-cann/kernels/get_row_q8_0.cpp
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,191 @@ | |||
| #include "kernel_operator.h" | ||||
| 
 | ||||
| // optimize me. Use template to avoid copy code.
 | ||||
| using namespace AscendC; | ||||
| 
 | ||||
| #define BUFFER_NUM 2 | ||||
| 
 | ||||
| #define QK8_0 32 | ||||
| 
 | ||||
| class GET_ROW_Q8_0 { | ||||
|    public: | ||||
|     __aicore__ inline GET_ROW_Q8_0() {} | ||||
|     __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output, | ||||
|                                 int64_t *input_ne_ub, int64_t *indices_ne_ub, | ||||
|                                 size_t *indices_nb_ub, int64_t *output_ne_ub, | ||||
|                                 size_t *output_nb_ub) { | ||||
|         int64_t op_block_num = GetBlockNum(); | ||||
|         int64_t op_block_idx = GetBlockIdx(); | ||||
| 
 | ||||
|         for (int i = 0; i < 4; i++) { | ||||
|             input_ne[i] = input_ne_ub[i]; | ||||
|             indices_ne[i] = indices_ne_ub[i]; | ||||
|             indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0]; | ||||
|             scale_ne[i] = input_ne_ub[i]; | ||||
|             output_ne[i] = output_ne_ub[i]; | ||||
|             output_stride[i] = output_nb_ub[i] / output_nb_ub[0]; | ||||
|         } | ||||
| 
 | ||||
|         // one scale for a group.
 | ||||
|         scale_ne[0] /= QK8_0; | ||||
| 
 | ||||
|         input_stride[0] = 1; | ||||
|         scale_stride[0] = 1; | ||||
|         output_stride[0] = 1; | ||||
|         for (int i = 1; i < 4; i++) { | ||||
|             input_stride[i] = input_stride[i - 1] * input_ne[i - 1]; | ||||
|             scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1]; | ||||
|         } | ||||
| 
 | ||||
|         group_size_in_row = input_ne[0] / QK8_0; | ||||
|         int64_t scale_offset = input_ne[0] * input_ne[1] * input_ne[2] * | ||||
|                                input_ne[3] * sizeof(int8_t); | ||||
| 
 | ||||
|         // Indices has two dims. n_elements = all rows should get.
 | ||||
|         // dr, all rows should this thread get.
 | ||||
|         uint64_t n_elements = | ||||
|             indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3]; | ||||
|         dr = n_elements / op_block_num; | ||||
| 
 | ||||
|         uint64_t tails = n_elements % op_block_num; | ||||
|         if (op_block_idx < tails) { | ||||
|             dr += 1; | ||||
|             ir = dr * op_block_idx; | ||||
|         } else { | ||||
|             ir = dr * op_block_idx + tails; | ||||
|         } | ||||
| 
 | ||||
|         input_gm.SetGlobalBuffer((__gm__ int8_t *)input); | ||||
|         scale_gm.SetGlobalBuffer((__gm__ half *)(input + scale_offset)); | ||||
|         indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices); | ||||
|         output_gm.SetGlobalBuffer((__gm__ float *)output); | ||||
| 
 | ||||
|         pipe.InitBuffer(input_queue, BUFFER_NUM, QK8_0 * sizeof(int8_t)); | ||||
|         pipe.InitBuffer(cast_queue, BUFFER_NUM, QK8_0 * sizeof(half)); | ||||
|         pipe.InitBuffer(output_queue, BUFFER_NUM, QK8_0 * sizeof(float)); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void copy_in(uint32_t offset) { | ||||
|         LocalTensor<int8_t> input_local = input_queue.AllocTensor<int8_t>(); | ||||
|         DataCopy(input_local, input_gm[offset], QK8_0); | ||||
|         input_queue.EnQue(input_local); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void copy_out(uint32_t offset) { | ||||
|         LocalTensor<float> output_local = output_queue.DeQue<float>(); | ||||
|         DataCopy(output_gm[offset], output_local, QK8_0); | ||||
|         output_queue.FreeTensor(output_local); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void calculate_group(int64_t idx, int64_t group) { | ||||
|         const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]); | ||||
|         const int64_t indices_ne1_idx = | ||||
|             (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) / | ||||
|             indices_ne[0]; | ||||
|         const int64_t indices_ne0_idx = | ||||
|             (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] - | ||||
|              indices_ne1_idx * indices_ne[0]); | ||||
| 
 | ||||
|         const int64_t indices_offset = indices_ne0_idx * indices_stride[0] + | ||||
|                                        indices_ne1_idx * indices_stride[1] + | ||||
|                                        indices_ne2_idx * indices_stride[2]; | ||||
|         const int32_t selected_row_idx = indices_gm.GetValue(indices_offset); | ||||
| 
 | ||||
|         const int64_t input_offset = selected_row_idx * input_stride[1] + | ||||
|                                      indices_ne1_idx * input_stride[2] + | ||||
|                                      indices_ne2_idx * input_stride[3] + | ||||
|                                      group * QK8_0; | ||||
|         const int64_t scale_offset = selected_row_idx * scale_stride[1] + | ||||
|                                      indices_ne1_idx * scale_stride[2] + | ||||
|                                      indices_ne2_idx * scale_stride[3] + group; | ||||
|         const int64_t output_offset = indices_ne0_idx * output_stride[1] + | ||||
|                                       indices_ne1_idx * output_stride[2] + | ||||
|                                       indices_ne2_idx * output_stride[3] + | ||||
|                                       group * QK8_0; | ||||
| 
 | ||||
|         copy_in(input_offset); | ||||
|         LocalTensor<int8_t> input_local = input_queue.DeQue<int8_t>(); | ||||
|         LocalTensor<half> cast_local = cast_queue.AllocTensor<half>(); | ||||
|         LocalTensor<float> output_local = output_queue.AllocTensor<float>(); | ||||
| 
 | ||||
|         // TODO: cast more data to speed up.
 | ||||
|         Cast(cast_local, input_local, RoundMode::CAST_NONE, QK8_0); | ||||
|         Cast(output_local, cast_local, RoundMode::CAST_NONE, QK8_0); | ||||
| 
 | ||||
|         // Only mul need compile by group.
 | ||||
|         half scale = scale_gm.GetValue(scale_offset); | ||||
|         Muls(output_local, output_local, (float)scale, QK8_0); | ||||
| 
 | ||||
|         input_queue.FreeTensor(input_local); | ||||
|         cast_queue.FreeTensor(cast_local); | ||||
|         output_queue.EnQue(output_local); | ||||
| 
 | ||||
|         copy_out(output_offset); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void calculate() { | ||||
|         for (int64_t i = ir; i < ir + dr; i++) { | ||||
|             for (int64_t j = 0; j < group_size_in_row; j++) { | ||||
|                 calculate_group(i, j); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|    private: | ||||
|     int64_t input_ne[4]; | ||||
|     size_t input_stride[4]; | ||||
| 
 | ||||
|     int64_t scale_ne[4]; | ||||
|     size_t scale_stride[4]; | ||||
| 
 | ||||
|     int64_t indices_ne[4]; | ||||
|     size_t indices_stride[4]; | ||||
| 
 | ||||
|     int64_t output_ne[4]; | ||||
|     size_t output_stride[4]; | ||||
| 
 | ||||
|     int64_t ir; | ||||
|     int64_t dr; | ||||
| 
 | ||||
|     int64_t group_size_in_row; | ||||
| 
 | ||||
|     TPipe pipe; | ||||
|     GlobalTensor<int8_t> input_gm; | ||||
|     GlobalTensor<half> scale_gm; | ||||
|     GlobalTensor<int32_t> indices_gm; | ||||
|     GlobalTensor<float> output_gm; | ||||
|     TQue<QuePosition::VECIN, BUFFER_NUM> input_queue; | ||||
|     TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue; | ||||
|     TQue<QuePosition::VECIN, BUFFER_NUM> cast_queue; | ||||
| }; | ||||
| 
 | ||||
| template <typename T> | ||||
| __aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) { | ||||
|     auto gm_ptr = (__gm__ uint8_t *)gm; | ||||
|     auto ub_ptr = (uint8_t *)(ub); | ||||
|     for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) { | ||||
|         *ub_ptr = *gm_ptr; | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| extern "C" __global__ __aicore__ void ascendc_get_row_q8_0( | ||||
|     GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm, | ||||
|     GM_ADDR input_ne_gm, GM_ADDR indices_ne_gm, GM_ADDR indices_nb_gm, | ||||
|     GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) { | ||||
|     int64_t input_ne_ub[4]; | ||||
|     int64_t indices_ne_ub[4]; | ||||
|     size_t indices_nb_ub[4]; | ||||
|     int64_t output_ne_ub[4]; | ||||
|     size_t output_nb_ub[4]; | ||||
| 
 | ||||
|     copy_to_ub(input_ne_gm, input_ne_ub, 32); | ||||
|     copy_to_ub(indices_ne_gm, indices_ne_ub, 32); | ||||
|     copy_to_ub(indices_nb_gm, indices_nb_ub, 32); | ||||
|     copy_to_ub(output_ne_gm, output_ne_ub, 32); | ||||
|     copy_to_ub(output_nb_gm, output_nb_ub, 32); | ||||
| 
 | ||||
|     GET_ROW_Q8_0 op; | ||||
|     op.init(input_gm, indices_gm, output_gm, input_ne_ub, indices_ne_ub, | ||||
|             indices_nb_ub, output_ne_ub, output_nb_ub); | ||||
|     op.calculate(); | ||||
| } | ||||
							
								
								
									
										208
									
								
								ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										208
									
								
								ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,208 @@ | |||
| #include "kernel_operator.h" | ||||
| 
 | ||||
| using namespace AscendC; | ||||
| 
 | ||||
| #define BUFFER_NUM 2 | ||||
| #define QK8_0 32 | ||||
| 
 | ||||
| class QUANTIZE_F16_Q8_0 { | ||||
|    public: | ||||
|     __aicore__ inline QUANTIZE_F16_Q8_0() {} | ||||
|     __aicore__ inline void init(GM_ADDR input, GM_ADDR output, | ||||
|                                 int64_t *input_ne_ub, size_t *input_nb_ub, | ||||
|                                 int64_t *output_ne_ub) { | ||||
|         int64_t op_block_num = GetBlockNum(); | ||||
|         int64_t op_block_idx = GetBlockIdx(); | ||||
| 
 | ||||
|         for (int i = 0; i < 4; i++) { | ||||
|             input_ne[i] = input_ne_ub[i]; | ||||
|             input_stride[i] = input_nb_ub[i] / input_nb_ub[0]; | ||||
| 
 | ||||
|             output_ne[i] = output_ne_ub[i]; | ||||
|         } | ||||
| 
 | ||||
|         output_stride[0] = 1; | ||||
|         for (int i = 1; i < 4; i++) { | ||||
|             output_stride[i] = output_stride[i - 1] * output_ne[i - 1]; | ||||
|         } | ||||
| 
 | ||||
|         scale_ne = input_ne; | ||||
|         scale_stride[0] = 1; | ||||
|         scale_stride[1] = input_ne[0] / QK8_0; | ||||
|         for (int i = 2; i < 4; i++) { | ||||
|             scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1]; | ||||
|         } | ||||
| 
 | ||||
|         // split input tensor by rows.
 | ||||
|         uint64_t nr = input_ne[1] * input_ne[2] * input_ne[3]; | ||||
|         dr = nr / op_block_num; | ||||
| 
 | ||||
|         uint64_t tails = nr % op_block_num; | ||||
|         if (op_block_idx < tails) { | ||||
|             dr += 1; | ||||
|             ir = dr * op_block_idx; | ||||
|         } else { | ||||
|             ir = dr * op_block_idx + tails; | ||||
|         } | ||||
| 
 | ||||
|         group_size_in_row = scale_stride[1]; | ||||
|         int64_t output_size = output_ne[0] * output_ne[1] * output_ne[2] * | ||||
|                               output_ne[3] * sizeof(uint8_t); | ||||
| 
 | ||||
|         input_gm.SetGlobalBuffer((__gm__ half *)input); | ||||
|         output_gm.SetGlobalBuffer((__gm__ int8_t *)output); | ||||
|         scale_gm.SetGlobalBuffer((__gm__ half *)(output + output_size + ir * | ||||
|                                                  group_size_in_row * | ||||
|                                                  sizeof(half))); | ||||
| 
 | ||||
|         pipe.InitBuffer(input_queue, BUFFER_NUM, QK8_0 * sizeof(half)); | ||||
|         pipe.InitBuffer(output_queue, BUFFER_NUM, QK8_0 * sizeof(int8_t)); | ||||
|         pipe.InitBuffer(work_queue, 1, 32); | ||||
|         pipe.InitBuffer(max_queue, 1, 32); | ||||
|         pipe.InitBuffer(abs_queue, 1, QK8_0 * sizeof(float)); | ||||
|         pipe.InitBuffer(scale_queue, 1, 32); | ||||
|         pipe.InitBuffer(cast_queue ,1 ,QK8_0 * sizeof(float)); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void copy_in(uint32_t offset) { | ||||
|         LocalTensor<half> input_local = input_queue.AllocTensor<half>(); | ||||
|         DataCopy(input_local, input_gm[offset], QK8_0); | ||||
|         input_queue.EnQue(input_local); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void copy_out(uint32_t offset) { | ||||
|         LocalTensor<int8_t> output_local = output_queue.DeQue<int8_t>(); | ||||
|         DataCopy(output_gm[offset], output_local, QK8_0); | ||||
|         output_queue.FreeTensor(output_local); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline half calculate_group(int64_t row, int64_t group) { | ||||
|         const int64_t i3 = row / (input_ne[1] * input_ne[2]); | ||||
|         const int64_t i2 = (row - i3 * input_ne[1] * input_ne[2]) / input_ne[1]; | ||||
|         const int64_t i1 = | ||||
|             row - i3 * input_ne[1] * input_ne[2] - i2 * input_ne[1]; | ||||
| 
 | ||||
|         const int64_t input_offset = i1 * input_stride[1] + | ||||
|                                      i2 * input_stride[2] + | ||||
|                                      i3 * input_stride[3] + QK8_0 * group; | ||||
| 
 | ||||
|         const int64_t output_offset = i1 * output_stride[1] + | ||||
|                                       i2 * output_stride[2] + | ||||
|                                       i3 * output_stride[3] + QK8_0 * group; | ||||
| 
 | ||||
|         copy_in(input_offset); | ||||
|         LocalTensor<half> input_local = input_queue.DeQue<half>(); | ||||
|         LocalTensor<int8_t> output_local = output_queue.AllocTensor<int8_t>(); | ||||
|         LocalTensor<float> work_local = work_queue.AllocTensor<float>(); | ||||
|         LocalTensor<float> abs_local = abs_queue.AllocTensor<float>(); | ||||
|         LocalTensor<float> max_local = max_queue.AllocTensor<float>(); | ||||
|         LocalTensor<float> cast_local = cast_queue.AllocTensor<float>(); | ||||
| 
 | ||||
|         Cast(cast_local, input_local, RoundMode::CAST_NONE, QK8_0); | ||||
|         Abs(abs_local, cast_local, QK8_0); | ||||
|         ReduceMax(max_local, abs_local, work_local, QK8_0); | ||||
| 
 | ||||
|         pipe_barrier(PIPE_ALL); | ||||
|         float d = max_local.GetValue(0); | ||||
|         d = d / ((1 << 7) - 1); | ||||
|         if (d != 0) { | ||||
|             Muls(cast_local, cast_local, 1.0f / d, QK8_0); | ||||
|         } | ||||
| 
 | ||||
|         Cast(cast_local, cast_local, RoundMode::CAST_ROUND, QK8_0); | ||||
|         Cast(input_local, cast_local, RoundMode::CAST_ROUND, QK8_0); | ||||
|         Cast(output_local, input_local, RoundMode::CAST_ROUND, QK8_0); | ||||
|         output_queue.EnQue(output_local); | ||||
|         copy_out(output_offset); | ||||
| 
 | ||||
|         input_queue.FreeTensor(input_local); | ||||
|         work_queue.FreeTensor(work_local); | ||||
|         abs_queue.FreeTensor(abs_local); | ||||
|         max_queue.FreeTensor(max_local); | ||||
|         cast_queue.FreeTensor(cast_local); | ||||
|         return (half)d; | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void calculate() { | ||||
|         LocalTensor<half> scale_local = scale_queue.AllocTensor<half>(); | ||||
|         uint32_t scale_local_offset = 0; | ||||
|         uint32_t scale_global_offset = 0; | ||||
|         for (int64_t i = ir; i < ir + dr; i++) { | ||||
|             for (int64_t j = 0; j < group_size_in_row; j++) { | ||||
|                 half scale = calculate_group(i, j); | ||||
|                 scale_local.SetValue(scale_local_offset++, scale); | ||||
|                 if (scale_local_offset == 16) { | ||||
|                     scale_local_offset = 0; | ||||
|                     // TODO: OPTIMIZE ME
 | ||||
|                     pipe_barrier(PIPE_ALL); | ||||
|                     DataCopy(scale_gm[scale_global_offset], scale_local, 16); | ||||
|                     pipe_barrier(PIPE_ALL); | ||||
|                     scale_global_offset += 16; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         if (scale_local_offset != 0) { | ||||
|             pipe_barrier(PIPE_ALL); | ||||
|             DataCopyExtParams dataCopyParams; | ||||
|             dataCopyParams.blockCount = 1; | ||||
|             dataCopyParams.blockLen = scale_local_offset * sizeof(half); | ||||
|             DataCopyPad(scale_gm[scale_global_offset], scale_local, | ||||
|                         dataCopyParams); | ||||
|             pipe_barrier(PIPE_ALL); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|    private: | ||||
|     int64_t input_ne[4]; | ||||
|     size_t input_stride[4]; | ||||
| 
 | ||||
|     int64_t *scale_ne; | ||||
|     size_t scale_stride[4]; | ||||
| 
 | ||||
|     int64_t output_ne[4]; | ||||
|     size_t output_stride[4]; | ||||
| 
 | ||||
|     int64_t group_size_in_row; | ||||
| 
 | ||||
|     int64_t ir; | ||||
|     int64_t dr; | ||||
| 
 | ||||
|     TPipe pipe; | ||||
|     GlobalTensor<half> input_gm; | ||||
|     GlobalTensor<half> scale_gm; | ||||
|     GlobalTensor<int8_t> output_gm; | ||||
|     TQue<QuePosition::VECIN, BUFFER_NUM> input_queue; | ||||
|     TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue; | ||||
|     TQue<QuePosition::VECIN, 1> work_queue; | ||||
|     TQue<QuePosition::VECOUT, 1> max_queue; | ||||
|     TQue<QuePosition::VECIN, 1> abs_queue; | ||||
|     TQue<QuePosition::VECOUT, 1> scale_queue; | ||||
|     TQue<QuePosition::VECOUT, 1> cast_queue; | ||||
| 
 | ||||
| }; | ||||
| 
 | ||||
| template <typename T> | ||||
| __aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) { | ||||
|     auto gm_ptr = (__gm__ uint8_t *)gm; | ||||
|     auto ub_ptr = (uint8_t *)(ub); | ||||
|     for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) { | ||||
|         *ub_ptr = *gm_ptr; | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| extern "C" __global__ __aicore__ void ascendc_quantize_f16_q8_0( | ||||
|     GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm, | ||||
|     GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) { | ||||
|     int64_t input_ne_ub[4]; | ||||
|     size_t input_nb_ub[4]; | ||||
|     int64_t output_ne_ub[4]; | ||||
| 
 | ||||
|     copy_to_ub(input_ne_gm, input_ne_ub, 32); | ||||
|     copy_to_ub(input_nb_gm, input_nb_ub, 32); | ||||
|     copy_to_ub(output_ne_gm, output_ne_ub, 32); | ||||
| 
 | ||||
|     QUANTIZE_F16_Q8_0 op; | ||||
|     op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub); | ||||
|     op.calculate(); | ||||
| } | ||||
							
								
								
									
										206
									
								
								ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										206
									
								
								ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,206 @@ | |||
| #include "kernel_operator.h" | ||||
| 
 | ||||
| using namespace AscendC; | ||||
| 
 | ||||
| #define BUFFER_NUM 2 | ||||
| #define QK8_0 32 | ||||
| 
 | ||||
| class QUANTIZE_F32_Q8_0 { | ||||
|    public: | ||||
|     __aicore__ inline QUANTIZE_F32_Q8_0() {} | ||||
|     __aicore__ inline void init(GM_ADDR input, GM_ADDR output, | ||||
|                                 int64_t *input_ne_ub, size_t *input_nb_ub, | ||||
|                                 int64_t *output_ne_ub) { | ||||
|         int64_t op_block_num = GetBlockNum(); | ||||
|         int64_t op_block_idx = GetBlockIdx(); | ||||
| 
 | ||||
|         for (int i = 0; i < 4; i++) { | ||||
|             input_ne[i] = input_ne_ub[i]; | ||||
|             input_stride[i] = input_nb_ub[i] / input_nb_ub[0]; | ||||
| 
 | ||||
|             output_ne[i] = output_ne_ub[i]; | ||||
|         } | ||||
| 
 | ||||
|         output_stride[0] = 1; | ||||
|         for (int i = 1; i < 4; i++) { | ||||
|             output_stride[i] = output_stride[i - 1] * output_ne[i - 1]; | ||||
|         } | ||||
| 
 | ||||
|         scale_ne = input_ne; | ||||
|         scale_stride[0] = 1; | ||||
|         scale_stride[1] = input_ne[0] / QK8_0; | ||||
|         for (int i = 2; i < 4; i++) { | ||||
|             scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1]; | ||||
|         } | ||||
| 
 | ||||
|         // split input tensor by rows.
 | ||||
|         uint64_t nr = input_ne[1] * input_ne[2] * input_ne[3]; | ||||
|         dr = nr / op_block_num; | ||||
| 
 | ||||
|         uint64_t tails = nr % op_block_num; | ||||
|         if (op_block_idx < tails) { | ||||
|             dr += 1; | ||||
|             ir = dr * op_block_idx; | ||||
|         } else { | ||||
|             ir = dr * op_block_idx + tails; | ||||
|         } | ||||
| 
 | ||||
|         group_size_in_row = scale_stride[1]; | ||||
|         int64_t output_size = output_ne[0] * output_ne[1] * output_ne[2] * | ||||
|                               output_ne[3] * sizeof(uint8_t); | ||||
| 
 | ||||
|         input_gm.SetGlobalBuffer((__gm__ float *)input); | ||||
|         output_gm.SetGlobalBuffer((__gm__ int8_t *)output); | ||||
|         scale_gm.SetGlobalBuffer((__gm__ half *)(output + output_size + | ||||
|                                                  ir * group_size_in_row * | ||||
|                                                  sizeof(half))); | ||||
| 
 | ||||
|         pipe.InitBuffer(input_queue, BUFFER_NUM, QK8_0 * sizeof(float)); | ||||
|         pipe.InitBuffer(output_queue, BUFFER_NUM, QK8_0 * sizeof(int8_t)); | ||||
|         pipe.InitBuffer(work_queue, 1, 32); | ||||
|         pipe.InitBuffer(max_queue, 1, 32); | ||||
|         pipe.InitBuffer(abs_queue, 1, QK8_0 * sizeof(float)); | ||||
|         pipe.InitBuffer(cast_queue, 1, QK8_0 * sizeof(half)); | ||||
|         pipe.InitBuffer(scale_queue, 1, 32); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void copy_in(uint32_t offset) { | ||||
|         LocalTensor<float> input_local = input_queue.AllocTensor<float>(); | ||||
|         DataCopy(input_local, input_gm[offset], QK8_0); | ||||
|         input_queue.EnQue(input_local); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void copy_out(uint32_t offset) { | ||||
|         LocalTensor<int8_t> output_local = output_queue.DeQue<int8_t>(); | ||||
|         DataCopy(output_gm[offset], output_local, QK8_0); | ||||
|         output_queue.FreeTensor(output_local); | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline half calculate_group(int64_t row, int64_t group) { | ||||
|         const int64_t i3 = row / (input_ne[1] * input_ne[2]); | ||||
|         const int64_t i2 = (row - i3 * input_ne[1] * input_ne[2]) / input_ne[1]; | ||||
|         const int64_t i1 = | ||||
|             row - i3 * input_ne[1] * input_ne[2] - i2 * input_ne[1]; | ||||
| 
 | ||||
|         const int64_t input_offset = i1 * input_stride[1] + | ||||
|                                      i2 * input_stride[2] + | ||||
|                                      i3 * input_stride[3] + QK8_0 * group; | ||||
| 
 | ||||
|         const int64_t output_offset = i1 * output_stride[1] + | ||||
|                                       i2 * output_stride[2] + | ||||
|                                       i3 * output_stride[3] + QK8_0 * group; | ||||
| 
 | ||||
|         copy_in(input_offset); | ||||
|         LocalTensor<float> input_local = input_queue.DeQue<float>(); | ||||
|         LocalTensor<int8_t> output_local = output_queue.AllocTensor<int8_t>(); | ||||
|         LocalTensor<float> work_local = work_queue.AllocTensor<float>(); | ||||
|         LocalTensor<float> abs_local = abs_queue.AllocTensor<float>(); | ||||
|         LocalTensor<float> max_local = max_queue.AllocTensor<float>(); | ||||
|         LocalTensor<half> cast_local = cast_queue.AllocTensor<half>(); | ||||
| 
 | ||||
|         Abs(abs_local, input_local, QK8_0); | ||||
|         ReduceMax(max_local, abs_local, work_local, QK8_0); | ||||
|         pipe_barrier(PIPE_ALL); | ||||
|         float d = max_local.GetValue(0); | ||||
|         d = d / ((1 << 7) - 1); | ||||
|         if (d != 0) { | ||||
|             Muls(input_local, input_local, 1.0f / d, QK8_0); | ||||
|         } | ||||
| 
 | ||||
|         Cast(input_local, input_local, RoundMode::CAST_ROUND, QK8_0); | ||||
|         Cast(cast_local, input_local, RoundMode::CAST_ROUND, QK8_0); | ||||
|         Cast(output_local, cast_local, RoundMode::CAST_ROUND, QK8_0); | ||||
|         output_queue.EnQue(output_local); | ||||
|         copy_out(output_offset); | ||||
| 
 | ||||
|         input_queue.FreeTensor(input_local); | ||||
|         work_queue.FreeTensor(work_local); | ||||
|         abs_queue.FreeTensor(abs_local); | ||||
|         max_queue.FreeTensor(max_local); | ||||
|         cast_queue.FreeTensor(cast_local); | ||||
| 
 | ||||
|         return (half)d; | ||||
|     } | ||||
| 
 | ||||
|     __aicore__ inline void calculate() { | ||||
|         LocalTensor<half> scale_local = scale_queue.AllocTensor<half>(); | ||||
|         uint32_t scale_local_offset = 0; | ||||
|         uint32_t scale_global_offset = 0; | ||||
|         for (int64_t i = ir; i < ir + dr; i++) { | ||||
|             for (int64_t j = 0; j < group_size_in_row; j++) { | ||||
|                 half scale = calculate_group(i, j); | ||||
|                 scale_local.SetValue(scale_local_offset++, scale); | ||||
|                 if (scale_local_offset == 16) { | ||||
|                     scale_local_offset = 0; | ||||
|                     // TODO: OPTIMIZE ME
 | ||||
|                     pipe_barrier(PIPE_ALL); | ||||
|                     DataCopy(scale_gm[scale_global_offset], scale_local, 16); | ||||
|                     pipe_barrier(PIPE_ALL); | ||||
|                     scale_global_offset += 16; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         if (scale_local_offset != 0) { | ||||
|             pipe_barrier(PIPE_ALL); | ||||
|             DataCopyExtParams dataCopyParams; | ||||
|             dataCopyParams.blockCount = 1; | ||||
|             dataCopyParams.blockLen = scale_local_offset * sizeof(half); | ||||
|             DataCopyPad(scale_gm[scale_global_offset], scale_local, | ||||
|                         dataCopyParams); | ||||
|             pipe_barrier(PIPE_ALL); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|    private: | ||||
|     int64_t input_ne[4]; | ||||
|     size_t input_stride[4]; | ||||
| 
 | ||||
|     int64_t *scale_ne; | ||||
|     size_t scale_stride[4]; | ||||
| 
 | ||||
|     int64_t output_ne[4]; | ||||
|     size_t output_stride[4]; | ||||
| 
 | ||||
|     int64_t group_size_in_row; | ||||
| 
 | ||||
|     int64_t ir; | ||||
|     int64_t dr; | ||||
| 
 | ||||
|     TPipe pipe; | ||||
|     GlobalTensor<float> input_gm; | ||||
|     GlobalTensor<half> scale_gm; | ||||
|     GlobalTensor<int8_t> output_gm; | ||||
|     TQue<QuePosition::VECIN, BUFFER_NUM> input_queue; | ||||
|     TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue; | ||||
|     TQue<QuePosition::VECIN, 1> work_queue; | ||||
|     TQue<QuePosition::VECOUT, 1> max_queue; | ||||
|     TQue<QuePosition::VECIN, 1> abs_queue; | ||||
|     TQue<QuePosition::VECIN, 1> cast_queue; | ||||
|     TQue<QuePosition::VECOUT, 1> scale_queue; | ||||
| }; | ||||
| 
 | ||||
| template <typename T> | ||||
| __aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) { | ||||
|     auto gm_ptr = (__gm__ uint8_t *)gm; | ||||
|     auto ub_ptr = (uint8_t *)(ub); | ||||
|     for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) { | ||||
|         *ub_ptr = *gm_ptr; | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| extern "C" __global__ __aicore__ void ascendc_quantize_f32_q8_0( | ||||
|     GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm, | ||||
|     GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) { | ||||
|     int64_t input_ne_ub[4]; | ||||
|     size_t input_nb_ub[4]; | ||||
|     int64_t output_ne_ub[4]; | ||||
| 
 | ||||
|     copy_to_ub(input_ne_gm, input_ne_ub, 32); | ||||
|     copy_to_ub(input_nb_gm, input_nb_ub, 32); | ||||
|     copy_to_ub(output_ne_gm, output_ne_ub, 32); | ||||
| 
 | ||||
|     QUANTIZE_F32_Q8_0 op; | ||||
|     op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub); | ||||
|     op.calculate(); | ||||
| } | ||||
|  | @ -3341,7 +3341,7 @@ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tenso | |||
| } | ||||
| 
 | ||||
| // check if t1 can be represented as a repeatition of t0
 | ||||
| static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { | ||||
| bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { | ||||
|     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); | ||||
| 
 | ||||
|     return ggml_is_empty(t0) ? ggml_is_empty(t1) : | ||||
|  | @ -13699,6 +13699,7 @@ static void ggml_compute_forward_soft_max( | |||
|     } | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| // ggml_compute_forward_soft_max_back
 | ||||
| 
 | ||||
| static void ggml_compute_forward_soft_max_back_f32( | ||||
|  | @ -21995,6 +21996,14 @@ int ggml_cpu_has_rpc(void) { | |||
| #endif | ||||
| } | ||||
| 
 | ||||
| int ggml_cpu_has_cann(void) { | ||||
| #if defined(GGML_USE_CANN) | ||||
|     return 1; | ||||
| #else | ||||
|     return 0; | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| int ggml_cpu_has_gpublas(void) { | ||||
|     return ggml_cpu_has_cuda() || ggml_cpu_has_vulkan() || ggml_cpu_has_kompute() || ggml_cpu_has_sycl(); | ||||
| } | ||||
|  |  | |||
|  | @ -19,6 +19,8 @@ | |||
| #  include "ggml-sycl.h" | ||||
| #elif defined(GGML_USE_KOMPUTE) | ||||
| #   include "ggml-kompute.h" | ||||
| #elif defined(GGML_USE_CANN) | ||||
| #   include "ggml-cann.h" | ||||
| #endif | ||||
| 
 | ||||
| #ifdef GGML_USE_BLAS | ||||
|  | @ -2079,6 +2081,8 @@ struct llama_state { | |||
|         ggml_backend_metal_log_set_callback(log_callback, log_callback_user_data); | ||||
| #elif defined(GGML_USE_CUDA) | ||||
|         ggml_backend_cuda_log_set_callback(log_callback, log_callback_user_data); | ||||
| #elif defined(GGML_USE_CANN) | ||||
|         ggml_backend_cann_log_set_callback(log_callback, log_callback_user_data); | ||||
| #endif | ||||
|     } | ||||
| 
 | ||||
|  | @ -2889,6 +2893,8 @@ static size_t llama_get_device_count(const llama_model & model) { | |||
|     count = ggml_backend_sycl_get_device_count(); | ||||
| #elif defined(GGML_USE_VULKAN) | ||||
|     count = ggml_backend_vk_get_device_count(); | ||||
| #elif defined(GGML_USE_CANN) | ||||
|     return ggml_backend_cann_get_device_count(); | ||||
| #endif | ||||
| #if defined(GGML_USE_RPC) | ||||
|     count += model.rpc_servers.size(); | ||||
|  | @ -2921,6 +2927,8 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_ | |||
|     if (buft == nullptr) { | ||||
|         LLAMA_LOG_WARN("%s: cannot use GPU %d, check `vulkaninfo --summary`\n", __func__, gpu); | ||||
|     } | ||||
| #elif defined(GGML_USE_CANN) | ||||
|     buft = ggml_backend_cann_buffer_type(gpu); | ||||
| #endif | ||||
| 
 | ||||
|     if (buft == nullptr) { | ||||
|  | @ -2981,6 +2989,11 @@ static size_t llama_get_device_memory(const llama_model & model, int device) { | |||
|     size_t free; | ||||
|     ggml_backend_vk_get_device_memory(device, &free, &total); | ||||
|     return free; | ||||
| #elif defined(GGML_USE_CANN) | ||||
|     size_t total; | ||||
|     size_t free; | ||||
|     ggml_backend_cann_get_device_memory(device, &total, &free); | ||||
|     return free; | ||||
| #else | ||||
|     return 1; | ||||
| #endif | ||||
|  | @ -18871,6 +18884,8 @@ size_t llama_max_devices(void) { | |||
|     return GGML_SYCL_MAX_DEVICES; | ||||
| #elif defined(GGML_USE_VULKAN) | ||||
|     return GGML_VK_MAX_DEVICES; | ||||
| #elif defined(GGML_USE_CANN) | ||||
|     return GGML_CANN_MAX_DEVICES; | ||||
| #else | ||||
|     return 1; | ||||
| #endif | ||||
|  | @ -19212,6 +19227,30 @@ struct llama_context * llama_new_context_with_model( | |||
|             } | ||||
|             ctx->backends.push_back(backend); | ||||
|         } | ||||
| #elif defined(GGML_USE_CANN) | ||||
|     // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
 | ||||
|     // TODO: ggml_backend_cann is not support split tensor now, just leave code here.
 | ||||
|     if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) { | ||||
|         ggml_backend_t backend = ggml_backend_cann_init(model->main_gpu); | ||||
|         if (backend == nullptr) { | ||||
|             LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, model->main_gpu); | ||||
|             llama_free(ctx); | ||||
|             return nullptr; | ||||
|         } | ||||
|         ctx->backends.push_back(backend); | ||||
|     } else { | ||||
|         // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
 | ||||
|         // TODO: currently, CANN can't use multi-gpus, just leave code here for further cann version.
 | ||||
|         for (int32_t device = 0; device < ggml_backend_cann_get_device_count(); ++device) { | ||||
|             ggml_backend_t backend = ggml_backend_cann_init(device); | ||||
|             if (backend == nullptr) { | ||||
|                 LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, device); | ||||
|                 llama_free(ctx); | ||||
|                 return nullptr; | ||||
|             } | ||||
|             ctx->backends.push_back(backend); | ||||
|         } | ||||
|     } | ||||
| #endif | ||||
| 
 | ||||
| #ifdef GGML_USE_BLAS | ||||
|  | @ -21789,6 +21828,8 @@ void llama_log_set(ggml_log_callback log_callback, void * user_data) { | |||
|     ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data); | ||||
| #elif defined(GGML_USE_CUDA) | ||||
|     ggml_backend_cuda_log_set_callback(g_state.log_callback, g_state.log_callback_user_data); | ||||
| #elif defined(GGML_USE_CANN) | ||||
|     ggml_backend_cann_log_set_callback(g_state.log_callback, g_state.log_callback_user_data); | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -759,7 +759,7 @@ struct test_dup : public test_case { | |||
|     } | ||||
| 
 | ||||
|     test_dup(ggml_type type = GGML_TYPE_F32, | ||||
|             std::array<int64_t, 4> ne = {10, 10, 10, 1}, | ||||
|             std::array<int64_t, 4> ne = {10, 10, 20, 1}, | ||||
|             std::array<int64_t, 4> permute = {0, 0, 0, 0}) | ||||
|         : type(type), ne(ne), permute(permute), | ||||
|             _use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {} | ||||
|  | @ -779,9 +779,11 @@ struct test_cpy : public test_case { | |||
|     const ggml_type type_src; | ||||
|     const ggml_type type_dst; | ||||
|     const std::array<int64_t, 4> ne; | ||||
|     const std::array<int64_t, 4> permute; | ||||
|     bool _src_use_permute; | ||||
| 
 | ||||
|     std::string vars() override { | ||||
|         return VARS_TO_STR3(type_src, type_dst, ne); | ||||
|         return VARS_TO_STR4(type_src, type_dst, ne, permute); | ||||
|     } | ||||
| 
 | ||||
|     double max_nmse_err() override { | ||||
|  | @ -793,12 +795,18 @@ struct test_cpy : public test_case { | |||
|     } | ||||
| 
 | ||||
|     test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32, | ||||
|             std::array<int64_t, 4> ne = {10, 10, 10, 1}) | ||||
|         : type_src(type_src), type_dst(type_dst), ne(ne) {} | ||||
|             std::array<int64_t, 4> ne = {10, 10, 10, 1}, | ||||
|             std::array<int64_t, 4> permute = {0, 0, 0, 0}, | ||||
|             bool _dst_use_permute = false) | ||||
|         : type_src(type_src), type_dst(type_dst), ne(ne), permute(permute), | ||||
|           _src_use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {} | ||||
| 
 | ||||
|     ggml_tensor * build_graph(ggml_context * ctx) override { | ||||
|         ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data()); | ||||
|         ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, ne.data()); | ||||
|         if (_src_use_permute) { | ||||
|             src = ggml_permute(ctx, src, permute[0], permute[1], permute[2], permute[3]); | ||||
|         } | ||||
|         ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, src->ne); | ||||
|         ggml_tensor * out = ggml_cpy(ctx, src, dst); | ||||
|         return out; | ||||
|     } | ||||
|  | @ -1174,6 +1182,7 @@ struct test_soft_max : public test_case { | |||
|     } | ||||
| }; | ||||
| 
 | ||||
| 
 | ||||
| // GGML_OP_ROPE
 | ||||
| struct test_rope : public test_case { | ||||
|     const ggml_type type; | ||||
|  | @ -2146,12 +2155,22 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op | |||
|     test_cases.emplace_back(new test_dup(GGML_TYPE_F16)); | ||||
|     test_cases.emplace_back(new test_dup(GGML_TYPE_I32)); | ||||
|     test_cases.emplace_back(new test_dup(GGML_TYPE_I16)); | ||||
|     test_cases.emplace_back(new test_dup(GGML_TYPE_F32, {10, 10, 5, 1}, {0, 2, 1, 3})); | ||||
|     test_cases.emplace_back(new test_dup(GGML_TYPE_F16, {10, 10, 5, 1}, {0, 2, 1, 3})); // dup by rows
 | ||||
|     test_cases.emplace_back(new test_dup(GGML_TYPE_F32, {10, 10, 5, 1}, {1, 0, 2, 3})); | ||||
|     test_cases.emplace_back(new test_dup(GGML_TYPE_F16, {10, 10, 5, 1}, {1, 0, 2, 3})); // dup dst not-contiguous
 | ||||
|     test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {0, 2, 1, 3})); | ||||
|     test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {1, 2, 0, 3})); | ||||
| 
 | ||||
|     for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) { | ||||
|         for (ggml_type type_dst : all_types) { | ||||
|            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4})); | ||||
|            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
 | ||||
|         } | ||||
|     } | ||||
|     for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) { | ||||
|         for (ggml_type type_dst : {GGML_TYPE_F16, GGML_TYPE_F32}) { | ||||
|             test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {1, 0, 2, 3})); // cpy not-contiguous
 | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|  | @ -2283,7 +2302,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op | |||
|         for (int n = 0; n < 10; ++n) { | ||||
|             int64_t ne0 = dist_ne0(rng); | ||||
|             int64_t ne1 = dist_ne1(rng); | ||||
|             test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, 0.1f, ne0 < 1000 ? 4.0f : 0.0f)); | ||||
|             test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, 0.1f, ne0 < 1000 ? 4.0f : 0.0f)); | ||||
|         } | ||||
| 
 | ||||
|         exponent <<= 1; | ||||
|  | @ -2302,7 +2321,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op | |||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, 0.1f, 0.0f)); | ||||
|     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f)); | ||||
|     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  0.1f, 0.0f)); | ||||
|     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  0.1f, 8.0f)); | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue