Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 49 additions & 6 deletions admin/server/auth/handlers.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package auth

import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
Expand Down Expand Up @@ -173,9 +174,10 @@ func (a *Authenticator) authStart(w http.ResponseWriter, r *http.Request, signup
}

// If this is part of the custom domain login flow, save that info in the cookie since we need that info when handling the auth callback.
customDomainFlow := r.URL.Query().Get("custom_domain_flow")
if b, err := strconv.ParseBool(customDomainFlow); err == nil && b {
sess.Values[cookieFieldCustomDomainFlow] = true
customDomainFlow := false
if b, err := strconv.ParseBool(r.URL.Query().Get("custom_domain_flow")); err == nil && b {
sess.Values[cookieFieldCustomDomainFlow] = b
customDomainFlow = b
}

// Save cookie
Expand All @@ -193,6 +195,12 @@ func (a *Authenticator) authStart(w http.ResponseWriter, r *http.Request, signup
return
}

err := a.validateRedirectURL(r.Context(), redirect, customDomainFlow)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

// Redirect to auth provider (canonical domain flow)
redirectURL := a.oauth2.AuthCodeURL(state)
if signup {
Expand Down Expand Up @@ -569,20 +577,35 @@ func (a *Authenticator) authLogout(w http.ResponseWriter, r *http.Request) {
return
}

// Extract custom redirect destination (if any)
// Extract custom redirect destination (if any).
redirect := r.URL.Query().Get("redirect")

// Redirect to authLogoutProvider (see its docstring below for details on why we do this).
http.Redirect(w, r, a.admin.URLs.AuthLogoutProvider(redirect), http.StatusTemporaryRedirect)
host := originalHost(r)
if a.admin.URLs.IsCustomDomain(host) {
http.Redirect(w, r, a.admin.URLs.AuthLogoutProvider(redirect, true), http.StatusTemporaryRedirect)
return
}
http.Redirect(w, r, a.admin.URLs.AuthLogoutProvider(redirect, false), http.StatusTemporaryRedirect)
}

// authLogoutProvider redirects to the auth provider's logout flow.
// This is separated from authLogout to support orgs with custom domains where the auth token cookie must be cleared from the custom domain,
// but the redirect destination must be set in a cookie on the primary domain because the auth provider will redirect to authLogoutCallback on the primary domain.
func (a *Authenticator) authLogoutProvider(w http.ResponseWriter, r *http.Request) {
// Set custom redirect destination in cookie for when the logout flow is over (if any)
// Validate and set custom redirect destination in cookie for when the logout flow is over (if any)
redirect := r.URL.Query().Get("redirect")
if redirect != "" {
customDomainFlow := false
if b, err := strconv.ParseBool(r.URL.Query().Get("custom_domain_flow")); err == nil && b {
customDomainFlow = b
}
err := a.validateRedirectURL(r.Context(), redirect, customDomainFlow)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

// Update cookie
sess := a.cookies.Get(r, cookieName)
sess.Values[cookieFieldRedirect] = redirect
Expand Down Expand Up @@ -709,6 +732,26 @@ func (a *Authenticator) getAccessToken(w http.ResponseWriter, r *http.Request) {
}
}

func (a *Authenticator) validateRedirectURL(ctx context.Context, redirect string, allowCustomDomains bool) error {
if a.admin.URLs.IsSafeRedirectURL(redirect) {
return nil
}
if !allowCustomDomains {
return fmt.Errorf("redirect to %q is not allowed", redirect)
}

parsed, err := url.Parse(redirect)
if err != nil {
return fmt.Errorf("fail to parse redirect URL: %w", err)
}

_, err = a.admin.DB.FindOrganizationByCustomDomain(ctx, parsed.Host)
if errors.Is(err, database.ErrNotFound) {
return fmt.Errorf("redirect to %q is not allowed", redirect)
}
return err
}

func originalHost(r *http.Request) string {
if xfHost := r.Header.Get("Rill-Custom-Domain"); xfHost != "" {
return xfHost
Expand Down
40 changes: 38 additions & 2 deletions admin/urls.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,37 @@ func (u *URLs) WithCustomDomain(domain string) *URLs {
}
}

// IsSafeRedirectURL reports whether redirect is safe to redirect to after an auth flow.
// A redirect is safe when it is:
// - empty (caller defaults to the frontend URL)
// - a relative path (no scheme, no host); protocol-relative "//evil.com" and
// scheme-only "javascript:" / "data:" forms are rejected
// - an absolute URL whose host matches the primary external URL host, the primary frontend URL host
func (u *URLs) IsSafeRedirectURL(redirect string) bool {
if redirect == "" {
return true
}
parsed, err := url.Parse(redirect)
if err != nil {
return false
}
// No host: safe only when there is also no scheme.
// This rejects javascript:, data:, mailto:, and //evil.com forms.
if parsed.Host == "" {
return parsed.Scheme == ""
}
// Absolute URL: host must match a trusted host.
externalURL, err := url.Parse(u.external)
if err == nil && strings.EqualFold(parsed.Host, externalURL.Host) {
return true
}
frontendURL, err := url.Parse(u.frontend)
if err == nil && strings.EqualFold(parsed.Host, frontendURL.Host) {
return true
}
return false
}

// WithCustomDomainFromRedirectURL attempts to infer a custom domain from a redirect URL.
// If it succeeds, it passes the custom domain to WithCustomDomain and returns the result.
// If it does not detect a custom domain in the redirect URL, or the redirect URL is invalid, it fails silently by returning itself unchanged.
Expand Down Expand Up @@ -170,11 +201,16 @@ func (u *URLs) AuthLogout() string {
}

// AuthLogoutProvider returns the URL that starts the logout redirects against the external auth provider.
func (u *URLs) AuthLogoutProvider(redirect string) string {
func (u *URLs) AuthLogoutProvider(redirect string, customDomainFlow bool) string {
res := urlutil.MustJoinURL(u.external, "/auth/logout/provider") // NOTE: Always using the primary external URL.
q := map[string]string{}
if redirect != "" {
res = urlutil.MustWithQuery(res, map[string]string{"redirect": redirect})
q["redirect"] = redirect
}
if customDomainFlow {
q["custom_domain_flow"] = "true"
}
res = urlutil.MustWithQuery(res, map[string]string{"redirect": redirect})
return res
}

Expand Down
Loading