Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -827,16 +827,209 @@ private fun ValidationScope.validateScalars() {
}
}

internal class FieldAndNode(val field: GQLInputValueDefinition, val node: Node)

internal class Node(val typeDefinition: GQLInputObjectTypeDefinition) {
val isOneOf = typeDefinition.directives.findOneOf()

/**
* Whether that node is valid (can reach a leaf node)
* This will be updated as we traverse the graph
*/
var isValid = false

/**
* Whether that node is visited
*/
var visited = false
var edgeCount = 0
val predecessors = mutableSetOf<Node>()
val sucessors = mutableSetOf<FieldAndNode>()

/**
* tarjan
*/
var index: Int? = null
var lowLink: Int? = null
var onStack = false

override fun toString() = typeDefinition.name
}

private fun ValidationScope.reverseGraph(inputObjectTypeDefinitions: List<GQLInputObjectTypeDefinition>): MutableCollection<Node> {
val nodes = mutableMapOf<String, Node>()
inputObjectTypeDefinitions.forEach {
nodes.put(it.name, Node(it))
}

nodes.values.forEach { node ->
/**
* Track the leaf fields.
* - `@oneOf` are not valid by default but may become if they have one escape hatch.
* - other types are valid by default but may become invalid if they have one non-null reference
*/
node.isValid = !node.isOneOf
node.typeDefinition.inputFields.forEach { field ->
val fieldType = field.type
if (node.isOneOf) {
if (fieldType is GQLNamedType) {
val fieldTypeDefinition = typeDefinitions.get(fieldType.name)
if (fieldTypeDefinition is GQLInputObjectTypeDefinition) {
val successor = nodes.get(fieldTypeDefinition.name)!!
successor.predecessors.add(node)
node.sucessors.add(FieldAndNode(field, successor))
} else {
// scalar or enum
node.isValid = true
}
} else {
// Maybe a list
// Should not be non-null. If it is, other validation rules will catch it.
node.isValid = true
}
} else {
if (fieldType is GQLNonNullType) {
val innerType = fieldType.type
if (innerType is GQLNamedType) {
val fieldTypeDefinition = typeDefinitions.get(innerType.name)
if (fieldTypeDefinition is GQLInputObjectTypeDefinition) {
// Not a leaf field
node.isValid = false
node.edgeCount++
val successor = nodes.get(fieldTypeDefinition.name)!!
successor.predecessors.add(node)
node.sucessors.add(FieldAndNode(field, successor))
}
} else {
// List type => escape
}
} else {
// Nullable type => escape
}
}
}
}
return nodes.values
}

/**
* walks the reverse graph, starting with the leaf nodes to find all the valid nodes
*/
private fun findValid(nodes: Collection<Node>) {
val stack = ArrayDeque<Node>()
// Start with the leaf, non-oneOf types
stack.addAll(nodes.filter { it.isValid })

while (stack.isNotEmpty()) {
val node = stack.removeFirst()
if (node.visited) continue
node.visited = true
node.predecessors.forEach { predecessor ->
if (predecessor.isOneOf) {
predecessor.isValid = true
stack.addAll(predecessor.predecessors)
} else {
predecessor.edgeCount--
if (predecessor.edgeCount == 0) {
predecessor.isValid = true
stack.addAll(predecessor.predecessors)
}
}
}
}
}

private fun removeValid(nodes: MutableCollection<Node>) {
nodes.removeAll { it.isValid }
// At this point, there shouldn't be any edge pointing to a valid node or that would be an escape input field
// nodes.forEach {
// check(it.sucessors.none { it.node.isValid })
// }
}

internal class PathElement(val typename: String, val inputField: GQLInputValueDefinition)
internal typealias Scc = Collection<Node>

/**
* For error reporting purposes, find the longest cycle inside the SCC
*/
private fun findWitnessCycle(scc: Collection<Node>): List<PathElement> {
val start = scc.first()

val path = mutableListOf<PathElement>()
val visited = mutableSetOf<Node>()

fun dfs(current: Node): Boolean {
visited.add(current)
for (fieldAndNode in current.sucessors) {
val next = fieldAndNode.node
path.add(PathElement(current.typeDefinition.name, fieldAndNode.field))
if (next == start) return true
if (next !in visited && dfs(next)) return true
path.removeAt(path.lastIndex)
}
visited.remove(current)
return false
}

dfs(start)
return path
}

internal fun tarjanScc(nodes: Collection<Node>): Collection<Scc> {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same!

var index = 0
val stack = ArrayDeque<Node>()
val result = mutableListOf<Scc>()

fun strongConnect(v: Node) {
v.index = index
v.lowLink = index
index++
stack.addLast(v)
v.onStack = true

v.sucessors.forEach {
val w = it.node
if (w.index == null) {
strongConnect(w)
v.lowLink = minOf(v.lowLink!!, w.lowLink!!)
} else if (w.onStack) {
v.lowLink = minOf(v.lowLink!!, w.index!!)
}
}

if (v.lowLink == v.index) {
val scc = mutableListOf<Node>()
while (true) {
val w = stack.removeLast()
w.onStack = false
scc.add(w)
if (w == v) break
}
result.add(scc)
}
}

nodes.forEach {
if (it.index == null) {
strongConnect(it)
}
}

return result
}

private fun ValidationScope.validateInputObjects() {
val traversalState = TraversalState()
val inputObjects = typeDefinitions.values.filterIsInstance<GQLInputObjectTypeDefinition>()
validateInputObjectsCycles(inputObjects)

val defaultValueTraversalState = DefaultValueTraversalState()
typeDefinitions.values.filterIsInstance<GQLInputObjectTypeDefinition>().forEach { o ->
inputObjects.forEach { o ->
if (o.inputFields.isEmpty()) {
registerIssue("Input object must specify one or more input fields", o.sourceLocation)
}

validateDirectivesInConstContext(o.directives, o)
validateInputFieldCycles(o, traversalState)
validateInputObjectDefaultValue(o, defaultValueTraversalState)

val isOneOfInputObject = o.directives.findOneOf()
Expand All @@ -853,65 +1046,46 @@ private fun ValidationScope.validateInputObjects() {
}
}

private class TraversalState {
val visitedTypes = mutableSetOf<String>()
val fieldPath = mutableListOf<Pair<String, SourceLocation?>>()
val fieldPathIndexByTypeName = mutableMapOf<String, Int>()
}

private class DefaultValueTraversalState {
val visitedFields = mutableSetOf<String>()
val fieldPath = mutableListOf<Pair<String, SourceLocation?>>()
val fieldPathIndex = mutableMapOf<String, Int>()
}


private fun ValidationScope.validateInputFieldCycles(inputObjectTypeDefinition: GQLInputObjectTypeDefinition, state: TraversalState) {
if (state.visitedTypes.contains(inputObjectTypeDefinition.name)) {
return
}
state.visitedTypes.add(inputObjectTypeDefinition.name)

state.fieldPathIndexByTypeName[inputObjectTypeDefinition.name] = state.fieldPath.size

inputObjectTypeDefinition.inputFields.forEach {
val type = it.type
if (type is GQLNonNullType && type.type is GQLNamedType) {
val fieldType = typeDefinitions.get(type.type.name)
if (fieldType is GQLInputObjectTypeDefinition) {
val cycleIndex = state.fieldPathIndexByTypeName.get(fieldType.name)

state.fieldPath.add("${fieldType.name}.${it.name}" to it.sourceLocation)

if (cycleIndex == null) {
validateInputFieldCycles(fieldType, state)
} else {
val cyclePath = state.fieldPath.subList(cycleIndex, state.fieldPath.size)

cyclePath.forEach {
issues.add(
OtherValidationIssue(
buildString {
append("Invalid circular reference. The Input Object '${fieldType.name}' references itself ")
if (cyclePath.size > 1) {
append("via the non-null fields: ")
} else {
append("in the non-null field ")
}
append(cyclePath.map { it.first }.joinToString(", "))
},
it.second
)
)
private fun ValidationScope.validateInputObjectsCycles(inputObjectTypeDefinitions: List<GQLInputObjectTypeDefinition>) {
val nodes = reverseGraph(inputObjectTypeDefinitions)
findValid(nodes)
removeValid(nodes)
tarjanScc(nodes).forEach { scc ->
if (scc.size == 1) {
val firstNode = scc.first()
val fieldAndNode = firstNode.sucessors.firstOrNull()
if (fieldAndNode != null && fieldAndNode.node.typeDefinition.name == firstNode.typeDefinition.name) {
registerIssue("Input object `${firstNode.typeDefinition.name}` references itself through field `${firstNode.typeDefinition.name}.${fieldAndNode.field.name}` and cannot be constructed.", fieldAndNode.field.sourceLocation)
} else {
// Trivial SCC containing a single, non self-referncing node are not an issue.
}
} else {
val cycle = findWitnessCycle(scc)
cycle.indices.forEach { i ->
val start = cycle.get(i)
val cycleAsString = buildString {
var j = i
repeat(cycle.size) {
val cur = cycle.get(j)
append("${cur.typename}.${cur.inputField.name} --> ")
j++
if (j == cycle.size) {
j = 0
}
}
append(start.typename)
}

state.fieldPath.removeLast()
registerIssue("Input object `${start.typename}` references itself through an unbreakable chain of input fields and cannot be constructed: $cycleAsString", start.inputField.sourceLocation)
}
}
}
}

state.fieldPathIndexByTypeName.remove(inputObjectTypeDefinition.name)

private class DefaultValueTraversalState {
val visitedFields = mutableSetOf<String>()
val fieldPath = mutableListOf<Pair<String, SourceLocation?>>()
val fieldPathIndex = mutableMapOf<String, Int>()
}

private fun ValidationScope.validateInputObjectDefaultValue(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package com.apollographql.apollo.graphql.ast.test

import com.apollographql.apollo.ast.GQLField
import com.apollographql.apollo.ast.GQLInputObjectTypeDefinition
import com.apollographql.apollo.ast.GQLInputValueDefinition
import com.apollographql.apollo.ast.GQLNamedType
import com.apollographql.apollo.ast.internal.FieldAndNode
import com.apollographql.apollo.ast.internal.Node
import com.apollographql.apollo.ast.internal.tarjanScc
import kotlin.test.Test

class TarjanTest {
fun typeDefinition(name: String) = GQLInputObjectTypeDefinition(
sourceLocation = null,
description = "",
name = name,
directives = emptyList(),
inputFields = emptyList()
)
val field = GQLInputValueDefinition(
sourceLocation = null,
name = "",
directives = emptyList(),
description = "",
type = GQLNamedType(null, ""),
defaultValue = null,
)

internal fun node(name: String) = Node(typeDefinition(name)).apply { isValid = false }

@Test
fun test1() {
val a = node("a")
val b = node("b")
val c = node("c")

a.sucessors.add(FieldAndNode(field, b))
b.sucessors.add(FieldAndNode(field, c))
c.sucessors.add(FieldAndNode(field, b))

val sccs = tarjanScc(listOf(a, b, c))
println(sccs)
}

@Test
fun test2() {
val a = node("a")
val b = node("b")
val c = node("c")
val d = node("d")

a.sucessors.add(FieldAndNode(field, b))
a.sucessors.add(FieldAndNode(field, d))
b.sucessors.add(FieldAndNode(field, c))
c.sucessors.add(FieldAndNode(field, b))
d.sucessors.add(FieldAndNode(field, a))

val sccs = tarjanScc(listOf(a, b, c))
println(sccs)
}

}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ type Query {
field(arg: SomeInputObject): String
}

# simple cycle
# Input object `SomeInputObject` references itself through field `SomeInputObject.nonNullSelf` and cannot be constructed.
input SomeInputObject {
nonNullSelf: SomeInputObject!
}
Loading
Loading