diff --git a/src/ldfi/ldfi/__init__.py b/src/ldfi/ldfi/__init__.py index 0e28bbcd..cb785b6f 100644 --- a/src/ldfi/ldfi/__init__.py +++ b/src/ldfi/ldfi/__init__.py @@ -4,23 +4,67 @@ import os import sqlite3 import z3 -from typing import (List, Set, Tuple) +from typing import (List, Set) from pkg_resources import get_distribution -# Logging -logging.basicConfig(filename='/tmp/ldfi.log', - filemode='w', - level=logging.DEBUG) +class Config: + def __init__(self, + test_id: int, + run_ids: List[int], + eff: int, + max_crashes: int): + self.test_id = test_id + self.run_ids = run_ids + self.eff = eff + self.max_crashes = max_crashes + + # TODO(stevan): make logging configurable. + logging.basicConfig(level=logging.DEBUG) + +def create_config() -> Config: + parser = argparse.ArgumentParser(description='Lineage-driven fault injection.') + parser.add_argument('--eff', metavar='TIME', type=int, required=True, + help='the time when finite failures end') + parser.add_argument('--crashes', metavar='INT', type=int, required=True, + help='the max amount of node crashes') + parser.add_argument('--test-id', metavar='TEST_ID', type=int, required=True, + help='the test id') + parser.add_argument('--run-ids', metavar='RUN_ID', type=int, nargs='+', required=True, + help='the run ids') + parser.add_argument('--version', '-v', action='version', + version=get_distribution(__name__).version) + + args = parser.parse_args() + + return Config(args.test_id, args.run_ids, args.eff, args.crashes) + +class Data: + def __init__(self, + previous_faults: List[List[str]], + potential_faults: List[List[str]], + crashes: Set[str]): + self.previous_faults = previous_faults + self.potential_faults = potential_faults + self.crashes = crashes + +class Event: + def __init__(self, + test_id: int, + run_id: int, + faults: str, + version: str, + statistics: str): + self.test_id = test_id + self.run_id = run_id + self.faults = faults + self.version = version + self.statistics = statistics class Storage: - def load_previous_faults(self, test_id: int, run_ids: List[int]) -> List[List[str]]: - pass - def load_potential_faults(self, test_id: int, run_ids: List[int], - eff: int, max_crashes: int, - previous_faults: List[List[str]]) -> Tuple[List[List[str]], Set[str]]: + def load(self, config: Config) -> Data: pass - def store_faults(self, test_id: int, run_id: int, faults: str, - version: str, statistics: str): + + def store(self, event: Event): pass class SqliteStorage(Storage): @@ -31,16 +75,18 @@ def __init__(self): self.conn.row_factory = sqlite3.Row self.c = self.conn.cursor() - def load_previous_faults(self, test_id: int, run_ids: List[int]) -> List[List[str]]: + def load(self, config: Config) -> Data: previous_faults = [] + potential_faults = [] + crashes = set() - for run_id in run_ids: + for run_id in config.run_ids: prods = [] self.c.execute("""SELECT faults FROM faults WHERE test_id = (?) - AND run_id = (?)""", (test_id, run_id)) + AND run_id = (?)""", (config.test_id, run_id)) for r in self.c: - for fault in eval(r['faults']): + for fault in eval(r['faults'])['faults']: # NOTE: eval introduces a space after the colon in a # dict, we need to remove this otherwise the variables # of the SAT expression will differ. @@ -48,15 +94,6 @@ def load_previous_faults(self, test_id: int, run_ids: List[int]) -> List[List[st logging.debug("fault: '%s'", str(fault).replace(": ", ":")) previous_faults.append(prods) - return previous_faults - - def load_potential_faults(self, test_id: int, run_ids: List[int], - eff: int, max_crashes: int, - previous_faults: List[List[str]]) -> Tuple[List[List[str]], Set[str]]: - potential_faults = [] - crashes = set() - - for run_id in run_ids: sums = [] self.c.execute("""SELECT * FROM network_trace WHERE test_id = (?) @@ -64,14 +101,16 @@ def load_potential_faults(self, test_id: int, run_ids: List[int], AND kind <> 'timer' AND NOT (`from` LIKE 'client:%') AND NOT (`to` LIKE 'client:%')""", - (test_id, run_id)) + (config.test_id, run_id)) for r in self.c: - if r['at'] < eff: + logging.debug("network trace: {'kind': %s, 'from': %s, 'to': %s, 'at': %s}", + r['kind'], r['from'], r['to'], r['at']) + if r['at'] < config.eff: omission = ("{'kind':'omission', 'from':'%s', 'to':'%s', 'at':%d}" % (r['from'], r['to'], r['at'])) if omission not in previous_faults: sums.append(omission) - if max_crashes > 0: + if config.max_crashes > 0: crash = ("{'kind':'crash', 'from':'%s', 'at':%d}" % (r['from'], r['sent_logical_time'])) if crash not in previous_faults: @@ -79,113 +118,104 @@ def load_potential_faults(self, test_id: int, run_ids: List[int], crashes.add(crash) potential_faults.append(sums) - return (potential_faults, crashes) + return Data(previous_faults, potential_faults, crashes) - def store_faults(self, test_id: int, run_id: int, faults: str, - version: str, statistics: str): - self.c.execute("""INSERT INTO faults(test_id, run_id, faults, version, statistics) - VALUES(?, ?, ?, ?, ?)""", - (test_id, run_id, faults, version, statistics)) - - self.conn.commit() - self.c.close() - -def order(d): - return("%s %s %s %d" % (d['kind'], d['from'], d.get('to', ""), d['at'])) + def store(self, event: Event): + self.c.execute("""INSERT INTO faults(test_id, run_id, faults, version, statistics) + VALUES(?, ?, ?, ?, ?)""", + (event.test_id, event.run_id, event.faults, event.version, + event.statistics)) + self.conn.commit() -def main(): - # Command-line argument parsing. - parser = argparse.ArgumentParser(description='Lineage-driven fault injection.') - parser.add_argument('--eff', metavar='TIME', type=int, required=True, - help='the time when finite failures end') - parser.add_argument('--crashes', metavar='INT', type=int, required=True, - help='the max amount of node crashes') - parser.add_argument('--test-id', metavar='TEST_ID', type=int, required=True, - help='the test id') - parser.add_argument('--run-ids', metavar='RUN_ID', type=int, nargs='+', required=True, - help='the run ids') - parser.add_argument('--json', action='store_true', help='output in JSON format?') - parser.add_argument('--version', '-v', action='version', - version=get_distribution(__name__).version) +def sanity_check(data): + if data.previous_faults == [[]]: + len_previous_faults = 0 + else: + len_previous_faults = len(data.previous_faults) - args = parser.parse_args() + assert(len(data.potential_faults) == len_previous_faults + 1) - # Load network traces from the database. - storage = SqliteStorage() - previous_faults = storage.load_previous_faults(args.test_id, args.run_ids) - logging.debug(str(previous_faults)) - (products, crashes) = storage.load_potential_faults(args.test_id, args.run_ids, args.eff, - args.crashes, previous_faults) - - # Sanity check. - for i, run_id in enumerate(args.run_ids): - if not products[i] and not crashes: - print("Error: couldn't find a network trace for test id: %d, and run id: %d." % - (args.test_id, run_id)) - exit(1) - - # Create and solve SAT formula. - for i, sum in enumerate(products): +def create_sat_formula(config, data): + potential_faults = [] + for i, sum in enumerate(data.potential_faults): kept = z3.Bools(sum) - logging.debug("i: %d", i) logging.debug("kept: " + str(kept)) - if previous_faults[i]: - drop = z3.Bools(previous_faults[i]) + drop = [] + if data.previous_faults[i-1]: + drop = z3.Bools(data.previous_faults[i-1]) logging.debug("drop: " + str(drop)) + if drop: + potential_faults.append(z3.Or(z3.Or(kept), z3.Not(z3.And(drop)))) else: - drop = False - products[i] = z3.Or(z3.Or(kept), z3.Not(z3.And(drop))) + potential_faults.append(z3.Or(kept)) - crashes = z3.Bools(list(crashes)) + formula = z3.And(potential_faults) - s = z3.Solver() - s.add(z3.And(products)) + crashes = z3.Bools(list(data.crashes)) - # There can be at most --crashes many crashes. if crashes: - crashes.append(args.crashes) - s.add(z3.AtMost(crashes)) - r = s.check() - - # Output the result. - if r == z3.unsat: - if not(args.json): - print("No further faults can be injected at this point, the test case is") - print("certified for this particular failure specification!") - else: - print(json.dumps({"faults": []})) - elif r == z3.unknown: - print("Impossible: the SAT solver returned 'unknown'") - try: - print(s.model()) - except Z3Exception: - pass - finally: - exit(2) + crashes.append(config.max_crashes) + formula = z3.And(formula, z3.AtMost(crashes)) + logging.debug("formula: %s", str(formula)) + return formula + +def sat_solve(formula): + solver = z3.Solver() + solver.add(formula) + result = solver.check() + model = solver.model() + statistics = solver.statistics() + return (result, model, statistics) + +def order(d: dict) -> str: + return("%s %s %s %d" % (d['kind'], d['from'], d.get('to', ""), d['at'])) + +def create_log_event(config, result, model, statistics) -> Event: + statistics_dict = {} + for k, v in statistics: + statistics_dict[k] = v + + event = Event(config.test_id, config.run_ids[-1], json.dumps({"faults": []}), + get_distribution(__name__).version, str(statistics_dict)) + + if result == z3.unsat: + # No further faults can be injected at this point, the test case is + # certified for this particular failure specification! + return event + elif result == z3.unknown: + logging.critical("Impossible: the SAT solver returned 'unknown'") + try: + logging.critical(model) + except z3.Z3Exception: + pass + finally: + exit(2) else: - m = s.model() + faults = [] + for d in model.decls(): + if model[d]: + Dict = eval(d.name()) + faults.append(Dict) + faults = sorted(faults, key=order) + event.faults = json.dumps({"faults": faults}) - statistics = {} - for k, v in s.statistics(): - statistics[k] = v + return event - if not(args.json): - print(m) - print(statistics) - else: - faults = [] - for d in m.decls(): - if m[d]: - Dict = eval(d.name()) - faults.append(Dict) - faults = sorted(faults, key=order) - - storage.store_faults(args.test_id, args.run_ids[-1], json.dumps(faults), - get_distribution(__name__).version, str(statistics)) - - print(json.dumps({"faults": faults, - "statistics": statistics, - "version": get_distribution(__name__).version})) +def main(): + config = create_config() + storage = SqliteStorage() + + data = storage.load(config) + sanity_check(data) + + formula = create_sat_formula(config, data) + + (result, model, statistics) = sat_solve(formula) + + event = create_log_event(config, result, model, statistics) + storage.store(event) + logging.debug(event.faults) + print(event.faults) if __name__ == '__main__': main() diff --git a/src/ldfi/shell.nix b/src/ldfi/shell.nix index 23555734..e81084b4 100644 --- a/src/ldfi/shell.nix +++ b/src/ldfi/shell.nix @@ -3,11 +3,15 @@ }: with pkgs; -( let - inherit (import sources.gitignore {}) gitignoreSource; - ldfi = callPackage ./release.nix { - pythonPackages = python38Packages; - gitignoreSource = gitignoreSource; - }; - in python38.withPackages (ps: [ ldfi ]) -).env +let + inherit (import sources.gitignore {}) gitignoreSource; + ldfi = callPackage ./release.nix { + pythonPackages = python38Packages; + gitignoreSource = gitignoreSource; + }; + pythonEnv = python38.withPackages (ps: [ ldfi ]); +in + +mkShell { + buildInputs = [ pythonEnv mypy black ]; +} diff --git a/src/ldfi/tests/test_ldfi.py b/src/ldfi/tests/test_ldfi.py index 01024969..cfa25de1 100644 --- a/src/ldfi/tests/test_ldfi.py +++ b/src/ldfi/tests/test_ldfi.py @@ -1,5 +1,7 @@ import unittest -from ldfi import order +import ldfi +import z3 +from z3 import (And, Or, Not, Bool) class TestSortedFaults(unittest.TestCase): def test_sorted_faults(self): @@ -7,12 +9,46 @@ def test_sorted_faults(self): {"kind": "crash", "from": "frontend", "at": 1}, {"kind": "crash", "from": "frontend", "at": 0}, {"kind": "omission", "from": "frontend", "to": "register2", "at": 1}], - key=order) + key=ldfi.order) assert(l == [{"kind": "crash", "from": "frontend", "at": 0}, {"kind": "crash", "from": "frontend", "at": 1}, {"kind": "omission", "from": "frontend", "to": "register2", "at": 1}, {"kind": "omission", "from": "frontend", "to": "register2", "at": 2}]) +def o(f, t, at): + return ('{"kind": "omission", "from": "%s", "to": "%s", "at": %d}' % (f, t, at)) + +class TestCreateFormula(unittest.TestCase): + def test_create_formula(self): + config = ldfi.Config(-1, [-1], 2, 0) + previous_faults = [[]] + oab1 = o("A", "B", 1) + oac1 = o("A", "C", 1) + + # First run + potential_faults = [[oab1, oac1]] + crashes = set() + data = ldfi.Data(previous_faults, potential_faults, crashes) + ldfi.sanity_check(data) + formula = ldfi.create_sat_formula(config, data) + assert(formula == And(Or(Bool(oab1), Bool(oac1)))) + (result, model, statistics) = ldfi.sat_solve(formula) + assert result == z3.sat + event = ldfi.create_log_event(config, result, model, statistics) + assert event.faults == ('{"faults": [%s]}' % oab1) + + # Second run + previous_faults = [[oab1]] + potential_faults = [[oab1, oac1], [oac1]] + data = ldfi.Data(previous_faults, potential_faults, crashes) + ldfi.sanity_check(data) + formula = ldfi.create_sat_formula(config, data) + print(formula) + assert(formula == And(Or(Or(Bool(oab1), Bool(oac1)), + Not(And(Bool(oab1)))), + Or(Or(Bool(oac1)), + Not(And(Bool(oab1)))))) + if __name__ == '__main__': unittest.main() diff --git a/src/lib/ldfi.go b/src/lib/ldfi.go index 258df873..2b18ab29 100644 --- a/src/lib/ldfi.go +++ b/src/lib/ldfi.go @@ -3,6 +3,7 @@ package lib import ( "encoding/json" "log" + "os" "os/exec" "strconv" "strings" @@ -73,7 +74,6 @@ func (f *Fault) UnmarshalJSON(bs []byte) error { func Ldfi(testId TestId, runIds []RunId, fail FailSpec) Faults { start := time.Now() args := []string{ - "--json", "--eff", strconv.Itoa(fail.EFF), "--crashes", strconv.Itoa(fail.Crashes), "--test-id", strconv.Itoa(testId.TestId), @@ -83,17 +83,16 @@ func Ldfi(testId TestId, runIds []RunId, fail FailSpec) Faults { args = append(args, strconv.Itoa(runId.RunId)) } cmd := exec.Command("detsys-ldfi", args...) + cmd.Stderr = os.Stderr - out, err := cmd.CombinedOutput() + out, err := cmd.Output() if err != nil { log.Panicf("%s\n%s\n", err, out) } var result struct { - Faults []Fault `json:"faults"` - Statistics map[string]interface{} `json:"statistics"` - Version string `json:"version"` + Faults []Fault `json:"faults"` } err = json.Unmarshal(out, &result) @@ -103,8 +102,6 @@ func Ldfi(testId TestId, runIds []RunId, fail FailSpec) Faults { elapsed := time.Since(start) log.Printf("ldfi time: %v\n", elapsed) - log.Printf("z3 statistics: %+v\n", result.Statistics) - log.Printf("version: %+v\n", result.Version) return Faults{result.Faults} } diff --git a/src/sut/broadcast/broadcast_test.go b/src/sut/broadcast/broadcast_test.go index a386f232..053eed1c 100644 --- a/src/sut/broadcast/broadcast_test.go +++ b/src/sut/broadcast/broadcast_test.go @@ -58,7 +58,7 @@ func many(round Round, t *testing.T) { lib.Reset() lib.InjectFaults(lib.Faults{faults}) lib.SetTickFrequency(tickFrequency) - maxTime, err := time.ParseDuration("10s") + maxTime, err := time.ParseDuration("5s") if err != nil { panic(err) }