diff --git a/example.config.toml b/example.config.toml index b9293c270..901241cbb 100644 --- a/example.config.toml +++ b/example.config.toml @@ -49,9 +49,10 @@ concurrency = 8192 prefer-ip = "prefer-ipv6" # Public IP addresses of this server. Used by 'mtg access' to generate -# proxy links and by 'mtg doctor' to validate SNI-DNS match. -# If not set, mtg tries to detect them automatically via ifconfig.co. -# Set these if ifconfig.co is unreachable from your server. +# proxy links and by 'mtg doctor' / proxy startup to validate SNI-DNS match. +# If not set, mtg tries to detect them automatically by querying the public +# HTTPS endpoints listed in network.public-ip-endpoints (see below). +# Set these explicitly if those endpoints are unreachable from your server. # public-ipv4 = "1.2.3.4" # public-ipv6 = "2001:db8::1" @@ -200,6 +201,17 @@ proxies = [ # "socks5://user:password@host:port" ] +# HTTPS endpoints used to discover this server's public IPv4/IPv6 when +# public-ipv4 / public-ipv6 are not set. Each must return the client's public +# IP as a single address in the plain-text response body. mtg tries them in +# order and uses the first that succeeds. The default is shown below; setting +# this option overrides the default entirely. +# public-ip-endpoints = [ +# "https://ifconfig.co", +# "https://icanhazip.com", +# "https://ifconfig.me", +# ] + # network timeouts define different settings for timeouts. tcp timeout # define a global timeout on establishing of network connections. idle # means a timeout on pumping data between sockset when nothing is diff --git a/internal/cli/access.go b/internal/cli/access.go index c93c97b19..0ab0fd63b 100644 --- a/internal/cli/access.go +++ b/internal/cli/access.go @@ -1,6 +1,7 @@ package cli import ( + "context" "encoding/json" "fmt" "net" @@ -8,6 +9,7 @@ import ( "os" "strconv" "sync" + "time" "github.com/9seconds/mtg/v2/internal/config" "github.com/9seconds/mtg/v2/internal/utils" @@ -54,6 +56,10 @@ func (a *Access) Run(cli *CLI, version string) error { return fmt.Errorf("cannot init network: %w", err) } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + endpoints := resolvePublicIPEndpoints(conf.Network.PublicIPEndpoints) wg := &sync.WaitGroup{} wg.Go(func() { @@ -62,7 +68,7 @@ func (a *Access) Run(cli *CLI, version string) error { ip = conf.PublicIPv4.Get(nil) } if ip == nil { - ip = getIP(ntw, "tcp4") + ip = getIP(ctx, ntw, "tcp4", endpoints) } if ip != nil { @@ -77,7 +83,7 @@ func (a *Access) Run(cli *CLI, version string) error { ip = conf.PublicIPv6.Get(nil) } if ip == nil { - ip = getIP(ntw, "tcp6") + ip = getIP(ctx, ntw, "tcp6", endpoints) } if ip != nil { diff --git a/internal/cli/doctor.go b/internal/cli/doctor.go index 48563bc1e..04fd3aa03 100644 --- a/internal/cli/doctor.go +++ b/internal/cli/doctor.go @@ -52,10 +52,13 @@ var ( ) tplODNSSNIMatch = template.Must( - template.New("").Parse(" ✅ IP address {{ .ip }} matches secret hostname {{ .hostname }}\n"), + template.New("").Parse(" ✅ Secret hostname {{ .hostname }} matches our public IP ({{ .our }}); resolved: {{ .resolved }}\n"), ) tplEDNSSNIMatch = template.Must( - template.New("").Parse(" ❌ Hostname {{ .hostname }} {{ if .resolved }}is resolved to {{ .resolved }} addresses, not {{ if .ip4 }}{{ .ip4 }}{{ else }}{{ .ip6 }}{{ end }}{{ else }}cannot be resolved to any host{{ end }}\n"), + template.New("").Parse(" ❌ Secret hostname {{ .hostname }} resolves to {{ .resolved }} but our public IP is {{ .our }}{{ if .families }} (mismatched families: {{ .families }}){{ end }}\n"), + ) + tplEDNSSNINoResolve = template.Must( + template.New("").Parse(" ❌ Secret hostname {{ .hostname }} cannot be resolved to any address\n"), ) tplOFrontingDomain = template.Must( @@ -329,26 +332,20 @@ func (d *Doctor) checkFrontingDomain(ntw mtglib.Network) bool { } func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) bool { - addresses, err := resolver.LookupIPAddr(context.Background(), d.conf.Secret.Host) - if err != nil { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + res := runSNICheck(ctx, resolver, d.conf, ntw) + + if res.ResolveErr != nil { tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck - "description": fmt.Sprintf("cannot resolve DNS name of %s", d.conf.Secret.Host), - "error": err, + "description": fmt.Sprintf("cannot resolve DNS name of %s", res.Host), + "error": res.ResolveErr, }) return false } - ourIP4 := d.conf.PublicIPv4.Get(nil) - if ourIP4 == nil { - ourIP4 = getIP(ntw, "tcp4") - } - - ourIP6 := d.conf.PublicIPv6.Get(nil) - if ourIP6 == nil { - ourIP6 = getIP(ntw, "tcp6") - } - - if ourIP4 == nil && ourIP6 == nil { + if !res.Known() { tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck "description": "cannot detect public IP address", "error": errors.New("cannot detect automatically and public-ipv4/public-ipv6 are not set in config"), @@ -356,25 +353,55 @@ func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) boo return false } - strAddresses := []string{} - for _, value := range addresses { - if (ourIP4 != nil && value.IP.String() == ourIP4.String()) || - (ourIP6 != nil && value.IP.String() == ourIP6.String()) { - tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck - "ip": value.IP, - "hostname": d.conf.Secret.Host, - }) - return true + if len(res.Resolved) == 0 { + tplEDNSSNINoResolve.Execute(os.Stdout, map[string]any{ //nolint: errcheck + "hostname": res.Host, + }) + return false + } + + resolved := make([]string, 0, len(res.Resolved)) + for _, ip := range res.Resolved { + resolved = append(resolved, `"`+ip.String()+`"`) + } + + our := "" + if res.OurIPv4 != nil { + our = res.OurIPv4.String() + } + + if res.OurIPv6 != nil { + if our != "" { + our += "/" } - strAddresses = append(strAddresses, `"`+value.IP.String()+`"`) + our += res.OurIPv6.String() + } + + if res.OK() { + tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck + "hostname": res.Host, + "resolved": strings.Join(resolved, ", "), + "our": our, + }) + return true + } + + mismatched := []string{} + + if res.OurIPv4 != nil && !res.IPv4Match { + mismatched = append(mismatched, "IPv4") + } + + if res.OurIPv6 != nil && !res.IPv6Match { + mismatched = append(mismatched, "IPv6") } tplEDNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck - "hostname": d.conf.Secret.Host, - "resolved": strings.Join(strAddresses, ", "), - "ip4": ourIP4, - "ip6": ourIP6, + "hostname": res.Host, + "resolved": strings.Join(resolved, ", "), + "our": our, + "families": strings.Join(mismatched, ", "), }) return false diff --git a/internal/cli/run_proxy.go b/internal/cli/run_proxy.go index 5d9e63e98..0733c7ae8 100644 --- a/internal/cli/run_proxy.go +++ b/internal/cli/run_proxy.go @@ -6,6 +6,7 @@ import ( "net" "os" "strings" + "time" "github.com/9seconds/mtg/v2/antireplay" "github.com/9seconds/mtg/v2/events" @@ -209,78 +210,64 @@ func makeEventStream(conf *config.Config, logger mtglib.Logger) (mtglib.EventStr } func warnSNIMismatch(conf *config.Config, ntw mtglib.Network, log mtglib.Logger) { - host := conf.Secret.Host - if host == "" { + if conf.Secret.Host == "" { return } - addresses, err := net.DefaultResolver.LookupIPAddr(context.Background(), host) - if err != nil { - log.BindStr("hostname", host). - WarningError("SNI-DNS check: cannot resolve secret hostname", err) - return - } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() - ourIP4 := conf.PublicIPv4.Get(nil) - if ourIP4 == nil { - ourIP4 = getIP(ntw, "tcp4") - } + res := runSNICheck(ctx, net.DefaultResolver, conf, ntw) - ourIP6 := conf.PublicIPv6.Get(nil) - if ourIP6 == nil { - ourIP6 = getIP(ntw, "tcp6") + if res.ResolveErr != nil { + log.BindStr("hostname", res.Host). + WarningError("SNI-DNS check: cannot resolve secret hostname", res.ResolveErr) + return } - if ourIP4 == nil && ourIP6 == nil { + if !res.Known() { log.Warning("SNI-DNS check: cannot detect public IP address; set public-ipv4/public-ipv6 in config or run 'mtg doctor'") return } - v4Match := ourIP4 == nil - v6Match := ourIP6 == nil - - for _, addr := range addresses { - if ourIP4 != nil && addr.IP.String() == ourIP4.String() { - v4Match = true - } - - if ourIP6 != nil && addr.IP.String() == ourIP6.String() { - v6Match = true - } + if len(res.Resolved) == 0 { + log.BindStr("hostname", res.Host). + Warning("SNI-DNS check: secret hostname does not resolve to any address") + return } - if v4Match && v6Match { + if res.OK() { return } - resolved := make([]string, 0, len(addresses)) - for _, addr := range addresses { - resolved = append(resolved, addr.IP.String()) + resolved := make([]string, 0, len(res.Resolved)) + for _, ip := range res.Resolved { + resolved = append(resolved, ip.String()) } our := "" - if ourIP4 != nil { - our = ourIP4.String() + if res.OurIPv4 != nil { + our = res.OurIPv4.String() } - if ourIP6 != nil { + if res.OurIPv6 != nil { if our != "" { our += "/" } - our += ourIP6.String() + our += res.OurIPv6.String() } - entry := log.BindStr("hostname", host). + entry := log.BindStr("hostname", res.Host). BindStr("resolved", strings.Join(resolved, ", ")). BindStr("public_ip", our) - if ourIP4 != nil { - entry = entry.BindStr("ipv4_match", fmt.Sprintf("%t", v4Match)) + if res.OurIPv4 != nil { + entry = entry.BindStr("ipv4_match", fmt.Sprintf("%t", res.IPv4Match)) } - if ourIP6 != nil { - entry = entry.BindStr("ipv6_match", fmt.Sprintf("%t", v6Match)) + if res.OurIPv6 != nil { + entry = entry.BindStr("ipv6_match", fmt.Sprintf("%t", res.IPv6Match)) } entry.Warning("SNI-DNS mismatch: secret hostname does not resolve to this server's public IP. " + diff --git a/internal/cli/sni_check.go b/internal/cli/sni_check.go new file mode 100644 index 000000000..a365cfd48 --- /dev/null +++ b/internal/cli/sni_check.go @@ -0,0 +1,112 @@ +package cli + +import ( + "context" + "net" + "sync" + + "github.com/9seconds/mtg/v2/internal/config" + "github.com/9seconds/mtg/v2/mtglib" +) + +// sniCheckResult captures the outcome of comparing the secret hostname's DNS +// records with this server's public IP addresses. +// +// IPv4Match/IPv6Match are true when either a matching record was found, or +// when the corresponding public IP could not be detected — in which case +// there is nothing to compare against. +type sniCheckResult struct { + Host string + Resolved []net.IP + OurIPv4 net.IP + OurIPv6 net.IP + IPv4Match bool + IPv6Match bool + ResolveErr error +} + +// Known reports whether at least one public IP family was detected. +func (r sniCheckResult) Known() bool { + return r.OurIPv4 != nil || r.OurIPv6 != nil +} + +// OK reports whether the check produced a clean result: the hostname was +// resolved, at least one public IP family is known, and every known family +// matches a resolved record. +func (r sniCheckResult) OK() bool { + if r.Host == "" { + return true + } + + if r.ResolveErr != nil || !r.Known() { + return false + } + + return r.IPv4Match && r.IPv6Match +} + +// runSNICheck resolves conf.Secret.Host and compares the result with the +// server's public IPv4 and IPv6. Public IPs come from config first and fall +// back to on-the-fly detection via ntw. IP detection for the two families +// runs concurrently and honors ctx — callers should supply a deadline, +// since the HTTP fallback can otherwise block startup indefinitely. +func runSNICheck(ctx context.Context, + resolver *net.Resolver, + conf *config.Config, + ntw mtglib.Network, +) sniCheckResult { + res := sniCheckResult{Host: conf.Secret.Host} + + if res.Host == "" { + res.IPv4Match = true + res.IPv6Match = true + + return res + } + + addrs, err := resolver.LookupIPAddr(ctx, res.Host) + if err != nil { + res.ResolveErr = err + + return res + } + + res.Resolved = make([]net.IP, 0, len(addrs)) + for _, a := range addrs { + res.Resolved = append(res.Resolved, a.IP) + } + + endpoints := resolvePublicIPEndpoints(conf.Network.PublicIPEndpoints) + wg := sync.WaitGroup{} + + wg.Go(func() { + res.OurIPv4 = conf.PublicIPv4.Get(nil) + if res.OurIPv4 == nil { + res.OurIPv4 = getIP(ctx, ntw, "tcp4", endpoints) + } + }) + + wg.Go(func() { + res.OurIPv6 = conf.PublicIPv6.Get(nil) + if res.OurIPv6 == nil { + res.OurIPv6 = getIP(ctx, ntw, "tcp6", endpoints) + } + }) + + wg.Wait() + + res.IPv4Match = res.OurIPv4 == nil + res.IPv6Match = res.OurIPv6 == nil + + for _, ip := range res.Resolved { + if res.OurIPv4 != nil && ip.String() == res.OurIPv4.String() { + res.IPv4Match = true + } + + if res.OurIPv6 != nil && ip.String() == res.OurIPv6.String() { + res.IPv6Match = true + } + } + + return res +} diff --git a/internal/cli/utils.go b/internal/cli/utils.go index db8af549b..450cab935 100644 --- a/internal/cli/utils.go +++ b/internal/cli/utils.go @@ -8,10 +8,41 @@ import ( "strings" "github.com/9seconds/mtg/v2/essentials" + "github.com/9seconds/mtg/v2/internal/config" "github.com/9seconds/mtg/v2/mtglib" ) -func getIP(ntw mtglib.Network, protocol string) net.IP { +// defaultPublicIPEndpoints is the fallback used when network.public-ip-endpoints +// is not set in config. Each endpoint must return the client's public IP as a +// single address in the plain-text response body. +var defaultPublicIPEndpoints = []string{ + "https://ifconfig.co", + "https://icanhazip.com", + "https://ifconfig.me", +} + +// resolvePublicIPEndpoints returns the configured endpoint list, falling back +// to defaultPublicIPEndpoints when none are configured. +func resolvePublicIPEndpoints(configured []config.TypeHttpsURL) []string { + if len(configured) == 0 { + return defaultPublicIPEndpoints + } + + out := make([]string, 0, len(configured)) + for _, u := range configured { + if v := u.Get(nil); v != nil { + out = append(out, v.String()) + } + } + + if len(out) == 0 { + return defaultPublicIPEndpoints + } + + return out +} + +func getIP(ctx context.Context, ntw mtglib.Network, protocol string, endpoints []string) net.IP { dialer := ntw.NativeDialer() client := ntw.MakeHTTPClient(func(ctx context.Context, network, address string) (essentials.Conn, error) { conn, err := dialer.DialContext(ctx, protocol, address) @@ -21,19 +52,26 @@ func getIP(ntw mtglib.Network, protocol string) net.IP { return essentials.WrapNetConn(conn), err }) - req, err := http.NewRequest(http.MethodGet, "https://ifconfig.co", nil) //nolint: noctx - if err != nil { - panic(err) + for _, endpoint := range endpoints { + if ip := fetchPublicIP(ctx, client, endpoint); ip != nil { + return ip + } } - req.Header.Add("Accept", "text/plain") + return nil +} - resp, err := client.Do(req) +func fetchPublicIP(ctx context.Context, client *http.Client, endpoint string) net.IP { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil } - if resp.StatusCode != http.StatusOK { + req.Header.Set("Accept", "text/plain") + req.Header.Set("User-Agent", "curl/8") + + resp, err := client.Do(req) + if err != nil { return nil } @@ -42,6 +80,10 @@ func getIP(ntw mtglib.Network, protocol string) net.IP { resp.Body.Close() //nolint: errcheck }() + if resp.StatusCode != http.StatusOK { + return nil + } + data, err := io.ReadAll(resp.Body) if err != nil { return nil diff --git a/internal/config/config.go b/internal/config/config.go index 70e233f17..32b65e6c5 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -71,9 +71,10 @@ type Config struct { Interval TypeDuration `json:"interval"` Count TypeConcurrency `json:"count"` } `json:"keepAlive"` - DOHIP TypeIP `json:"dohIp"` - DNS TypeDNSURI `json:"dns"` - Proxies []TypeProxyURL `json:"proxies"` + DOHIP TypeIP `json:"dohIp"` + DNS TypeDNSURI `json:"dns"` + Proxies []TypeProxyURL `json:"proxies"` + PublicIPEndpoints []TypeHttpsURL `json:"publicIpEndpoints"` } `json:"network"` Stats struct { StatsD struct { diff --git a/internal/config/parse.go b/internal/config/parse.go index bdc76162d..951ba650b 100644 --- a/internal/config/parse.go +++ b/internal/config/parse.go @@ -66,9 +66,10 @@ type tomlConfig struct { Interval string `toml:"interval" json:"interval,omitempty"` Count uint `toml:"count" json:"count,omitempty"` } `toml:"keep-alive" json:"keepAlive,omitempty"` - DOHIP string `toml:"doh-ip" json:"dohIp,omitempty"` - DNS string `toml:"dns" json:"dns,omitempty"` - Proxies []string `toml:"proxies" json:"proxies,omitempty"` + DOHIP string `toml:"doh-ip" json:"dohIp,omitempty"` + DNS string `toml:"dns" json:"dns,omitempty"` + Proxies []string `toml:"proxies" json:"proxies,omitempty"` + PublicIPEndpoints []string `toml:"public-ip-endpoints" json:"publicIpEndpoints,omitempty"` } `toml:"network" json:"network,omitempty"` Stats struct { StatsD struct {