diff --git a/bin/hipify-perl b/bin/hipify-perl index 525532eb..487e21ae 100755 --- a/bin/hipify-perl +++ b/bin/hipify-perl @@ -10010,12 +10010,31 @@ sub transformCubNamespace { } sub transformHostFunctions { + my $process_args = sub { + my ($f, $args_str, $argIdx, $actionType, $cast) = @_; + my $args_inner = substr($args_str, 1, -1); + my (@a, $d, $cur); + foreach my $s (split //, $args_inner) { + if ($s =~ /[\(\[\<\{]/) { $d++ } elsif ($s =~ /[\)\]\>\}]/) { $d-- } + if ($s eq ',' && $d == 0) { push @a, $cur; $cur = '' } else { $cur .= $s } + } push @a, $cur; + if (defined $a[$argIdx]) { + my $v = $a[$argIdx]; $v =~ s/^\s+|\s+$//g; + if ($actionType == 0 || $actionType == 1) { + $a[$argIdx] = "$cast($v)" unless index($v, $cast) == 0; + } elsif ($actionType == 2) { + splice(@a, $argIdx, 1); + } + } + return "$f(" . join(',', @a) . ")"; + }; + foreach $func ( "hipMemcpyToSymbol", "hipMemcpyToSymbolAsync" ) { - s/(?($1, $2, 0, 0, "HIP_SYMBOL") }ges; } foreach $func ( "hipGetSymbolAddress", @@ -10026,7 +10045,7 @@ sub transformHostFunctions { "hipMemcpyFromSymbolAsync" ) { - s/(?($1, $2, 1, 0, "HIP_SYMBOL") }ges; } foreach $func ( "hipFuncSetAttribute", @@ -10036,20 +10055,20 @@ sub transformHostFunctions { "hipLaunchKernel" ) { - s/(?\($2\),/g; + s{\b($func)\s*(\((?:[^()]+|(?2))*\))}{ $process_args->($1, $2, 0, 1, "reinterpret_cast") }ges; } foreach $func ( "hipFuncGetAttributes" ) { - s/(?\($4\)$5/g; + s{\b($func)\s*(\((?:[^()]+|(?2))*\))}{ $process_args->($1, $2, 1, 1, "reinterpret_cast") }ges; } foreach $func ( "hipGraphExecMemcpyNodeSetParamsToSymbol", "hipGraphMemcpyNodeSetParamsFromSymbol" ) { - s/(?($1, $2, 2, 0, "HIP_SYMBOL") }ges; } foreach $func ( "hipModuleOccupancyMaxPotentialBlockSize", @@ -10057,25 +10076,25 @@ sub transformHostFunctions { "hipModuleOccupancyMaxPotentialBlockSizeWithFlags" ) { - s/(?($1, $2, 3, 2, "") }ges; } foreach $func ( "hipGraphExecMemcpyNodeSetParamsFromSymbol" ) { - s/(?($1, $2, 3, 0, "HIP_SYMBOL") }ges; } foreach $func ( "hipGraphAddMemcpyNodeToSymbol" ) { - s/(?($1, $2, 4, 0, "HIP_SYMBOL") }ges; } foreach $func ( "hipGraphAddMemcpyNodeFromSymbol" ) { - s/(?($1, $2, 5, 0, "HIP_SYMBOL") }ges; } } @@ -13427,6 +13446,9 @@ my %hash_HipOnlyUnsupportedFunctions = ( 'cublasLtGetStatusName' => 1, 'cublasLtGetStatusString' => 1, 'cublasLtGetVersion' => 1, + 'cublasLtGroupedMatrixLayoutCreate' => 1, + 'cublasLtGroupedMatrixLayoutInit' => 1, + 'cublasLtGroupedMatrixLayoutInit_internal' => 1, 'cublasLtHeuristicsCacheGetCapacity' => 1, 'cublasLtHeuristicsCacheSetCapacity' => 1, 'cublasLtIntegerWidth_t' => 1, @@ -15945,6 +15967,9 @@ my %hash_RocOnlyUnsupportedFunctions = ( 'cublasLtGetStatusName' => 1, 'cublasLtGetStatusString' => 1, 'cublasLtGetVersion' => 1, + 'cublasLtGroupedMatrixLayoutCreate' => 1, + 'cublasLtGroupedMatrixLayoutInit' => 1, + 'cublasLtGroupedMatrixLayoutInit_internal' => 1, 'cublasLtHeuristicsCacheGetCapacity' => 1, 'cublasLtHeuristicsCacheSetCapacity' => 1, 'cublasLtIntegerWidth_t' => 1, diff --git a/docs/reference/tables/CUBLAS_API_supported_by_HIP.md b/docs/reference/tables/CUBLAS_API_supported_by_HIP.md index 7db6492c..afc836b3 100644 --- a/docs/reference/tables/CUBLAS_API_supported_by_HIP.md +++ b/docs/reference/tables/CUBLAS_API_supported_by_HIP.md @@ -2027,6 +2027,9 @@ |`cublasLtGetStatusName`|11.4| | | | | | | | | | | |`cublasLtGetStatusString`|11.4| | | | | | | | | | | |`cublasLtGetVersion`|10.1| | | | | | | | | | | +|`cublasLtGroupedMatrixLayoutCreate`|13.1| | | | | | | | | | | +|`cublasLtGroupedMatrixLayoutInit`|13.1| | | | | | | | | | | +|`cublasLtGroupedMatrixLayoutInit_internal`|13.1| | | | | | | | | | | |`cublasLtHeuristicsCacheGetCapacity`|11.8| | | | | | | | | | | |`cublasLtHeuristicsCacheSetCapacity`|11.8| | | | | | | | | | | |`cublasLtLoggerForceDisable`|11.0| | | | | | | | | | | diff --git a/docs/reference/tables/CUBLAS_API_supported_by_HIP_and_ROC.md b/docs/reference/tables/CUBLAS_API_supported_by_HIP_and_ROC.md index 37ef8dc0..8e272d0a 100644 --- a/docs/reference/tables/CUBLAS_API_supported_by_HIP_and_ROC.md +++ b/docs/reference/tables/CUBLAS_API_supported_by_HIP_and_ROC.md @@ -2027,6 +2027,9 @@ |`cublasLtGetStatusName`|11.4| | | | | | | | | | | | | | | | | |`cublasLtGetStatusString`|11.4| | | | | | | | | | | | | | | | | |`cublasLtGetVersion`|10.1| | | | | | | | | | | | | | | | | +|`cublasLtGroupedMatrixLayoutCreate`|13.1| | | | | | | | | | | | | | | | | +|`cublasLtGroupedMatrixLayoutInit`|13.1| | | | | | | | | | | | | | | | | +|`cublasLtGroupedMatrixLayoutInit_internal`|13.1| | | | | | | | | | | | | | | | | |`cublasLtHeuristicsCacheGetCapacity`|11.8| | | | | | | | | | | | | | | | | |`cublasLtHeuristicsCacheSetCapacity`|11.8| | | | | | | | | | | | | | | | | |`cublasLtLoggerForceDisable`|11.0| | | | | | | | | | | | | | | | | diff --git a/docs/reference/tables/CUBLAS_API_supported_by_ROC.md b/docs/reference/tables/CUBLAS_API_supported_by_ROC.md index b31d1f1e..4b533707 100644 --- a/docs/reference/tables/CUBLAS_API_supported_by_ROC.md +++ b/docs/reference/tables/CUBLAS_API_supported_by_ROC.md @@ -2027,6 +2027,9 @@ |`cublasLtGetStatusName`|11.4| | | | | | | | | | | |`cublasLtGetStatusString`|11.4| | | | | | | | | | | |`cublasLtGetVersion`|10.1| | | | | | | | | | | +|`cublasLtGroupedMatrixLayoutCreate`|13.1| | | | | | | | | | | +|`cublasLtGroupedMatrixLayoutInit`|13.1| | | | | | | | | | | +|`cublasLtGroupedMatrixLayoutInit_internal`|13.1| | | | | | | | | | | |`cublasLtHeuristicsCacheGetCapacity`|11.8| | | | | | | | | | | |`cublasLtHeuristicsCacheSetCapacity`|11.8| | | | | | | | | | | |`cublasLtLoggerForceDisable`|11.0| | | | | | | | | | | diff --git a/src/CUDA2HIP_BLAS_API_functions.cpp b/src/CUDA2HIP_BLAS_API_functions.cpp index 4b3c53ca..0175a6ad 100644 --- a/src/CUDA2HIP_BLAS_API_functions.cpp +++ b/src/CUDA2HIP_BLAS_API_functions.cpp @@ -1120,6 +1120,9 @@ const std::map CUDA_BLAS_FUNCTION_MAP = [] { m["cublasLtMatrixLayoutSetAttribute"] = {"hipblasLtMatrixLayoutSetAttribute", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LT, ROC_UNSUPPORTED}; m["cublasLtMatrixLayoutGetAttribute"] = {"hipblasLtMatrixLayoutGetAttribute", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LT, ROC_UNSUPPORTED}; m["cublasLtMatmulDescInit"] = {"hipblasLtMatmulDescInit", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LT, UNSUPPORTED}; + m["cublasLtGroupedMatrixLayoutInit_internal"] = {"hipblasLtGroupedMatrixLayoutInit_internal", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LT, UNSUPPORTED}; + m["cublasLtGroupedMatrixLayoutInit"] = {"hipblasLtGroupedMatrixLayoutInit", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LT, UNSUPPORTED}; + m["cublasLtGroupedMatrixLayoutCreate"] = {"hipblasLtGroupedMatrixLayoutCreate", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LT, UNSUPPORTED}; // [hipBLASLt] TODO: Use hipblasComputeType_t instead of incompatible hipblasLtComputeType_t // [HIPIFY] TODO: For CUDA < 11.0 throw an error cublasLtMatmulDescCreate is not supported by HIP, please use the newer version of cublasLtMatmulDescCreate (>=11.0) @@ -1674,6 +1677,9 @@ const std::map CUDA_BLAS_FUNCTION_VER_MAP = [] m["cublasSetFixedPointEmulationMantissaBitOffset"] = {CUDA_130, CUDA_0, CUDA_0 }; // A: CUDA 13.0.2, CUBLAS_VERSION 130100 m["cublasGetFixedPointEmulationMantissaBitCountPointer"] = {CUDA_130, CUDA_0, CUDA_0 }; // A: CUDA 13.0.2, CUBLAS_VERSION 130100 m["cublasSetFixedPointEmulationMantissaBitCountPointer"] = {CUDA_130, CUDA_0, CUDA_0 }; // A: CUDA 13.0.2, CUBLAS_VERSION 130100 + m["cublasLtGroupedMatrixLayoutInit_internal"] = {CUDA_131, CUDA_0, CUDA_0 }; + m["cublasLtGroupedMatrixLayoutInit"] = {CUDA_131, CUDA_0, CUDA_0 }; + m["cublasLtGroupedMatrixLayoutCreate"] = {CUDA_131, CUDA_0, CUDA_0 }; return m; }(); diff --git a/src/CUDA2HIP_Perl.cpp b/src/CUDA2HIP_Perl.cpp index 508465b0..9aa6a149 100644 --- a/src/CUDA2HIP_Perl.cpp +++ b/src/CUDA2HIP_Perl.cpp @@ -450,89 +450,83 @@ namespace perl { void generateHostFunctions(ostream &out) { out << endl << sub << "transformHostFunctions" << " {" << endl; - const string s = "s/(? DeviceSymbolFunctions0; - set DeviceSymbolFunctions1; - set DeviceSymbolFunctions2; - set DeviceSymbolFunctions3; - set DeviceSymbolFunctions4; - set DeviceSymbolFunctions5; - set ReinterpretFunctions0; - set ReinterpretFunctions1; - set RemoveArgFunctions3; + out << tab << "my $process_args = sub {" << endl; + out << tab_2 << "my ($f, $args_str, $argIdx, $actionType, $cast) = @_;" << endl; + out << tab_2 << "my $args_inner = substr($args_str, 1, -1);" << endl; + out << tab_2 << "my (@a, $d, $cur);" << endl; + out << tab_2 << "foreach my $s (split //, $args_inner) {" << endl; + out << tab_3 << "if ($s =~ /[\\(\\[\\<\\{]/) { $d++ } elsif ($s =~ /[\\)\\]\\>\\}]/) { $d-- }" << endl; + out << tab_3 << "if ($s eq ',' && $d == 0) { push @a, $cur; $cur = '' } else { $cur .= $s }" << endl; + out << tab_2 << "} push @a, $cur;" << endl; + out << tab_2 << "if (defined $a[$argIdx]) {" << endl; + out << tab_3 << "my $v = $a[$argIdx]; $v =~ s/^\\s+|\\s+$//g;" << endl; + out << tab_3 << "if ($actionType == 0 || $actionType == 1) {" << endl; + out << tab_4 << "$a[$argIdx] = \"$cast($v)\" unless index($v, $cast) == 0;" << endl; + out << tab_3 << "} elsif ($actionType == 2) {" << endl; + out << tab_4 << "splice(@a, $argIdx, 1);" << endl; + out << tab_3 << "}" << endl; + out << tab_2 << "}" << endl; + out << tab_2 << "return \"$f(\" . join(',', @a) . \")\";" << endl; + out << tab << "};" << endl_2; + set DeviceSymbolFunctions0, DeviceSymbolFunctions1, DeviceSymbolFunctions2; + set DeviceSymbolFunctions3, DeviceSymbolFunctions4, DeviceSymbolFunctions5; + set ReinterpretFunctions0, ReinterpretFunctions1, RemoveArgFunctions3; for (auto f : FuncArgCasts) { auto castStructs = f.second; for (auto cc : castStructs) { for (auto c : cc.castMap) { - switch (c.first) { - case 0: - switch (c.second.castType) { - case e_HIP_SYMBOL: DeviceSymbolFunctions0.insert(f.first); break; - case e_reinterpret_cast: ReinterpretFunctions0.insert(f.first); break; - default: break; - } - break; - case 1: - switch (c.second.castType) { - case e_HIP_SYMBOL: DeviceSymbolFunctions1.insert(f.first); break; - case e_reinterpret_cast: ReinterpretFunctions1.insert(f.first); break; - default: break; - } - break; - case 2: - switch (c.second.castType) { - case e_HIP_SYMBOL: DeviceSymbolFunctions2.insert(f.first); break; - default: break; - } - break; - case 3: - switch (c.second.castType) { - case e_HIP_SYMBOL: DeviceSymbolFunctions3.insert(f.first); break; - case e_remove_argument: RemoveArgFunctions3.insert(f.first); break; - default: break; - } - break; - case 4: - switch (c.second.castType) { - case e_HIP_SYMBOL: DeviceSymbolFunctions4.insert(f.first); break; + switch (c.first) { + case 0: + switch (c.second.castType) { + case e_HIP_SYMBOL: DeviceSymbolFunctions0.insert(f.first); break; + case e_reinterpret_cast: ReinterpretFunctions0.insert(f.first); break; + default: break; + } break; + case 1: + switch (c.second.castType) { + case e_HIP_SYMBOL: DeviceSymbolFunctions1.insert(f.first); break; + case e_reinterpret_cast: ReinterpretFunctions1.insert(f.first); break; + default: break; + } break; + case 2: + if (c.second.castType == e_HIP_SYMBOL) DeviceSymbolFunctions2.insert(f.first); break; + case 3: + switch (c.second.castType) { + case e_HIP_SYMBOL: DeviceSymbolFunctions3.insert(f.first); break; + case e_remove_argument: RemoveArgFunctions3.insert(f.first); break; + default: break; + } break; + case 4: + if (c.second.castType == e_HIP_SYMBOL) DeviceSymbolFunctions4.insert(f.first); break; + case 5: + if (c.second.castType == e_HIP_SYMBOL) DeviceSymbolFunctions5.insert(f.first); break; default: break; - } - break; - case 5: - switch (c.second.castType) { - case e_HIP_SYMBOL: DeviceSymbolFunctions5.insert(f.first); break; - default: break; - } - break; - default: break; + } } } - } } - set &funcSet = DeviceSymbolFunctions0; + set *funcSet = &DeviceSymbolFunctions0; for (int i = 0; i < 9; ++i) { + int argIdx = 0; + int actionType = 0; + string castStr = ""; switch (i) { - default: funcSet = DeviceSymbolFunctions0; break; - case 1: funcSet = DeviceSymbolFunctions1; break; - case 2: funcSet = ReinterpretFunctions0; break; - case 3: funcSet = ReinterpretFunctions1; break; - case 4: funcSet = DeviceSymbolFunctions2; break; - case 5: funcSet = RemoveArgFunctions3; break; - case 6: funcSet = DeviceSymbolFunctions3; break; - case 7: funcSet = DeviceSymbolFunctions4; break; - case 8: funcSet = DeviceSymbolFunctions5; break; + default: + case 0: funcSet = &DeviceSymbolFunctions0; argIdx = 0; actionType = 0; castStr = getCastType(e_HIP_SYMBOL); break; + case 1: funcSet = &DeviceSymbolFunctions1; argIdx = 1; actionType = 0; castStr = getCastType(e_HIP_SYMBOL); break; + case 2: funcSet = &ReinterpretFunctions0; argIdx = 0; actionType = 1; castStr = getCastType(e_reinterpret_cast); break; + case 3: funcSet = &ReinterpretFunctions1; argIdx = 1; actionType = 1; castStr = getCastType(e_reinterpret_cast); break; + case 4: funcSet = &DeviceSymbolFunctions2; argIdx = 2; actionType = 0; castStr = getCastType(e_HIP_SYMBOL); break; + case 5: funcSet = &RemoveArgFunctions3; argIdx = 3; actionType = 2; break; // No cast needed + case 6: funcSet = &DeviceSymbolFunctions3; argIdx = 3; actionType = 0; castStr = getCastType(e_HIP_SYMBOL); break; + case 7: funcSet = &DeviceSymbolFunctions4; argIdx = 4; actionType = 0; castStr = getCastType(e_HIP_SYMBOL); break; + case 8: funcSet = &DeviceSymbolFunctions5; argIdx = 5; actionType = 0; castStr = getCastType(e_HIP_SYMBOL); break; } - if (funcSet.empty()) continue; + if (funcSet->empty()) continue; out << tab + foreach_func << endl; unsigned int count = 0; string sHIPName; - for (auto &f : funcSet) { + for (auto &f : *funcSet) { const auto found = CUDA_RUNTIME_FUNCTION_MAP.find(f); if (found != CUDA_RUNTIME_FUNCTION_MAP.end()) sHIPName = found->second.hipName.str(); else { @@ -542,19 +536,9 @@ namespace perl { out << (count ? ",\n" : "") << tab_2 << "\"" << sHIPName << "\""; count++; } - out << endl_tab << ")" << endl_tab << "{" << endl_tab_2; - switch (i) { - case 0: - default: out << s0 << getCastType(e_HIP_SYMBOL) << "\\($2\\),/g;" << endl; break; - case 1: out << s1 << getCastType(e_HIP_SYMBOL) << "\\($4\\)$5/g;" << endl; break; - case 2: out << s0 << getCastType(e_reinterpret_cast) << "\\($2\\),/g;" << endl; break; - case 3: out << s1 << getCastType(e_reinterpret_cast) << "\\($4\\)$5/g;" << endl; break; - case 4: out << s2 << getCastType(e_HIP_SYMBOL) << "\\($5\\)$6/g;" << endl; break; - case 5: out << s3 << "$7/g;" << endl; break; - case 6: out << s3 << ",$5" << getCastType(e_HIP_SYMBOL) << "\\($6\\)$7/g;" << endl; break; - case 7: out << s4 << getCastType(e_HIP_SYMBOL) << "\\($7\\)$8/g;" << endl; break; - case 8: out << s5 << getCastType(e_HIP_SYMBOL) << "\\($8\\)$9/g;" << endl; break; - } + out << endl_tab << ")" << endl_tab << "{" << endl; + out << tab_2 << "s{\\b($func)\\s*(\\((?:[^()]+|(?2))*\\))}{ $process_args->($1, $2, " + << argIdx << ", " << actionType << ", \"" << castStr << "\") }ges;" << endl; out << tab << "}" << endl; } out << "}" << endl_2;