Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 40 additions & 24 deletions Sources/PartoutCore/Modules/DNSModule.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ public struct DNSModule: Module, BuildableType, Hashable, Codable {
case tls(hostname: String)
}

public enum DomainPolicy: Hashable, Codable, Sendable {
case search
case match
}

public static let moduleType = ModuleType("DNS")

public let id: UniqueID
Expand All @@ -24,6 +29,8 @@ public struct DNSModule: Module, BuildableType, Hashable, Codable {

public let searchDomains: [Address]?

public let domainPolicy: DomainPolicy?

public let routesThroughVPN: Bool?

fileprivate init(
Expand All @@ -32,13 +39,15 @@ public struct DNSModule: Module, BuildableType, Hashable, Codable {
servers: [Address],
domainName: Address?,
searchDomains: [Address]?,
domainPolicy: DomainPolicy?,
routesThroughVPN: Bool?
) {
self.id = id
self.protocolType = protocolType
self.servers = servers
self.domainName = domainName
self.searchDomains = searchDomains
self.domainPolicy = domainPolicy
self.routesThroughVPN = routesThroughVPN
}

Expand All @@ -50,17 +59,29 @@ public struct DNSModule: Module, BuildableType, Hashable, Codable {
switch protocolType {
case .cleartext:
break

case .https(let url):
builder.protocolType = .https
builder.dohURL = url.absoluteString

case .tls(let hostname):
builder.protocolType = .tls
builder.dotHostname = hostname
}
builder.domainName = domainName?.rawValue
builder.searchDomains = searchDomains?.map(\.rawValue)
if let domainName {
builder.isFirstDomainPrimary = true
if let searchDomains {
if domainName == searchDomains.first {
builder.domains = searchDomains.map(\.rawValue)
} else {
builder.domains = [domainName.rawValue] + searchDomains.map(\.rawValue)
}
} else {
builder.domains = [domainName.rawValue]
}
} else if let searchDomains {
builder.isFirstDomainPrimary = false
builder.domains = searchDomains.map(\.rawValue)
}
builder.domainPolicy = domainPolicy
builder.routesThroughVPN = routesThroughVPN
return builder
}
Expand All @@ -78,9 +99,11 @@ extension DNSModule {

public var dotHostname: String

public var domainName: String?
public var domains: [String]?

public var domainPolicy: DomainPolicy?

public var searchDomains: [String]?
public var isFirstDomainPrimary: Bool

public var routesThroughVPN: Bool?

Expand All @@ -94,17 +117,19 @@ extension DNSModule {
servers: [String] = [],
dohURL: String = "",
dotHostname: String = "",
domainName: String? = nil,
searchDomains: [String]? = nil,
domains: [String]? = nil,
domainPolicy: DomainPolicy? = nil,
isFirstDomainPrimary: Bool = false,
routesThroughVPN: Bool? = nil
) {
self.id = id
self.protocolType = protocolType
self.servers = servers
self.dohURL = dohURL
self.dotHostname = dotHostname
self.domainName = domainName
self.searchDomains = searchDomains
self.domains = domains
self.domainPolicy = domainPolicy
self.isFirstDomainPrimary = isFirstDomainPrimary
self.routesThroughVPN = routesThroughVPN
}

Expand All @@ -118,25 +143,15 @@ extension DNSModule {
}
return addr
}
let validDomainName = try domainName.flatMap {
let validDomains = try domains?.compactMap {
guard !$0.isEmpty else {
return nil as Address?
}
guard let addr = Address(rawValue: $0), !addr.isIPAddress else {
throw PartoutError.invalidFields(["domainName": $0])
throw PartoutError.invalidFields(["domains": $0])
}
return addr
}
let validSearchDomains = try searchDomains?.compactMap {
guard !$0.isEmpty else {
return nil as Address?
}
guard let addr = Address(rawValue: $0), !addr.isIPAddress else {
throw PartoutError.invalidFields(["searchDomains": $0])
}
return addr
}

let validProtocolType: ProtocolType
switch protocolType {
case .cleartext:
Expand All @@ -158,8 +173,9 @@ extension DNSModule {
id: id,
protocolType: validProtocolType,
servers: validServers,
domainName: validDomainName,
searchDomains: validSearchDomains,
domainName: isFirstDomainPrimary ? validDomains?.first : nil,
searchDomains: validDomains,
domainPolicy: domainPolicy,
routesThroughVPN: routesThroughVPN
)
}
Expand Down
76 changes: 53 additions & 23 deletions Sources/PartoutOS/AppleNE/Modules/DNSModule+NE.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,52 +6,82 @@ import NetworkExtension

extension DNSModule: NESettingsApplying {
public func apply(_ ctx: PartoutLoggerContext, to settings: inout NEPacketTunnelNetworkSettings) {
var dnsSettings: NEDNSSettings?
let dnsSettings: NEDNSSettings
let rawServers = servers.map(\.rawValue)

// Former DNS settings are always overridden, even with empty servers
switch protocolType {
case .cleartext:
if !rawServers.isEmpty {
dnsSettings = NEDNSSettings(servers: rawServers)
pp_log(ctx, .os, .info, "\t\tServers: \(servers.map { $0.asSensitiveAddress(ctx) })")
} else {
pp_log(ctx, .os, .info, "\t\tServers: empty")
guard !rawServers.isEmpty else {
pp_log(ctx, .os, .info, "\t\tSkip DNS settings, cleartext requires non-empty servers")
return
}

dnsSettings = NEDNSSettings(servers: rawServers)
pp_log(ctx, .os, .info, "\t\tServers: \(servers.map { $0.asSensitiveAddress(ctx) })")
case .https(let url):
let specificSettings = NEDNSOverHTTPSSettings(servers: rawServers)
specificSettings.serverURL = url
dnsSettings = specificSettings
pp_log(ctx, .os, .info, "\t\tServers: \(servers.map { $0.asSensitiveAddress(ctx) })")
pp_log(ctx, .os, .info, "\t\tDoH URL: \(url.absoluteString.asSensitiveAddress(ctx))")

case .tls(let hostname):
let specificSettings = NEDNSOverTLSSettings(servers: rawServers)
specificSettings.serverName = hostname
dnsSettings = specificSettings
pp_log(ctx, .os, .info, "\t\tServers: \(servers.map { $0.asSensitiveAddress(ctx) })")
pp_log(ctx, .os, .info, "\t\tDoT hostname: \(hostname.asSensitiveAddress(ctx))")

@unknown default:
break
}

if dnsSettings != nil {
domainName.map {
dnsSettings?.domainName = $0.rawValue
pp_log(ctx, .os, .info, "\t\tDomain: \($0.asSensitiveAddress(ctx))")
}
searchDomains.map {
guard !$0.isEmpty else {
return
}
dnsSettings?.searchDomains = $0.map(\.rawValue)
pp_log(ctx, .os, .info, "\t\tSearch domains: \($0.map { $0.asSensitiveAddress(ctx) })")
}
} else {
pp_log(ctx, .os, .info, "\t\tSkip DNS settings")
// Main domain (if set)
domainName.map {
dnsSettings.domainName = $0.rawValue
pp_log(ctx, .os, .info, "\t\tDomain: \($0.asSensitiveAddress(ctx))")
}

// Apply domains with the given policy
let domains = searchDomains ?? []
let domainsDescription = domains.map { $0.asSensitiveAddress(ctx) }
let searchDomains = domains.map(\.rawValue)
//
// Credit for .matchDomains:
// https://github.com/WireGuard/wireguard-apple/pull/11
//
switch domainPolicy {
case .search:
dnsSettings.searchDomains = searchDomains
// XXX: This works around a Network Extension bug. We add the
// search domains here because .searchDomains is ineffective when
// the VPN is not the default gateway
dnsSettings.matchDomains = [""] + searchDomains
dnsSettings.matchDomainsNoSearch = false
pp_log(ctx, .os, .info, "\t\tSearch-only domains: \(domainsDescription)")
case .match:
let matchDomains = !searchDomains.isEmpty ? searchDomains : [""]
dnsSettings.searchDomains = nil
dnsSettings.matchDomains = matchDomains
dnsSettings.matchDomainsNoSearch = true
pp_log(ctx, .os, .info, "\t\tMatch-only domains: \(domainsDescription)")
default:
let matchDomains = !searchDomains.isEmpty ? searchDomains : [""]
dnsSettings.searchDomains = searchDomains
dnsSettings.matchDomains = matchDomains
dnsSettings.matchDomainsNoSearch = false
pp_log(ctx, .os, .info, "\t\tMatch/Search domains: \(domainsDescription)")
}

//
// This is why we guard before committing .matchDomains:
// https://git.zx2c4.com/wireguard-apple/commit/?id=20bdf46792905de8862ae7641e50e0f9f99ec946
//
assert(dnsSettings.matchDomains != nil)
if dnsSettings.servers.isEmpty {
pp_log(ctx, .os, .error, "\t\tIgnoring match domains without bootstrap DNS servers")
dnsSettings.matchDomains = nil
}

// Commit to tunnel settings
settings.dnsSettings = dnsSettings
}
}
8 changes: 5 additions & 3 deletions Sources/PartoutOS/AppleNE/Modules/Profile+NE.swift
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ extension Profile {
// 4. configure DNS for domain-based routing

if let dnsSettings = neSettings.dnsSettings {

// route DNS through VPN first unless no servers provided
if !dnsSettings.servers.isEmpty {
// Route DNS through VPN first unless:
// - No servers provided
// - .matchDomains is not configured
// This is a fallback as it *SHOULD* be accomplished by DNSModule+NE
if !dnsSettings.servers.isEmpty, dnsSettings.matchDomains == nil {
neSettings.dnsSettings?.matchDomains = [""]
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,23 @@ private extension NetworkSettingsBuilder {

if let domain = dnsDomain {
pp_log(ctx, .openvpn, .info, "\tDNS: Set domain: \(domain.asSensitiveAddress(ctx))")
dnsSettings.domainName = domain
dnsSettings.domains = [domain]
dnsSettings.isFirstDomainPrimary = true
} else {
dnsSettings.isFirstDomainPrimary = false
}

let searchDomains = allDNSSearchDomains
if !searchDomains.isEmpty {
pp_log(ctx, .openvpn, .info, "\tDNS: Set search domains: \(searchDomains.map { $0.asSensitiveAddress(ctx) })")
dnsSettings.searchDomains = searchDomains
// First domain is main domain
if var domains = dnsSettings.domains {
let otherDomains = searchDomains.filter { !domains.contains($0) }
domains.append(contentsOf: otherDomains)
dnsSettings.domains = domains
} else {
dnsSettings.domains = searchDomains
}
}

do {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ extension WireGuard.Configuration {
}
}
interface.dns.servers = dnsServers.map(\.rawValue)
interface.dns.searchDomains = dnsSearch
interface.dns.domains = dnsSearch
}
if let mtuString = attributes["mtu"] {
guard let mtu = UInt16(mtuString) else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ private extension WireGuard.Configuration.Builder {
if !interface.addresses.isEmpty {
lines.append("Address = \(interface.addresses.wgJoined)")
}
let dnsEntries = interface.dns.servers + (interface.dns.searchDomains ?? [])
let dnsEntries = interface.dns.servers + (interface.dns.domains ?? [])
if !dnsEntries.isEmpty {
lines.append("DNS = \(dnsEntries.wgJoined)")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ extension WireGuard.LocalInterface {

var dnsBuilder = DNSModule.Builder()
dnsBuilder.servers = wg.dns.map(\.stringRepresentation)
dnsBuilder.searchDomains = wg.dnsSearch
dnsBuilder.domains = wg.dnsSearch
let dns = try dnsBuilder.build()

let mtu = wg.mtu
Expand Down
18 changes: 18 additions & 0 deletions Tests/PartoutCoreTests/DNSModuleTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,24 @@ struct DNSModuleTests {
#expect(sut == module.builder())
}

@Test(arguments: [true, false])
func givenDomains_whenRebuild_thenIsRestored(isFirstDomainPrimary: Bool) throws {
let sut = DNSModule.Builder(
protocolType: .cleartext,
servers: ["1.2.3.4"],
domains: ["primary.example.com", "search.example.com"],
isFirstDomainPrimary: isFirstDomainPrimary
)
let module = try sut.build()
let rebuilt = module.builder()

#expect(module.domainName?.rawValue == (isFirstDomainPrimary ? "primary.example.com" : nil))
#expect(module.searchDomains?.map(\.rawValue) == ["primary.example.com", "search.example.com"])
#expect(rebuilt.domains == sut.domains)
#expect(rebuilt.isFirstDomainPrimary == sut.isFirstDomainPrimary)
#expect(rebuilt == sut)
}

@Test
func givenHTTPSWithoutURL_whenBuild_thenFails() {
let sut = DNSModule.Builder(
Expand Down
12 changes: 7 additions & 5 deletions Tests/PartoutOSTests/AppleNE/NESettingsApplyingTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -100,23 +100,25 @@ struct NESettingsApplyingTests {
#expect(proxySettings.exceptionList == module.bypassDomains.map(\.rawValue))
}

@Test
func givenDNS_whenApply_thenUpdatesSettings() throws {
@Test(arguments: [true, false])
func givenDNS_whenApply_thenUpdatesSettings(isFirstDomainPrimary: Bool) throws {
let module = try DNSModule.Builder(
protocolType: .cleartext,
servers: ["1.1.1.1", "2.2.2.2"],
domainName: "domain.com",
searchDomains: ["one.com", "two.com"]
domains: ["domain.com", "one.com", "two.com"],
isFirstDomainPrimary: isFirstDomainPrimary
).build()

var sut = NEPacketTunnelNetworkSettings(tunnelRemoteAddress: "")
module.apply(.global, to: &sut)

let dnsSettings = try #require(sut.dnsSettings)
let expSearchDomains = module.searchDomains?.map(\.rawValue)
#expect(dnsSettings.dnsProtocol == .cleartext)
#expect(dnsSettings.servers == module.servers.map(\.rawValue))
#expect(dnsSettings.domainName == (isFirstDomainPrimary ? "domain.com" : nil))
#expect(dnsSettings.domainName == module.domainName?.rawValue)
#expect(dnsSettings.searchDomains == module.searchDomains?.map(\.rawValue))
#expect(dnsSettings.searchDomains == expSearchDomains)
}

@Test
Expand Down
Loading