Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions bin/hipify-perl
Original file line number Diff line number Diff line change
Expand Up @@ -10010,12 +10010,31 @@ sub transformCubNamespace {
}

sub transformHostFunctions {
my $process_args = sub {
my ($f, $args_str, $argIdx, $actionType, $cast) = @_;
my $args_inner = substr($args_str, 1, -1);
my (@a, $d, $cur);
foreach my $s (split //, $args_inner) {
if ($s =~ /[\(\[\<\{]/) { $d++ } elsif ($s =~ /[\)\]\>\}]/) { $d-- }
if ($s eq ',' && $d == 0) { push @a, $cur; $cur = '' } else { $cur .= $s }
} push @a, $cur;
if (defined $a[$argIdx]) {
my $v = $a[$argIdx]; $v =~ s/^\s+|\s+$//g;
if ($actionType == 0 || $actionType == 1) {
$a[$argIdx] = "$cast($v)" unless index($v, $cast) == 0;
} elsif ($actionType == 2) {
splice(@a, $argIdx, 1);
}
}
return "$f(" . join(',', @a) . ")";
};

foreach $func (
"hipMemcpyToSymbol",
"hipMemcpyToSymbolAsync"
)
{
s/(?<!\/\/ CHECK: )($func)\s*\(([^,\)]+),/$func\(HIP_SYMBOL\($2\),/g;
s{\b($func)\s*(\((?:[^()]+|(?2))*\))}{ $process_args->($1, $2, 0, 0, "HIP_SYMBOL") }ges;
}
foreach $func (
"hipGetSymbolAddress",
Expand All @@ -10026,7 +10045,7 @@ sub transformHostFunctions {
"hipMemcpyFromSymbolAsync"
)
{
s/(?<!\/\/ CHECK: )($func)\s*\(([^,\)]+),([\s]*)([^,\)]+)(,\s*|\))/$func\($2,$3HIP_SYMBOL\($4\)$5/g;
s{\b($func)\s*(\((?:[^()]+|(?2))*\))}{ $process_args->($1, $2, 1, 0, "HIP_SYMBOL") }ges;
}
foreach $func (
"hipFuncSetAttribute",
Expand All @@ -10036,46 +10055,46 @@ sub transformHostFunctions {
"hipLaunchKernel"
)
{
s/(?<!\/\/ CHECK: )($func)\s*\(([^,\)]+),/$func\(reinterpret_cast<const void*>\($2\),/g;
s{\b($func)\s*(\((?:[^()]+|(?2))*\))}{ $process_args->($1, $2, 0, 1, "reinterpret_cast<const void*>") }ges;
}
foreach $func (
"hipFuncGetAttributes"
)
{
s/(?<!\/\/ CHECK: )($func)\s*\(([^,\)]+),([\s]*)([^,\)]+)(,\s*|\))/$func\($2,$3reinterpret_cast<const void*>\($4\)$5/g;
s{\b($func)\s*(\((?:[^()]+|(?2))*\))}{ $process_args->($1, $2, 1, 1, "reinterpret_cast<const void*>") }ges;
}
foreach $func (
"hipGraphExecMemcpyNodeSetParamsToSymbol",
"hipGraphMemcpyNodeSetParamsFromSymbol"
)
{
s/(?<!\/\/ CHECK: )($func)\s*\(([^,\)]+),([^,\)]+),([\s]*)([^,\)]+)(,\s*|\))/$func\($2,$3,$4HIP_SYMBOL\($5\)$6/g;
s{\b($func)\s*(\((?:[^()]+|(?2))*\))}{ $process_args->($1, $2, 2, 0, "HIP_SYMBOL") }ges;
}
foreach $func (
"hipModuleOccupancyMaxPotentialBlockSize",
"hipModuleOccupancyMaxPotentialBlockSizeWithFlags",
"hipModuleOccupancyMaxPotentialBlockSizeWithFlags"
)
{
s/(?<!\/\/ CHECK: )($func)\s*\(([^,\)]+),([^,\)]+),([^,\)]+),([\s]*)([^,\)]+)(,\s*|\))/$func\($2,$3,$4$7/g;
s{\b($func)\s*(\((?:[^()]+|(?2))*\))}{ $process_args->($1, $2, 3, 2, "") }ges;
}
foreach $func (
"hipGraphExecMemcpyNodeSetParamsFromSymbol"
)
{
s/(?<!\/\/ CHECK: )($func)\s*\(([^,\)]+),([^,\)]+),([^,\)]+),([\s]*)([^,\)]+)(,\s*|\))/$func\($2,$3,$4,$5HIP_SYMBOL\($6\)$7/g;
s{\b($func)\s*(\((?:[^()]+|(?2))*\))}{ $process_args->($1, $2, 3, 0, "HIP_SYMBOL") }ges;
}
foreach $func (
"hipGraphAddMemcpyNodeToSymbol"
)
{
s/(?<!\/\/ CHECK: )($func)\s*\(([^,\)]+),([^,\)]+),([^,\)]+),([^,\)]+),([\s]*)([^,\)]+)(,\s*|\))/$func\($2,$3,$4,$5,$6HIP_SYMBOL\($7\)$8/g;
s{\b($func)\s*(\((?:[^()]+|(?2))*\))}{ $process_args->($1, $2, 4, 0, "HIP_SYMBOL") }ges;
}
foreach $func (
"hipGraphAddMemcpyNodeFromSymbol"
)
{
s/(?<!\/\/ CHECK: )($func)\s*\(([^,\)]+),([^,\)]+),([^,\)]+),([^,\)]+),([^,\)]+),([\s]*)([^,\)]+)(,\s*|\))/$func\($2,$3,$4,$5,$6,$7HIP_SYMBOL\($8\)$9/g;
s{\b($func)\s*(\((?:[^()]+|(?2))*\))}{ $process_args->($1, $2, 5, 0, "HIP_SYMBOL") }ges;
}
}

Expand Down Expand Up @@ -13427,6 +13446,9 @@ my %hash_HipOnlyUnsupportedFunctions = (
'cublasLtGetStatusName' => 1,
'cublasLtGetStatusString' => 1,
'cublasLtGetVersion' => 1,
'cublasLtGroupedMatrixLayoutCreate' => 1,
'cublasLtGroupedMatrixLayoutInit' => 1,
'cublasLtGroupedMatrixLayoutInit_internal' => 1,
'cublasLtHeuristicsCacheGetCapacity' => 1,
'cublasLtHeuristicsCacheSetCapacity' => 1,
'cublasLtIntegerWidth_t' => 1,
Expand Down Expand Up @@ -15945,6 +15967,9 @@ my %hash_RocOnlyUnsupportedFunctions = (
'cublasLtGetStatusName' => 1,
'cublasLtGetStatusString' => 1,
'cublasLtGetVersion' => 1,
'cublasLtGroupedMatrixLayoutCreate' => 1,
'cublasLtGroupedMatrixLayoutInit' => 1,
'cublasLtGroupedMatrixLayoutInit_internal' => 1,
'cublasLtHeuristicsCacheGetCapacity' => 1,
'cublasLtHeuristicsCacheSetCapacity' => 1,
'cublasLtIntegerWidth_t' => 1,
Expand Down
3 changes: 3 additions & 0 deletions docs/reference/tables/CUBLAS_API_supported_by_HIP.md
Original file line number Diff line number Diff line change
Expand Up @@ -2027,6 +2027,9 @@
|`cublasLtGetStatusName`|11.4| | | | | | | | | | |
|`cublasLtGetStatusString`|11.4| | | | | | | | | | |
|`cublasLtGetVersion`|10.1| | | | | | | | | | |
|`cublasLtGroupedMatrixLayoutCreate`|13.1| | | | | | | | | | |
|`cublasLtGroupedMatrixLayoutInit`|13.1| | | | | | | | | | |
|`cublasLtGroupedMatrixLayoutInit_internal`|13.1| | | | | | | | | | |
|`cublasLtHeuristicsCacheGetCapacity`|11.8| | | | | | | | | | |
|`cublasLtHeuristicsCacheSetCapacity`|11.8| | | | | | | | | | |
|`cublasLtLoggerForceDisable`|11.0| | | | | | | | | | |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2027,6 +2027,9 @@
|`cublasLtGetStatusName`|11.4| | | | | | | | | | | | | | | | |
|`cublasLtGetStatusString`|11.4| | | | | | | | | | | | | | | | |
|`cublasLtGetVersion`|10.1| | | | | | | | | | | | | | | | |
|`cublasLtGroupedMatrixLayoutCreate`|13.1| | | | | | | | | | | | | | | | |
|`cublasLtGroupedMatrixLayoutInit`|13.1| | | | | | | | | | | | | | | | |
|`cublasLtGroupedMatrixLayoutInit_internal`|13.1| | | | | | | | | | | | | | | | |
|`cublasLtHeuristicsCacheGetCapacity`|11.8| | | | | | | | | | | | | | | | |
|`cublasLtHeuristicsCacheSetCapacity`|11.8| | | | | | | | | | | | | | | | |
|`cublasLtLoggerForceDisable`|11.0| | | | | | | | | | | | | | | | |
Expand Down
3 changes: 3 additions & 0 deletions docs/reference/tables/CUBLAS_API_supported_by_ROC.md
Original file line number Diff line number Diff line change
Expand Up @@ -2027,6 +2027,9 @@
|`cublasLtGetStatusName`|11.4| | | | | | | | | | |
|`cublasLtGetStatusString`|11.4| | | | | | | | | | |
|`cublasLtGetVersion`|10.1| | | | | | | | | | |
|`cublasLtGroupedMatrixLayoutCreate`|13.1| | | | | | | | | | |
|`cublasLtGroupedMatrixLayoutInit`|13.1| | | | | | | | | | |
|`cublasLtGroupedMatrixLayoutInit_internal`|13.1| | | | | | | | | | |
|`cublasLtHeuristicsCacheGetCapacity`|11.8| | | | | | | | | | |
|`cublasLtHeuristicsCacheSetCapacity`|11.8| | | | | | | | | | |
|`cublasLtLoggerForceDisable`|11.0| | | | | | | | | | |
Expand Down
6 changes: 6 additions & 0 deletions src/CUDA2HIP_BLAS_API_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,9 @@ const std::map<llvm::StringRef, hipCounter> CUDA_BLAS_FUNCTION_MAP = [] {
m["cublasLtMatrixLayoutSetAttribute"] = {"hipblasLtMatrixLayoutSetAttribute", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LT, ROC_UNSUPPORTED};
m["cublasLtMatrixLayoutGetAttribute"] = {"hipblasLtMatrixLayoutGetAttribute", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LT, ROC_UNSUPPORTED};
m["cublasLtMatmulDescInit"] = {"hipblasLtMatmulDescInit", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LT, UNSUPPORTED};
m["cublasLtGroupedMatrixLayoutInit_internal"] = {"hipblasLtGroupedMatrixLayoutInit_internal", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LT, UNSUPPORTED};
m["cublasLtGroupedMatrixLayoutInit"] = {"hipblasLtGroupedMatrixLayoutInit", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LT, UNSUPPORTED};
m["cublasLtGroupedMatrixLayoutCreate"] = {"hipblasLtGroupedMatrixLayoutCreate", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LT, UNSUPPORTED};

// [hipBLASLt] TODO: Use hipblasComputeType_t instead of incompatible hipblasLtComputeType_t
// [HIPIFY] TODO: For CUDA < 11.0 throw an error cublasLtMatmulDescCreate is not supported by HIP, please use the newer version of cublasLtMatmulDescCreate (>=11.0)
Expand Down Expand Up @@ -1674,6 +1677,9 @@ const std::map<llvm::StringRef, cudaAPIversions> CUDA_BLAS_FUNCTION_VER_MAP = []
m["cublasSetFixedPointEmulationMantissaBitOffset"] = {CUDA_130, CUDA_0, CUDA_0 }; // A: CUDA 13.0.2, CUBLAS_VERSION 130100
m["cublasGetFixedPointEmulationMantissaBitCountPointer"] = {CUDA_130, CUDA_0, CUDA_0 }; // A: CUDA 13.0.2, CUBLAS_VERSION 130100
m["cublasSetFixedPointEmulationMantissaBitCountPointer"] = {CUDA_130, CUDA_0, CUDA_0 }; // A: CUDA 13.0.2, CUBLAS_VERSION 130100
m["cublasLtGroupedMatrixLayoutInit_internal"] = {CUDA_131, CUDA_0, CUDA_0 };
m["cublasLtGroupedMatrixLayoutInit"] = {CUDA_131, CUDA_0, CUDA_0 };
m["cublasLtGroupedMatrixLayoutCreate"] = {CUDA_131, CUDA_0, CUDA_0 };

return m;
}();
Expand Down
148 changes: 66 additions & 82 deletions src/CUDA2HIP_Perl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,89 +450,83 @@ namespace perl {

void generateHostFunctions(ostream &out) {
out << endl << sub << "transformHostFunctions" << " {" << endl;
const string s = "s/(?<!\\/\\/ CHECK: )($func)\\s*";
const string s0 = s + "\\(([^,\\)]+),/$func\\(";
const string s1 = s + "\\(([^,\\)]+),([\\s]*)([^,\\)]+)(,\\s*|\\))/$func\\($2,$3";
const string s2 = s + "\\(([^,\\)]+),([^,\\)]+),([\\s]*)([^,\\)]+)(,\\s*|\\))/$func\\($2,$3,$4";
const string s3 = s + "\\(([^,\\)]+),([^,\\)]+),([^,\\)]+),([\\s]*)([^,\\)]+)(,\\s*|\\))/$func\\($2,$3,$4";
const string s4 = s + "\\(([^,\\)]+),([^,\\)]+),([^,\\)]+),([^,\\)]+),([\\s]*)([^,\\)]+)(,\\s*|\\))/$func\\($2,$3,$4,$5,$6";
const string s5 = s + "\\(([^,\\)]+),([^,\\)]+),([^,\\)]+),([^,\\)]+),([^,\\)]+),([\\s]*)([^,\\)]+)(,\\s*|\\))/$func\\($2,$3,$4,$5,$6,$7";
set<string> DeviceSymbolFunctions0;
set<string> DeviceSymbolFunctions1;
set<string> DeviceSymbolFunctions2;
set<string> DeviceSymbolFunctions3;
set<string> DeviceSymbolFunctions4;
set<string> DeviceSymbolFunctions5;
set<string> ReinterpretFunctions0;
set<string> ReinterpretFunctions1;
set<string> RemoveArgFunctions3;
out << tab << "my $process_args = sub {" << endl;
out << tab_2 << "my ($f, $args_str, $argIdx, $actionType, $cast) = @_;" << endl;
out << tab_2 << "my $args_inner = substr($args_str, 1, -1);" << endl;
out << tab_2 << "my (@a, $d, $cur);" << endl;
out << tab_2 << "foreach my $s (split //, $args_inner) {" << endl;
out << tab_3 << "if ($s =~ /[\\(\\[\\<\\{]/) { $d++ } elsif ($s =~ /[\\)\\]\\>\\}]/) { $d-- }" << endl;
out << tab_3 << "if ($s eq ',' && $d == 0) { push @a, $cur; $cur = '' } else { $cur .= $s }" << endl;
out << tab_2 << "} push @a, $cur;" << endl;
out << tab_2 << "if (defined $a[$argIdx]) {" << endl;
out << tab_3 << "my $v = $a[$argIdx]; $v =~ s/^\\s+|\\s+$//g;" << endl;
out << tab_3 << "if ($actionType == 0 || $actionType == 1) {" << endl;
out << tab_4 << "$a[$argIdx] = \"$cast($v)\" unless index($v, $cast) == 0;" << endl;
out << tab_3 << "} elsif ($actionType == 2) {" << endl;
out << tab_4 << "splice(@a, $argIdx, 1);" << endl;
out << tab_3 << "}" << endl;
out << tab_2 << "}" << endl;
out << tab_2 << "return \"$f(\" . join(',', @a) . \")\";" << endl;
out << tab << "};" << endl_2;
set<string> DeviceSymbolFunctions0, DeviceSymbolFunctions1, DeviceSymbolFunctions2;
set<string> DeviceSymbolFunctions3, DeviceSymbolFunctions4, DeviceSymbolFunctions5;
set<string> ReinterpretFunctions0, ReinterpretFunctions1, RemoveArgFunctions3;
for (auto f : FuncArgCasts) {
auto castStructs = f.second;
for (auto cc : castStructs) {
for (auto c : cc.castMap) {
switch (c.first) {
case 0:
switch (c.second.castType) {
case e_HIP_SYMBOL: DeviceSymbolFunctions0.insert(f.first); break;
case e_reinterpret_cast: ReinterpretFunctions0.insert(f.first); break;
default: break;
}
break;
case 1:
switch (c.second.castType) {
case e_HIP_SYMBOL: DeviceSymbolFunctions1.insert(f.first); break;
case e_reinterpret_cast: ReinterpretFunctions1.insert(f.first); break;
default: break;
}
break;
case 2:
switch (c.second.castType) {
case e_HIP_SYMBOL: DeviceSymbolFunctions2.insert(f.first); break;
default: break;
}
break;
case 3:
switch (c.second.castType) {
case e_HIP_SYMBOL: DeviceSymbolFunctions3.insert(f.first); break;
case e_remove_argument: RemoveArgFunctions3.insert(f.first); break;
default: break;
}
break;
case 4:
switch (c.second.castType) {
case e_HIP_SYMBOL: DeviceSymbolFunctions4.insert(f.first); break;
switch (c.first) {
case 0:
switch (c.second.castType) {
case e_HIP_SYMBOL: DeviceSymbolFunctions0.insert(f.first); break;
case e_reinterpret_cast: ReinterpretFunctions0.insert(f.first); break;
default: break;
} break;
case 1:
switch (c.second.castType) {
case e_HIP_SYMBOL: DeviceSymbolFunctions1.insert(f.first); break;
case e_reinterpret_cast: ReinterpretFunctions1.insert(f.first); break;
default: break;
} break;
case 2:
if (c.second.castType == e_HIP_SYMBOL) DeviceSymbolFunctions2.insert(f.first); break;
case 3:
switch (c.second.castType) {
case e_HIP_SYMBOL: DeviceSymbolFunctions3.insert(f.first); break;
case e_remove_argument: RemoveArgFunctions3.insert(f.first); break;
default: break;
} break;
case 4:
if (c.second.castType == e_HIP_SYMBOL) DeviceSymbolFunctions4.insert(f.first); break;
case 5:
if (c.second.castType == e_HIP_SYMBOL) DeviceSymbolFunctions5.insert(f.first); break;
default: break;
}
break;
case 5:
switch (c.second.castType) {
case e_HIP_SYMBOL: DeviceSymbolFunctions5.insert(f.first); break;
default: break;
}
break;
default: break;
}
}
}
}
}
set<string> &funcSet = DeviceSymbolFunctions0;
set<string> *funcSet = &DeviceSymbolFunctions0;
for (int i = 0; i < 9; ++i) {
int argIdx = 0;
int actionType = 0;
string castStr = "";
switch (i) {
default: funcSet = DeviceSymbolFunctions0; break;
case 1: funcSet = DeviceSymbolFunctions1; break;
case 2: funcSet = ReinterpretFunctions0; break;
case 3: funcSet = ReinterpretFunctions1; break;
case 4: funcSet = DeviceSymbolFunctions2; break;
case 5: funcSet = RemoveArgFunctions3; break;
case 6: funcSet = DeviceSymbolFunctions3; break;
case 7: funcSet = DeviceSymbolFunctions4; break;
case 8: funcSet = DeviceSymbolFunctions5; break;
default:
case 0: funcSet = &DeviceSymbolFunctions0; argIdx = 0; actionType = 0; castStr = getCastType(e_HIP_SYMBOL); break;
case 1: funcSet = &DeviceSymbolFunctions1; argIdx = 1; actionType = 0; castStr = getCastType(e_HIP_SYMBOL); break;
case 2: funcSet = &ReinterpretFunctions0; argIdx = 0; actionType = 1; castStr = getCastType(e_reinterpret_cast); break;
case 3: funcSet = &ReinterpretFunctions1; argIdx = 1; actionType = 1; castStr = getCastType(e_reinterpret_cast); break;
case 4: funcSet = &DeviceSymbolFunctions2; argIdx = 2; actionType = 0; castStr = getCastType(e_HIP_SYMBOL); break;
case 5: funcSet = &RemoveArgFunctions3; argIdx = 3; actionType = 2; break; // No cast needed
case 6: funcSet = &DeviceSymbolFunctions3; argIdx = 3; actionType = 0; castStr = getCastType(e_HIP_SYMBOL); break;
case 7: funcSet = &DeviceSymbolFunctions4; argIdx = 4; actionType = 0; castStr = getCastType(e_HIP_SYMBOL); break;
case 8: funcSet = &DeviceSymbolFunctions5; argIdx = 5; actionType = 0; castStr = getCastType(e_HIP_SYMBOL); break;
}
if (funcSet.empty()) continue;
if (funcSet->empty()) continue;
out << tab + foreach_func << endl;
unsigned int count = 0;
string sHIPName;
for (auto &f : funcSet) {
for (auto &f : *funcSet) {
const auto found = CUDA_RUNTIME_FUNCTION_MAP.find(f);
if (found != CUDA_RUNTIME_FUNCTION_MAP.end()) sHIPName = found->second.hipName.str();
else {
Expand All @@ -542,19 +536,9 @@ namespace perl {
out << (count ? ",\n" : "") << tab_2 << "\"" << sHIPName << "\"";
count++;
}
out << endl_tab << ")" << endl_tab << "{" << endl_tab_2;
switch (i) {
case 0:
default: out << s0 << getCastType(e_HIP_SYMBOL) << "\\($2\\),/g;" << endl; break;
case 1: out << s1 << getCastType(e_HIP_SYMBOL) << "\\($4\\)$5/g;" << endl; break;
case 2: out << s0 << getCastType(e_reinterpret_cast) << "\\($2\\),/g;" << endl; break;
case 3: out << s1 << getCastType(e_reinterpret_cast) << "\\($4\\)$5/g;" << endl; break;
case 4: out << s2 << getCastType(e_HIP_SYMBOL) << "\\($5\\)$6/g;" << endl; break;
case 5: out << s3 << "$7/g;" << endl; break;
case 6: out << s3 << ",$5" << getCastType(e_HIP_SYMBOL) << "\\($6\\)$7/g;" << endl; break;
case 7: out << s4 << getCastType(e_HIP_SYMBOL) << "\\($7\\)$8/g;" << endl; break;
case 8: out << s5 << getCastType(e_HIP_SYMBOL) << "\\($8\\)$9/g;" << endl; break;
}
out << endl_tab << ")" << endl_tab << "{" << endl;
out << tab_2 << "s{\\b($func)\\s*(\\((?:[^()]+|(?2))*\\))}{ $process_args->($1, $2, "
<< argIdx << ", " << actionType << ", \"" << castStr << "\") }ges;" << endl;
out << tab << "}" << endl;
}
out << "}" << endl_2;
Expand Down
Loading