Skip to content

Commit 9af8762

Browse files
committed
Add an option decompression_uplo for symmetric results
1 parent 31999f1 commit 9af8762

4 files changed

Lines changed: 142 additions & 49 deletions

File tree

ext/SparseMatrixColoringsCUDAExt.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,15 @@ function SMC.StarSetColoringResult(
4747
A::CuSparseMatrixCSC,
4848
ag::SMC.AdjacencyGraph{T},
4949
color::Vector{<:Integer},
50-
star_set::SMC.StarSet{<:Integer},
50+
star_set::SMC.StarSet{<:Integer};
51+
decompression_uplo::Symbol=:F,
5152
) where {T<:Integer}
53+
@assert decompression_uplo == :F
5254
group = SMC.group_by_color(T, color)
53-
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
55+
compressed_indices = SMC.star_csc_indices(ag, color, star_set, decompression_uplo)
5456
additional_info = (; compressed_indices_gpu_csc=CuVector(compressed_indices))
5557
return SMC.StarSetColoringResult(
56-
A, ag, color, group, compressed_indices, additional_info
58+
A, ag, color, group, compressed_indices, decompression_uplo, additional_info
5759
)
5860
end
5961

@@ -85,13 +87,15 @@ function SMC.StarSetColoringResult(
8587
A::CuSparseMatrixCSR,
8688
ag::SMC.AdjacencyGraph{T},
8789
color::Vector{<:Integer},
88-
star_set::SMC.StarSet{<:Integer},
90+
star_set::SMC.StarSet{<:Integer};
91+
decompression_uplo::Symbol=:F,
8992
) where {T<:Integer}
93+
@assert decompression_uplo == :F
9094
group = SMC.group_by_color(T, color)
91-
compressed_indices = SMC.star_csc_indices(ag, color, star_set)
95+
compressed_indices = SMC.star_csc_indices(ag, color, star_set, decompression_uplo)
9296
additional_info = (; compressed_indices_gpu_csr=CuVector(compressed_indices))
9397
return SMC.StarSetColoringResult(
94-
A, ag, color, group, compressed_indices, additional_info
98+
A, ag, color, group, compressed_indices, decompression_uplo, additional_info
9599
)
96100
end
97101

src/decompression.jl

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -453,12 +453,18 @@ function decompress!(
453453
check_compatible_pattern(A, ag, uplo)
454454
fill!(A, zero(eltype(A)))
455455

456+
l = 0
456457
rvS = rowvals(S)
457458
for j in axes(S, 2)
458459
for k in nzrange(S, j)
459460
i = rvS[k]
460461
if in_triangle(i, j, uplo)
461-
A[i, j] = B[compressed_indices[k]]
462+
if result.decompression_uplo == :F
463+
A[i, j] = B[compressed_indices[k]]
464+
else
465+
l += 1
466+
A[i, j] = B[compressed_indices[l]]
467+
end
462468
end
463469
end
464470
end
@@ -472,6 +478,7 @@ function decompress_single_color!(
472478
result::StarSetColoringResult,
473479
uplo::Symbol=:F,
474480
)
481+
@assert result.decompression_uplo == :F
475482
(; ag, compressed_indices, group) = result
476483
(; S) = ag
477484
check_compatible_pattern(A, ag, uplo)
@@ -509,11 +516,12 @@ function decompress!(
509516
(; S) = ag
510517
nzA = nonzeros(A)
511518
check_compatible_pattern(A, ag, uplo)
512-
if uplo == :F
519+
if result.decompression_uplo == uplo
513520
for k in eachindex(nzA, compressed_indices)
514521
nzA[k] = B[compressed_indices[k]]
515522
end
516523
else
524+
@assert result.decompression_uplo == :F
517525
rvS = rowvals(S)
518526
l = 0 # assume A has the same pattern as the triangle
519527
for j in axes(S, 2)
@@ -529,6 +537,44 @@ function decompress!(
529537
return A
530538
end
531539

540+
function decompress_single_color!(
541+
A::SparseMatrixCSC,
542+
b::AbstractVector,
543+
c::Integer,
544+
result::StarSetColoringResult,
545+
uplo::Symbol=:F,
546+
)
547+
(; ag, compressed_indices) = result
548+
(; S) = ag
549+
lower_index = (c - 1) * S.n + 1
550+
upper_index = c * S.n
551+
nzA = nonzeros(A)
552+
if result.decompression_uplo == uplo
553+
uplo == :F && check_same_pattern(A, S)
554+
for k in eachindex(nzA, compressed_indices)
555+
if lower_index <= compressed_indices[k] <= upper_index
556+
nzA[k] = b[compressed_indices[k] - lower_index + 1]
557+
end
558+
end
559+
else
560+
@assert result.decompression_uplo == :F
561+
rvS = rowvals(S)
562+
l = 0 # assume A has the same pattern as the triangle
563+
for j in axes(S, 2)
564+
for k in nzrange(S, j)
565+
i = rvS[k]
566+
if in_triangle(i, j, uplo)
567+
l += 1
568+
if lower_index <= compressed_indices[k] <= upper_index
569+
nzA[l] = b[i]
570+
end
571+
end
572+
end
573+
end
574+
end
575+
return A
576+
end
577+
532578
## TreeSetColoringResult
533579

534580
function decompress!(

src/interface.jl

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,11 @@ function coloring(
190190
A::AbstractMatrix,
191191
problem::ColoringProblem,
192192
algo::GreedyColoringAlgorithm;
193-
decompression_eltype::Type{R}=Float64,
194193
symmetric_pattern::Bool=false,
194+
decompression_eltype::Type{R}=Float64,
195+
decompression_uplo::Symbol=:F,
195196
) where {R}
196-
return _coloring(WithResult(), A, problem, algo, R, symmetric_pattern)
197+
return _coloring(WithResult(), A, problem, algo, symmetric_pattern, R, decompression_uplo)
197198
end
198199

199200
"""
@@ -229,8 +230,9 @@ function _coloring(
229230
A::AbstractMatrix,
230231
::ColoringProblem{:nonsymmetric,:column},
231232
algo::GreedyColoringAlgorithm,
233+
symmetric_pattern::Bool,
232234
decompression_eltype::Type,
233-
symmetric_pattern::Bool;
235+
decompression_uplo::Symbol;
234236
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing,
235237
)
236238
symmetric_pattern = symmetric_pattern || A isa Union{Symmetric,Hermitian}
@@ -252,8 +254,9 @@ function _coloring(
252254
A::AbstractMatrix,
253255
::ColoringProblem{:nonsymmetric,:row},
254256
algo::GreedyColoringAlgorithm,
257+
symmetric_pattern::Bool,
255258
decompression_eltype::Type,
256-
symmetric_pattern::Bool;
259+
decompression_uplo::Symbol;
257260
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing,
258261
)
259262
symmetric_pattern = symmetric_pattern || A isa Union{Symmetric,Hermitian}
@@ -275,8 +278,9 @@ function _coloring(
275278
A::AbstractMatrix,
276279
::ColoringProblem{:symmetric,:column},
277280
algo::GreedyColoringAlgorithm{:direct},
281+
symmetric_pattern::Bool,
278282
decompression_eltype::Type,
279-
symmetric_pattern::Bool;
283+
decompression_uplo::Symbol;
280284
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing,
281285
)
282286
ag = AdjacencyGraph(A; augmented_graph=false)
@@ -286,7 +290,7 @@ function _coloring(
286290
end
287291
color, star_set = argmin(maximum first, color_and_star_set_by_order)
288292
if speed_setting isa WithResult
289-
return StarSetColoringResult(A, ag, color, star_set)
293+
return StarSetColoringResult(A, ag, color, star_set; decompression_uplo)
290294
else
291295
return color
292296
end
@@ -297,8 +301,9 @@ function _coloring(
297301
A::AbstractMatrix,
298302
::ColoringProblem{:symmetric,:column},
299303
algo::GreedyColoringAlgorithm{:substitution},
300-
decompression_eltype::Type{R},
301304
symmetric_pattern::Bool,
305+
decompression_eltype::Type{R},
306+
decompression_uplo::Symbol,
302307
) where {R}
303308
ag = AdjacencyGraph(A; augmented_graph=false)
304309
color_and_tree_set_by_order = map(algo.orders) do order
@@ -307,7 +312,7 @@ function _coloring(
307312
end
308313
color, tree_set = argmin(maximum first, color_and_tree_set_by_order)
309314
if speed_setting isa WithResult
310-
return TreeSetColoringResult(A, ag, color, tree_set, R)
315+
return TreeSetColoringResult(A, ag, color, tree_set, R; decompression_uplo)
311316
else
312317
return color
313318
end
@@ -318,8 +323,9 @@ function _coloring(
318323
A::AbstractMatrix,
319324
::ColoringProblem{:nonsymmetric,:bidirectional},
320325
algo::GreedyColoringAlgorithm{:direct},
326+
symmetric_pattern::Bool,
321327
decompression_eltype::Type{R},
322-
symmetric_pattern::Bool;
328+
decompression_uplo::Symbol;
323329
forced_colors::Union{AbstractVector{<:Integer},Nothing}=nothing,
324330
) where {R}
325331
A_and_Aᵀ, edge_to_index = bidirectional_pattern(A; symmetric_pattern)
@@ -345,7 +351,9 @@ function _coloring(
345351
t -> maximum(t[3]) + maximum(t[4]), outputs_by_order
346352
) # can't use ncolors without computing the full result
347353
if speed_setting isa WithResult
348-
symmetric_result = StarSetColoringResult(A_and_Aᵀ, ag, color, star_set)
354+
symmetric_result = StarSetColoringResult(
355+
A_and_Aᵀ, ag, color, star_set; decompression_uplo=:L
356+
)
349357
return BicoloringResult(
350358
A,
351359
ag,
@@ -366,8 +374,9 @@ function _coloring(
366374
A::AbstractMatrix,
367375
::ColoringProblem{:nonsymmetric,:bidirectional},
368376
algo::GreedyColoringAlgorithm{:substitution},
369-
decompression_eltype::Type{R},
370377
symmetric_pattern::Bool,
378+
decompression_eltype::Type{R},
379+
decompression_uplo::Symbol,
371380
) where {R}
372381
A_and_Aᵀ, edge_to_index = bidirectional_pattern(A; symmetric_pattern)
373382
ag = AdjacencyGraph(A_and_Aᵀ, edge_to_index, 0; augmented_graph=true)
@@ -390,7 +399,9 @@ function _coloring(
390399
t -> maximum(t[3]) + maximum(t[4]), outputs_by_order
391400
) # can't use ncolors without computing the full result
392401
if speed_setting isa WithResult
393-
symmetric_result = TreeSetColoringResult(A_and_Aᵀ, ag, color, tree_set, R)
402+
symmetric_result = TreeSetColoringResult(
403+
A_and_Aᵀ, ag, color, tree_set, R; decompression_uplo=:L
404+
)
394405
return BicoloringResult(
395406
A,
396407
ag,

0 commit comments

Comments
 (0)