aboutsummaryrefslogtreecommitdiff
path: root/tests/network-diagnostics/test_network_diagnostics.py
blob: 1a8073f4231bb1ab1b8059255c17e368327a6d1c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""Tests for run_network_diagnostics in the VM testing harness.

run_network_diagnostics is the VM install pre-flight network check. It
collects read-only facts (interfaces, default route, resolver) first and
unconditionally, then runs every reachability check -- DNS, HTTP egress,
TLS egress, Arch mirror, AUR -- accumulating failures and reporting them all
at the end. Facts are printed regardless of pass/fail, so a failed install
still leaves the evidence. Generic checks (DNS/egress/TLS) are kept separate
from Arch-specific checks (mirror/AUR) so a DNS failure is named as DNS, not
misattributed to the mirror. Returns 0 when all checks pass, non-zero
otherwise, preserving the caller's success/failure contract.

These tests exercise the REAL function body (sourced out of
network-diagnostics.sh, not a copy) with:
  - stub logging functions (section/step/info/success/error/warn) that just
    echo, so output is assertable;
  - a fake `sshpass` on PATH that dispatches on the remote command string and
    returns canned exit codes driven by FAKE_*_FAIL env vars. This is the
    system boundary -- the real function shells out through
    `sshpass ... ssh ... "<remote cmd>"`, and the fake stands in for the VM.

Run from repo root:
    python3 -m unittest tests.network-diagnostics.test_network_diagnostics
"""

import os
import shutil
import subprocess
import tempfile
import unittest


REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
NETDIAG = os.path.join(
    REPO_ROOT, "scripts", "testing", "lib", "network-diagnostics.sh"
)

# A fake sshpass. The real invocation is:
#   sshpass -p <pw> ssh <opts> -p <port> root@<host> "<remote cmd>"
# so the remote command is always the last argument. This stub inspects it and
# returns a canned exit code per check, driven by FAKE_*_FAIL env vars. Fact
# commands (ip/route/resolv) always succeed and print sample output so the
# evidence-collection path is exercised.
FAKE_SSHPASS = r"""#!/bin/bash
cmd="${@: -1}"
case "$cmd" in
    *"ip -brief addr"*)
        echo "lo    UNKNOWN 127.0.0.1/8"
        echo "eth0  UP      10.0.2.15/24"
        exit 0 ;;
    *"ip route show default"*)
        echo "default via 10.0.2.2 dev eth0"
        exit 0 ;;
    *"resolv.conf"*)
        echo "nameserver 10.0.2.3"
        exit 0 ;;
    *"getent hosts"*)
        [ "${FAKE_DNS_FAIL:-0}" = "1" ] && exit 2
        exit 0 ;;
    *"https://archlinux.org"*)
        [ "${FAKE_TLS_FAIL:-0}" = "1" ] && exit 7
        exit 0 ;;
    *"http://archlinux.org"*)
        [ "${FAKE_HTTP_FAIL:-0}" = "1" ] && exit 7
        exit 0 ;;
    *"geo.mirror.pkgbuild.com"*)
        [ "${FAKE_MIRROR_FAIL:-0}" = "1" ] && exit 1
        exit 0 ;;
    *"aur.archlinux.org"*)
        [ "${FAKE_AUR_FAIL:-0}" = "1" ] && exit 1
        exit 0 ;;
    *)
        exit 0 ;;
esac
"""

# Stub logging functions plus the sourced real file, then call the function.
WRAPPER = r"""#!/bin/bash
section() { echo "=== $1 ==="; }
step()    { echo "  -> $1"; }
info()    { echo "[i] $1"; }
success() { echo "[OK] $1"; }
warn()    { echo "[!] $1" >&2; }
error()   { echo "[X] $1" >&2; }
source "$1"
run_network_diagnostics
"""


class NetworkDiagnosticsHarness(unittest.TestCase):
    def setUp(self):
        self.tmp = tempfile.mkdtemp(prefix="netdiag-test-")
        self.fakebin = os.path.join(self.tmp, "bin")
        os.makedirs(self.fakebin)
        sshpass = os.path.join(self.fakebin, "sshpass")
        with open(sshpass, "w") as f:
            f.write(FAKE_SSHPASS)
        os.chmod(sshpass, 0o755)
        self.wrapper = os.path.join(self.tmp, "run.sh")
        with open(self.wrapper, "w") as f:
            f.write(WRAPPER)
        os.chmod(self.wrapper, 0o755)

    def tearDown(self):
        shutil.rmtree(self.tmp, ignore_errors=True)

    def run_diag(self, results_dir=None, **fail_flags):
        env = dict(os.environ)
        env["PATH"] = self.fakebin + os.pathsep + env["PATH"]
        # Keep the harness deterministic regardless of the host's SSH config.
        env["SSH_OPTS"] = "-o StrictHostKeyChecking=no"
        env["ROOT_PASSWORD"] = "archsetup"
        env["SSH_PORT"] = "22"
        env["VM_IP"] = "localhost"
        if results_dir is not None:
            env["TEST_RESULTS_DIR"] = results_dir
        for k, v in fail_flags.items():
            env[k] = v
        return subprocess.run(
            ["bash", self.wrapper, NETDIAG],
            capture_output=True, text=True, timeout=20, env=env,
        )

    # --- Normal case: everything reachable -----------------------------

    def test_all_checks_pass_returns_zero(self):
        r = self.run_diag()
        self.assertEqual(r.returncode, 0, r.stdout + r.stderr)
        self.assertIn("all checks passed", r.stdout)

    def test_facts_collected_on_success(self):
        r = self.run_diag()
        out = r.stdout + r.stderr
        self.assertIn("10.0.2.15/24", out)          # interface fact
        self.assertIn("default via 10.0.2.2", out)  # route fact
        self.assertIn("nameserver 10.0.2.3", out)   # resolver fact

    # --- DNS-failure case ----------------------------------------------

    def test_dns_failure_returns_nonzero(self):
        r = self.run_diag(FAKE_DNS_FAIL="1")
        self.assertNotEqual(r.returncode, 0)

    def test_dns_failure_names_dns_not_mirror(self):
        r = self.run_diag(FAKE_DNS_FAIL="1")
        out = r.stdout + r.stderr
        self.assertIn("DNS resolution failed", out)
        # A DNS failure must not be misreported as a mirror failure. With only
        # DNS failing, the mirror check still runs and passes.
        self.assertNotIn("Cannot reach Arch mirrors", out)

    def test_dns_failure_still_collects_evidence(self):
        # The whole point of the change: evidence is gathered before any check
        # can bail, so a DNS failure still leaves the facts in the output.
        r = self.run_diag(FAKE_DNS_FAIL="1")
        out = r.stdout + r.stderr
        self.assertIn("10.0.2.15/24", out)
        self.assertIn("default via 10.0.2.2", out)
        self.assertIn("nameserver 10.0.2.3", out)

    def test_dns_failure_summary_lists_the_failure(self):
        r = self.run_diag(FAKE_DNS_FAIL="1")
        out = r.stdout + r.stderr
        self.assertIn("found 1 failure", out)
        self.assertIn("getent hosts archlinux.org", out)

    # --- Mirror-only-failure case --------------------------------------

    def test_mirror_only_failure_returns_nonzero(self):
        r = self.run_diag(FAKE_MIRROR_FAIL="1")
        self.assertNotEqual(r.returncode, 0)

    def test_mirror_only_failure_generic_checks_pass(self):
        r = self.run_diag(FAKE_MIRROR_FAIL="1")
        out = r.stdout + r.stderr
        # Generic checks are healthy; only the Arch-specific mirror check fails.
        self.assertIn("DNS resolution OK", out)
        self.assertIn("HTTP egress OK", out)
        self.assertIn("TLS/HTTPS egress OK", out)
        self.assertIn("Cannot reach Arch mirrors", out)
        self.assertNotIn("DNS resolution failed", out)

    def test_mirror_only_failure_summary_names_mirror(self):
        r = self.run_diag(FAKE_MIRROR_FAIL="1")
        out = r.stdout + r.stderr
        self.assertIn("geo.mirror.pkgbuild.com", out)

    # --- All checks run: multiple failures are all reported ------------

    def test_multiple_failures_all_reported(self):
        r = self.run_diag(FAKE_DNS_FAIL="1", FAKE_AUR_FAIL="1")
        out = r.stdout + r.stderr
        self.assertIn("found 2 failure", out)
        self.assertIn("getent hosts archlinux.org", out)
        self.assertIn("aur.archlinux.org", out)

    # --- Raw outputs saved to the results dir --------------------------

    def test_raw_facts_saved_to_results_dir(self):
        results = os.path.join(self.tmp, "results")
        os.makedirs(results)
        self.run_diag(results_dir=results)
        for slug, needle in (
            ("ip-addr", "10.0.2.15/24"),
            ("ip-route", "default via 10.0.2.2"),
            ("resolv-conf", "nameserver 10.0.2.3"),
        ):
            path = os.path.join(results, "netdiag-%s.txt" % slug)
            self.assertTrue(os.path.exists(path), "missing " + path)
            with open(path) as f:
                self.assertIn(needle, f.read())


if __name__ == "__main__":
    unittest.main()