Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,19 @@ internal class ConnectionViewModel @VisibleForTesting constructor(

init {
viewModelScope.launch {
// Pre-set the mTLS flag before emitting the auth URL. If the phone is currently
// connected to an mTLS-protected instance whose certificate covers this host, the
// onboarding WebView will reuse the live TLS session (session resumption) and
// onReceivedClientCertRequest will never fire — pre-setting the flag ensures the
// Wear OS cert-selection screen is not silently skipped.
// Matching against the cert's SANs/CN rather than mere key presence avoids a false
// positive when the user has multiple servers where only one requires mTLS.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should rewrite this and remove the implementation details of the function that you already document on the function itself.

try {
webViewClient.preInitializeTLSClientAuthState(rawUrl.toHttpUrl().host)
} catch (_: Exception) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to rethrow the cancelation properly still since you catch Exception.

// Malformed URL: preInitializeTLSClientAuthState is a best-effort optimisation;
// buildAuthUrl below will surface the error to the user.
}
buildAuthUrl(rawUrl)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ import android.webkit.WebViewClient
import androidx.annotation.VisibleForTesting
import io.homeassistant.companion.android.common.data.keychain.KeyChainRepository
import java.lang.ref.WeakReference
import java.net.InetAddress
import java.security.PrivateKey
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import javax.security.auth.x500.X500Principal
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
Expand All @@ -39,6 +41,103 @@ open class TLSWebViewClient(private var keyChainRepository: KeyChainRepository)
private var key: PrivateKey? = null
private var chain: Array<X509Certificate>? = null

/**
* Pre-initializes [isTLSClientAuthNeeded] by verifying whether the currently loaded
* certificate chain covers [targetHost], to handle TLS session resumption.
*
* Normally [isTLSClientAuthNeeded] is set when [onReceivedClientCertRequest] fires during
* a full TLS handshake. However, when TLS session resumption occurs (the WebView reuses an
* existing session from the same process), the server does not issue a new
* `CertificateRequest`, so [onReceivedClientCertRequest] is never called — even if the
* server requires a client certificate.
*
* This is the root cause of the Wear OS onboarding mTLS failure: the main app WebView
* establishes a TLS session while the user is connected; the onboarding WebView immediately
* resumes it, bypassing the callback that would reveal the mTLS requirement to the
* navigation layer.
*
* The fix inspects the in-memory certificate chain (if any) and checks whether it covers
* [targetHost] via its Subject Alternative Names (SANs), or its Common Name (CN) as a
* fallback. This avoids a false positive when the user has multiple servers where only one
* requires mTLS: the loaded cert will not match the non-mTLS server's hostname.
*
* If the app was force-stopped first (clearing in-memory state) no TLS session can be
* resumed either, so [onReceivedClientCertRequest] will fire naturally on the fresh handshake.
*
* Must be called **before** the WebView starts loading (i.e. before the URL is emitted).
* Idempotent: if the flag is already `true` (set by a real handshake) this is a no-op.
*
* @param targetHost the hostname of the server being connected to (e.g. "myha.example.com")
*/
fun preInitializeTLSClientAuthState(targetHost: String) {
if (isTLSClientAuthNeeded) return
val cert = keyChainRepository.getCertificateChain()?.firstOrNull() ?: return
isTLSClientAuthNeeded = certCoversHost(cert, targetHost)
}

/**
* Returns `true` if [cert] is valid for [host].
*
* Checks Subject Alternative Names (SANs) first — both DNS names (with wildcard support)
* and IP addresses. Falls back to the Common Name (CN) in the Subject DN if no SANs are
* present, matching the behaviour of legacy TLS stacks.
*/
@VisibleForTesting
internal fun certCoversHost(cert: X509Certificate, host: String): Boolean {
val sans: Collection<List<*>>? = try {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep this private and test through the public API.

cert.subjectAlternativeNames
} catch (_: Exception) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
} catch (_: Exception) {
} catch (_: CertificateParsingException) {

null
}

return if (!sans.isNullOrEmpty()) {
sans.any { san ->
val type = san[0] as? Int ?: return@any false
Comment thread
smhc marked this conversation as resolved.
when (type) {
2 -> { // dNSName — returned as String
val value = san[1] as? String ?: return@any false
hostMatchesSan(host, value)
}
7 -> { // iPAddress — returned as ByteArray per the Java X.509 API contract
val ipBytes = san[1] as? ByteArray ?: return@any false
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to not have this magic numbers? 2 an 7 or at least link it to the proper documentation.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure it's always a ByteArray here and it never could be a String?

try {
host.equals(InetAddress.getByAddress(ipBytes).hostAddress, ignoreCase = true)
} catch (_: Exception) {
false
}
}
else -> false
}
}
} else {
// Fallback: extract CN from the Subject DN.
// getName(RFC2253) uses comma as AVA separator; commas inside values are escaped
// as \, which we don't need to handle because hostnames never contain commas.
val dn = cert.subjectX500Principal.getName(X500Principal.RFC2253)
val cn = dn.splitToSequence(",")
.map { it.trim() }
.firstOrNull { it.startsWith("CN=", ignoreCase = true) }
// Use substring-after-'=' so the extraction is case-insensitive (matches the
// startsWith check above) rather than a case-sensitive removePrefix("CN=").
?.let { it.substring(it.indexOf('=') + 1).trim() }
Comment thread
smhc marked this conversation as resolved.
cn != null && hostMatchesSan(host, cn)
}
}

/**
* Matches [host] against a SAN value that may contain a leading wildcard.
*
* A wildcard (`*.example.com`) covers any single label: `foo.example.com` matches but
* `foo.bar.example.com` and `example.com` do not (per RFC 2818 §3.1).
*/
private fun hostMatchesSan(host: String, san: String): Boolean {
if (!san.startsWith("*.")) return host.equals(san, ignoreCase = true)
val suffix = san.substring(1) // ".example.com"
if (!host.endsWith(suffix, ignoreCase = true)) return false
val wildcardLabel = host.substring(0, host.length - suffix.length)
return wildcardLabel.isNotEmpty() && !wildcardLabel.contains('.')
}

private fun getActivity(context: Context?): Activity? {
if (context == null) {
return null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import io.mockk.mockk
import io.mockk.mockkStatic
import io.mockk.slot
import io.mockk.verify
import java.net.InetAddress
import java.net.URL
import java.security.cert.X509Certificate
import kotlin.reflect.KClass
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.MutableSharedFlow
Expand Down Expand Up @@ -417,4 +419,128 @@ class ConnectionViewModelTest {
assertEquals("class java.lang.UnsatisfiedLinkError", error.rawErrorType)
verify(exactly = 1) { connectivityCheckRepository.runChecks(rawUrl) }
}

// --- preInitializeTLSClientAuthState / cert-host matching tests ---

@Test
fun `Given cert with exact DNS SAN matching target host when initializing then isTLSClientAuthNeeded is pre-set to true`() = runTest {
val cert = mockk<X509Certificate> {
every { subjectAlternativeNames } returns listOf(listOf(2, "homeassistant.local"))
}
every { keyChainRepository.getCertificateChain() } returns arrayOf(cert)

val viewModel = ConnectionViewModel(
"http://homeassistant.local:8123",
webViewClientFactory,
connectivityCheckRepository,
)
advanceUntilIdle()

assertTrue(viewModel.webViewClient.isTLSClientAuthNeeded)
}

@Test
fun `Given cert with wildcard DNS SAN matching target host when initializing then isTLSClientAuthNeeded is pre-set to true`() = runTest {
val cert = mockk<X509Certificate> {
every { subjectAlternativeNames } returns listOf(listOf(2, "*.example.com"))
}
every { keyChainRepository.getCertificateChain() } returns arrayOf(cert)

val viewModel = ConnectionViewModel(
"https://ha.example.com",
webViewClientFactory,
connectivityCheckRepository,
)
advanceUntilIdle()

assertTrue(viewModel.webViewClient.isTLSClientAuthNeeded)
}

Comment thread
smhc marked this conversation as resolved.
@Test
fun `Given cert with DNS SAN for a different host when initializing then isTLSClientAuthNeeded remains false`() = runTest {
val cert = mockk<X509Certificate> {
every { subjectAlternativeNames } returns listOf(listOf(2, "other-server.example.com"))
}
every { keyChainRepository.getCertificateChain() } returns arrayOf(cert)

val viewModel = ConnectionViewModel(
"http://homeassistant.local:8123",
webViewClientFactory,
connectivityCheckRepository,
)
advanceUntilIdle()

assertFalse(viewModel.webViewClient.isTLSClientAuthNeeded)
}

@Test
fun `Given no certificate chain in memory when initializing then isTLSClientAuthNeeded remains false`() = runTest {
every { keyChainRepository.getCertificateChain() } returns null

val viewModel = ConnectionViewModel(
"http://homeassistant.local:8123",
webViewClientFactory,
connectivityCheckRepository,
)
advanceUntilIdle()

assertFalse(viewModel.webViewClient.isTLSClientAuthNeeded)
}

@Test
fun `Given cert with CN matching target host and no SANs when initializing then isTLSClientAuthNeeded is pre-set to true`() = runTest {
val cert = mockk<X509Certificate> {
every { subjectAlternativeNames } returns null
every { subjectX500Principal } returns mockk {
every { getName("RFC2253") } returns "CN=homeassistant.local,O=Home Assistant"
}
}
every { keyChainRepository.getCertificateChain() } returns arrayOf(cert)

val viewModel = ConnectionViewModel(
"http://homeassistant.local:8123",
webViewClientFactory,
connectivityCheckRepository,
)
advanceUntilIdle()

assertTrue(viewModel.webViewClient.isTLSClientAuthNeeded)
}

@Test
fun `Given cert with wildcard SAN that does not cover a multi-label subdomain when initializing then isTLSClientAuthNeeded remains false`() = runTest {
val cert = mockk<X509Certificate> {
// *.example.com covers foo.example.com but not foo.bar.example.com
every { subjectAlternativeNames } returns listOf(listOf(2, "*.example.com"))
}
every { keyChainRepository.getCertificateChain() } returns arrayOf(cert)

val viewModel = ConnectionViewModel(
"https://foo.bar.example.com",
webViewClientFactory,
connectivityCheckRepository,
)
advanceUntilIdle()

assertFalse(viewModel.webViewClient.isTLSClientAuthNeeded)
}

@Test
fun `Given cert with IP address SAN matching target host when initializing then isTLSClientAuthNeeded is pre-set to true`() = runTest {
// getSubjectAlternativeNames() returns iPAddress (type 7) as a ByteArray, not a String
val ipBytes = InetAddress.getByName("192.168.1.100").address
val cert = mockk<X509Certificate> {
every { subjectAlternativeNames } returns listOf(listOf(7, ipBytes))
}
every { keyChainRepository.getCertificateChain() } returns arrayOf(cert)

val viewModel = ConnectionViewModel(
"https://192.168.1.100",
webViewClientFactory,
connectivityCheckRepository,
)
advanceUntilIdle()

assertTrue(viewModel.webViewClient.isTLSClientAuthNeeded)
}
Comment thread
smhc marked this conversation as resolved.
}
Loading