diff --git a/.gitignore b/.gitignore index 56fe693dd..d20cfab6c 100644 --- a/.gitignore +++ b/.gitignore @@ -204,16 +204,20 @@ dmypy.json pyvenv.cfg pip-selfcheck.json +# End of https://www.gitignore.io/api/c,ocaml,python + # dont upload our real life zoo test/real_world_samples test/run_real_world_samples.sh -# End of https://www.gitignore.io/api/c,ocaml,python - .project .pydevproject -src/cwe_checker.plugin +# Plugin files (generated by bapbuild) +*.plugin + +# install files for opam packages (generated by dune) +*.install test/artificial_samples/dockcross* diff --git a/CHANGES.md b/CHANGES.md index 54667a8bf..5eaf6fcd1 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -9,6 +9,8 @@ - Added BAP recipe for standard cwe_checker run (PR #9) - Improved check for CWE-476 (NULL Pointer Dereference) using data flow analysis (PR #11) - Switched C build system from make to scons (PR #16) +- Added type inference pass (PR #14) +- Added unit tests to test suite (PR #14) 0.1 (2018-10-08) ===== diff --git a/Makefile b/Makefile index cd033cd7d..251a22262 100644 --- a/Makefile +++ b/Makefile @@ -1,12 +1,25 @@ .PHONY: all clean test uninstall all: - cd src; bapbuild -r -Is checkers,utils -pkgs yojson,unix cwe_checker.plugin; bapbundle install cwe_checker.plugin; cd .. + dune build --profile release + dune install + cd plugins/cwe_checker; make all; cd ../.. + cd plugins/cwe_checker_type_inference; make all; cd ../.. + cd plugins/cwe_checker_type_inference_print; make all; cd ../.. test: + dune runtest --profile release # TODO: correct all dune linter warnings so that we can remove --profile release pytest -v clean: + dune clean bapbuild -clean + cd test/unit; make clean; cd ../.. + cd plugins/cwe_checker; make clean; cd ../.. + cd plugins/cwe_checker_type_inference; make clean; cd ../.. + cd plugins/cwe_checker_type_inference_print; make clean; cd ../.. uninstall: - bapbundle remove cwe_checker.plugin + dune uninstall + cd plugins/cwe_checker; make uninstall; cd ../.. + cd plugins/cwe_checker_type_inference; make uninstall; cd ../.. + cd plugins/cwe_checker_type_inference_print; make uninstall; cd ../.. diff --git a/dune-project b/dune-project new file mode 100644 index 000000000..0cd3eb62a --- /dev/null +++ b/dune-project @@ -0,0 +1,2 @@ +(lang dune 1.6) +(name cwe_checker) diff --git a/plugins/cwe_checker/Makefile b/plugins/cwe_checker/Makefile new file mode 100644 index 000000000..7d936e0ee --- /dev/null +++ b/plugins/cwe_checker/Makefile @@ -0,0 +1,9 @@ +all: + bapbuild -pkgs yojson,unix,ppx_jane,cwe_checker_core cwe_checker.plugin + bapbundle install cwe_checker.plugin + +clean: + bapbuild -clean + +uninstall: + bapbundle remove cwe_checker.plugin diff --git a/src/cwe_checker.ml b/plugins/cwe_checker/cwe_checker.ml similarity index 99% rename from src/cwe_checker.ml rename to plugins/cwe_checker/cwe_checker.ml index 349a70470..10139d92e 100644 --- a/src/cwe_checker.ml +++ b/plugins/cwe_checker/cwe_checker.ml @@ -1,8 +1,9 @@ -open Core_kernel.Std +open Core_kernel open Bap.Std open Graphlib.Std open Format open Yojson.Basic.Util +open Cwe_checker_core include Self() diff --git a/plugins/cwe_checker_type_inference/Makefile b/plugins/cwe_checker_type_inference/Makefile new file mode 100644 index 000000000..1019e6473 --- /dev/null +++ b/plugins/cwe_checker_type_inference/Makefile @@ -0,0 +1,9 @@ +all: + bapbuild -pkgs yojson,unix,ppx_jane,cwe_checker_core cwe_checker_type_inference.plugin + bapbundle install cwe_checker_type_inference.plugin + +clean: + bapbuild -clean + +uninstall: + bapbundle remove cwe_checker_type_inference.plugin diff --git a/plugins/cwe_checker_type_inference/cwe_checker_type_inference.ml b/plugins/cwe_checker_type_inference/cwe_checker_type_inference.ml new file mode 100644 index 000000000..802015112 --- /dev/null +++ b/plugins/cwe_checker_type_inference/cwe_checker_type_inference.ml @@ -0,0 +1,5 @@ +open Bap.Std +open Core_kernel +open Cwe_checker_core + +let () = Project.register_pass Type_inference.compute_pointer_register diff --git a/plugins/cwe_checker_type_inference_print/Makefile b/plugins/cwe_checker_type_inference_print/Makefile new file mode 100644 index 000000000..616f24d06 --- /dev/null +++ b/plugins/cwe_checker_type_inference_print/Makefile @@ -0,0 +1,9 @@ +all: + bapbuild -pkgs yojson,unix,ppx_jane,cwe_checker_core cwe_checker_type_inference_print.plugin + bapbundle install cwe_checker_type_inference_print.plugin + +clean: + bapbuild -clean + +uninstall: + bapbundle remove cwe_checker_type_inference_print.plugin diff --git a/plugins/cwe_checker_type_inference_print/cwe_checker_type_inference_print.ml b/plugins/cwe_checker_type_inference_print/cwe_checker_type_inference_print.ml new file mode 100644 index 000000000..63b074cbd --- /dev/null +++ b/plugins/cwe_checker_type_inference_print/cwe_checker_type_inference_print.ml @@ -0,0 +1,14 @@ +open Bap.Std +open Core_kernel +open Cwe_checker_core + +let main project = + Log_utils.set_log_level Log_utils.DEBUG; + Log_utils.set_output stdout; + Log_utils.color_on (); + + let program = Project.program project in + let tid_map = Address_translation.generate_tid_map program in + Type_inference.print_type_info_tags project tid_map + +let () = Project.register_pass' main ~deps:["cwe-checker-type-inference"] diff --git a/recipes/type_inference/descr b/recipes/type_inference/descr new file mode 100644 index 000000000..b41ca8445 --- /dev/null +++ b/recipes/type_inference/descr @@ -0,0 +1 @@ +Prints the results of the type inference pass of each block. diff --git a/recipes/type_inference/recipe.scm b/recipes/type_inference/recipe.scm new file mode 100644 index 000000000..10c1dafd6 --- /dev/null +++ b/recipes/type_inference/recipe.scm @@ -0,0 +1 @@ +(option pass cwe-type-inference-print) diff --git a/src/analysis/mem_region.ml b/src/analysis/mem_region.ml new file mode 100644 index 000000000..f426c347c --- /dev/null +++ b/src/analysis/mem_region.ml @@ -0,0 +1,164 @@ +open Bap.Std +open Core_kernel + + +let (+), (-) = Bitvector.(+), Bitvector.(-) + +let (>) x y = Bitvector.(>) (Bitvector.signed x) (Bitvector.signed y) +let (<) x y = Bitvector.(<) (Bitvector.signed x) (Bitvector.signed y) +let (>=) x y = Bitvector.(>=) (Bitvector.signed x) (Bitvector.signed y) +let (<=) x y = Bitvector.(<=) (Bitvector.signed x) (Bitvector.signed y) +let (=) x y = Bitvector.(=) x y + +type 'a mem_node = { + pos: Bitvector.t; (* address of the element *) + size: Bitvector.t; (* size (in bytes) of the element *) + data: ('a, unit) Result.t; +} [@@deriving bin_io, compare, sexp] + +type 'a t = 'a mem_node list [@@deriving bin_io, compare, sexp] + + +let empty () : 'a t = + [] + +(** Return an error mem_node at the given position with the given size. *) +let error_elem ~pos ~size = + { pos = pos; + size = size; + data = Error ();} + + +let rec add mem_region elem ~pos ~size = + let () = if pos + size < pos then failwith "[CWE-checker] element out of bounds for mem_region" in + let new_node = { + pos=pos; + size=size; + data=Ok(elem); + } in + match mem_region with + | [] -> new_node :: [] + | head :: tail -> + if head.pos + head.size <= pos then + head :: (add tail elem ~pos ~size) + else if pos + size <= head.pos then + new_node :: mem_region + else begin (* head and new node intersect => at the intersection, head gets overwritten and the rest of head gets marked as error. *) + let tail = if head.pos + head.size > pos + size then (* mark the right end of head as error *) + let err = error_elem ~pos:(pos + size) ~size:(head.pos + head.size - (pos + size)) in + err :: tail + else + tail in + let tail = add tail elem ~pos ~size in (* add the new element*) + let tail = if head.pos < pos then (* mark the left end of head as error *) + let err = error_elem ~pos:(head.pos) ~size:(pos - head.pos) in + err :: tail + else + tail in + tail + end + + +let rec get mem_region pos = + match mem_region with + | [] -> None + | head :: tail -> + if head.pos > pos then + None + else if head.pos = pos then + match head.data with + | Ok(x) -> Some(Ok(x, head.size)) + | Error(_) -> Some(Error(())) + else if head.pos + head.size <= pos then + get tail pos + else + Some(Error(())) (* pos intersects some data, but does not equal its starting address*) + +(* Helper function. Removes all elements with position <= pos. *) +let rec remove_until mem_region pos = + match mem_region with + | [] -> [] + | hd :: tl -> + if hd.pos <= pos then + remove_until tl pos + else + mem_region + + +let rec remove mem_region ~pos ~size = + let () = if pos + size < pos then failwith "[CWE-checker] element out of bounds for mem_region" in + match mem_region with + | [] -> [] + | hd :: tl -> + if hd.pos + hd.size <= pos then + hd :: remove tl pos size + else if pos + size <= hd.pos then + mem_region + else + let mem_region = remove tl pos size in + let mem_region = + if hd.pos + hd.size > pos + size then + error_elem ~pos:(pos + size) ~size:(hd.pos + hd.size - (pos + size)) :: mem_region + else + mem_region in + let mem_region = + if hd.pos < pos then + error_elem ~pos:hd.pos ~size:(pos - hd.pos) :: mem_region + else + mem_region in + mem_region + +let rec mark_error mem_region ~pos ~size = + let () = if pos + size < pos then failwith "[CWE-checker] element out of bounds for mem_region" in + match mem_region with + | [] -> (error_elem pos size) :: [] + | hd :: tl -> + if hd.pos + hd.size <= pos then + hd :: (mark_error tl pos size) + else if pos + size <= hd.pos then + (error_elem pos size) :: mem_region + else + let start_pos = min pos hd.pos in + let end_pos_plus_one = max (pos + size) (hd.pos + hd.size) in + mark_error tl ~pos:start_pos ~size:(end_pos_plus_one - start_pos) + + +(* TODO: This is probably a very inefficient implementation in some cases. Write a faster implementation if necessary. *) +let rec merge mem_region1 mem_region2 ~data_merge = + match (mem_region1, mem_region2) with + | (value, []) + | ([], value) -> value + | (hd1 :: tl1, hd2 :: tl2) -> + if hd1.pos + hd1.size <= hd2.pos then + hd1 :: merge tl1 mem_region2 data_merge + else if hd2.pos + hd2.size <= hd1.pos then + hd2 :: merge mem_region1 tl2 data_merge + else if hd1.pos = hd2.pos && hd1.size = hd2.size then + match (hd1.data, hd2.data) with + | (Ok(data1), Ok(data2)) -> begin + match data_merge data1 data2 with + | Some(Ok(value)) -> { hd1 with data = Ok(value) } :: merge tl1 tl2 ~data_merge + | Some(Error(_)) -> {hd1 with data = Error(())} :: merge tl1 tl2 ~data_merge + | None -> merge tl1 tl2 data_merge + end + | _ -> { hd1 with data = Error(()) } :: merge tl1 tl2 ~data_merge + else + let start_pos = min hd1.pos hd2.pos in + let end_pos_plus_one = max (hd1.pos + hd1.size) (hd2.pos + hd2.size) in + let mem_region = merge tl1 tl2 data_merge in + mark_error mem_region ~pos:start_pos ~size:(end_pos_plus_one - start_pos) + + +let rec equal (mem_region1:'a t) (mem_region2:'a t) ~data_equal : bool = + match (mem_region1, mem_region2) with + | ([], []) -> true + | (hd1 :: tl1, hd2 :: tl2) -> + if hd1.pos = hd2.pos && hd1.size = hd2.size then + match (hd1.data, hd2.data) with + | (Ok(data1), Ok(data2)) when data_equal data1 data2 -> + equal tl1 tl2 data_equal + | (Error(()), Error(())) -> equal tl1 tl2 data_equal + | _ -> false + else + false + | _ -> false diff --git a/src/analysis/mem_region.mli b/src/analysis/mem_region.mli new file mode 100644 index 000000000..f2006282a --- /dev/null +++ b/src/analysis/mem_region.mli @@ -0,0 +1,39 @@ +(** contains an abstract memory region data type where you can assign arbitrary data to locations + inside the memory regions. A memory region has no fixed size, so it can be used + for memory regions of variable size like arrays or stacks. + + TODO: Right now this data structure is unsuited for elements that get only partially loaded. *) + +open Bap.Std +open Core_kernel + +type 'a t [@@deriving bin_io, compare, sexp] + + +(** Get an empty memory region- *) +val empty: unit -> 'a t + + +(** Add an element to the memory region. If the element intersects existing elements, + the non-overwritten part gets marked as Error *) +val add: 'a t -> 'a -> pos:Bitvector.t -> size:Bitvector.t -> 'a t + +(** Mark the memory region between pos (included) and pos+size (excluded) as empty. + If elements get partially removed, mark the non-removed parts as Error *) +val remove: 'a t -> pos:Bitvector.t -> size:Bitvector.t -> 'a t + +(** Returns the element and its size at position pos or None, when there is no element at that position. + If pos intersects an element but does not match its starting position, it returns Some(Error(())). *) +val get: 'a t -> Bitvector.t -> (('a * Bitvector.t), unit) Result.t Option.t + +(** Merge two memory regions. Elements with the same position and size get merged using + data_merge, other intersecting elements get marked as Error. Note that data_merge + may return None (to remove the elements from the memory region) or Some(Error(_)) to + mark the merged element as error. *) +val merge: 'a t -> 'a t -> data_merge:('a -> 'a -> ('a, 'b) result option) -> 'a t + +(** Check whether two memory regions are equal. *) +val equal: 'a t -> 'a t -> data_equal:('a -> 'a -> bool) -> bool + +(** Mark an area in the mem_region as containing errors. *) +val mark_error: 'a t -> pos:Bitvector.t -> size:Bitvector.t -> 'a t diff --git a/src/analysis/type_inference.ml b/src/analysis/type_inference.ml new file mode 100644 index 000000000..c1fd31f44 --- /dev/null +++ b/src/analysis/type_inference.ml @@ -0,0 +1,459 @@ +open Bap.Std +open Core_kernel + +(** TODO: + interprocedural analysis + backward analysis to recognize which constants are pointers and which not. + extend to track FunctionPointer, DataPointer + extend to track PointerTargets + TODO: There are no checks yet if a value from the stack of the calling function + is accessed. Maybe this should be part of another analysis. + TODO: tracking for PointerTargets should also track if another register other + than the stack register is used to access values on the stack of the current + function. +*) + +let name = "Type Inference" +let version = "0.1" + +module Register = struct + type t = + | Pointer + | Data + [@@deriving bin_io, compare, sexp] + + let merge reg1 reg2 = + if reg1 = reg2 then Some(Ok(reg1)) else Some(Error(())) + + let equal reg1 reg2 = + reg1 = reg2 +end + + +module TypeInfo = struct + type reg_state = (Register.t, unit) Result.t Var.Map.t [@@deriving bin_io, compare, sexp] + type t = { + stack: Register.t Mem_region.t; + stack_offset: (Bitvector.t, unit) Result.t Option.t; (* If we don't know the offset, this is None, if we have conflicting values for the offset, this is Some(Error()) *) + reg: reg_state; + } [@@deriving bin_io, compare, sexp] + + let merge state1 state2 = + let stack = Mem_region.merge state1.stack state2.stack ~data_merge:Register.merge in + let stack_offset = match (state1.stack_offset, state2.stack_offset) with + | (Some(Ok(x)), Some(Ok(y))) when x = y -> Some(Ok(x)) + | (Some(x), None) + | (None, Some(x)) -> Some(x) + | (None, None) -> None + | _ -> Some(Error(())) in + let reg = Map.merge state1.reg state2.reg ~f:(fun ~key values -> + match values with + | `Left(reg) + | `Right(reg) -> Some(reg) + | `Both(Ok(reg1), Ok(reg2)) -> Register.merge reg1 reg2 + | `Both(_, _) -> Some(Error(())) + ) in + { stack = stack; + stack_offset = stack_offset; + reg = reg } + + let equal state1 state2 = + if state1.stack_offset = state2.stack_offset && (Mem_region.equal state1.stack state2.stack ~data_equal:Register.equal) then + Map.equal (fun reg1 reg2 -> reg1 = reg2) state1.reg state2.reg + else + false + + (** Get an empty state. *) + let empty () = + let module VarMap = Var.Map in + { stack = Mem_region.empty (); + stack_offset = None; + reg = VarMap.empty; + } + + (** Returns a register list with only the stack pointer as pointer register and + only the flag registers as data registers. *) + let get_stack_pointer_and_flags project = + let stack_pointer = Symbol_utils.stack_register project in + let reg = Map.set Var.Map.empty ~key:stack_pointer ~data:(Ok(Register.Pointer)) in + let flags = Symbol_utils.flag_register_list project in + List.fold flags ~init:reg ~f:(fun state register -> + Map.set state register (Ok(Register.Data)) ) + + (** create a new state with stack pointer as known pointer register and all flag + registers as known data registers. The stack itself is empty and the offset + is 0. (TODO for interprocedural analysis: Ensure that the return address is + marked as a pointer!) *) + let function_start_state project = + let module VarMap = Var.Map in + let reg = get_stack_pointer_and_flags project in + { stack = Mem_region.empty (); + stack_offset = Some(Ok(Bitvector.of_int 0 ~width:(Symbol_utils.arch_pointer_size_in_bytes project * 8))); + reg = reg; + } + + let remove_virtual_registers state = + { state with reg = Map.filter_keys state.reg ~f:(fun var -> Var.is_physical var) } + + let stack_offset_add state (value:Bitvector.t) = + match state.stack_offset with + | Some(Ok(x)) -> { state with stack_offset = Some(Ok(Bitvector.(+) x value)) } + | _ -> state + + (** if the addr_exp is a (computable) stack offset, return the offset *) + let compute_stack_offset state addr_exp ~project = + let (register, offset) = match addr_exp with + | Bil.Var(var) -> (Some(var), Bitvector.of_int 0 ~width:(Symbol_utils.arch_pointer_size_in_bytes project * 8)) + | Bil.BinOp(Bil.PLUS, Bil.Var(var), Bil.Int(num)) -> (Some(var), num) + | Bil.BinOp(Bil.MINUS, Bil.Var(var), Bil.Int(num)) -> (Some(var), Bitvector.neg (Bitvector.signed num)) + | _ -> (None, Bitvector.of_int 0 ~width:(Symbol_utils.arch_pointer_size_in_bytes project * 8)) in + match (register, state.stack_offset) with + | (Some(var), Some(Ok(base_offset))) when var = (Symbol_utils.stack_register project) -> Some(Bitvector.(+) base_offset offset) + | _ -> None + + (* Pretty printer that just prints the sexp. Needed for the creation of type_info_tag. *) + let pp ppf elem = + Format.fprintf ppf "%s" (Sexp.to_string (sexp_of_t elem)) + +end (* module *) + +(* Create a tag for TypeInfo *) +let type_info_tag = Value.Tag.register (module TypeInfo) + ~name:"type_info" + ~uuid:"7a537f19-2dd1-49b6-b343-35b4b1d04c0b" + +(** returns a list with all nested expressions. If expr1 is contained in expr2, then + expr1 will be included after expr2 in the list. *) +let rec nested_exp_list exp : Exp.t list = + let nested_exp = match exp with + | Bil.Load(exp1, exp2, _, _) -> exp :: (nested_exp_list exp1) @ (nested_exp_list exp2) + | Bil.Store(exp1, exp2, exp3, _, _) -> nested_exp_list exp1 @ nested_exp_list exp2 @ nested_exp_list exp3 + | Bil.BinOp(op, exp1, exp2) -> nested_exp_list exp1 @ nested_exp_list exp2 + | Bil.UnOp(op, exp1) -> nested_exp_list exp1 + | Bil.Var(_) -> [] + | Bil.Int(_) -> [] + | Bil.Cast(_, _, exp1) -> nested_exp_list exp1 + | Bil.Let(_, exp1, exp2) -> nested_exp_list exp1 @ nested_exp_list exp2 + | Bil.Unknown(_) -> [] + | Bil.Ite(exp1, exp2, exp3) -> nested_exp_list exp1 @ nested_exp_list exp2 @ nested_exp_list exp3 + | Bil.Extract(_, _, exp1) -> nested_exp_list exp1 + | Bil.Concat(exp1, exp2) -> nested_exp_list exp1 @ nested_exp_list exp2 in + exp :: nested_exp + + +(** If exp is a load from the stack, return the corresponding element. + TODO: Bil.AND and Bil.OR are ignored, because we do not track alignment yet. *) +let get_stack_elem state exp ~project = + match exp with + | Bil.Load(_, addr, endian, size) -> begin (* TODO: add a test for correct endianess *) + match TypeInfo.compute_stack_offset state addr project with + | Some(offset) -> begin + match Mem_region.get state.TypeInfo.stack offset with + | Some(Ok(elem, elem_size)) -> + if Bitvector.to_int elem_size = Ok(Size.in_bytes size) then + Some(Ok(elem)) + else + Some(Error()) + | Some(Error()) -> Some(Error()) + | None -> None + end + | None -> None + end + | _ -> None + + + + +let rec type_of_exp exp (state: TypeInfo.t) ~project = + let open Register in + match exp with + | Bil.Load(_) -> (* TODO: Right now only the stack is tracked for type infos. *) + get_stack_elem state exp ~project + | Bil.Store(_) -> None (* Stores are handled in another function. *) + | Bil.BinOp(binop, exp1, exp2) -> begin + match (binop, type_of_exp exp1 state project, type_of_exp exp2 state project) with + (* pointer arithmetics *) + | (Bil.PLUS, Some(Ok(Pointer)), Some(Ok(Pointer))) -> Some(Error(())) + | (Bil.PLUS, Some(Ok(Pointer)), other) + | (Bil.PLUS, other, Some(Ok(Pointer))) -> Some(Ok(Pointer)) + | (Bil.PLUS, Some(Ok(Data)), Some(Ok(Data))) -> Some(Ok(Data)) + | (Bil.PLUS, _, _) -> None + | (Bil.MINUS, Some(Ok(Pointer)), Some(Ok(Pointer))) -> Some(Ok(Data)) (* Pointer subtraction to determine offset is CWE-469, this should be logged. *) + | (Bil.MINUS, Some(Ok(Pointer)), other) -> Some(Ok(Pointer)) (* We assume that other is not a pointer. This can only generate errors in the presence of CWE-469 *) + | (Bil.MINUS, Some(Ok(Data)), Some(Ok(Data))) -> Some(Ok(Data)) + | (Bil.MINUS, _, _) -> None + (* bitwise AND and OR can be used as addition and subtraction if some alignment of the pointer is known *) + | (Bil.AND, Some(Ok(Pointer)), Some(Ok(Pointer))) -> Some(Error(())) (* TODO: This could be a pointer, but is there any case where this is used in practice? *) + | (Bil.AND, Some(Ok(Pointer)), other) + | (Bil.AND, other, Some(Ok(Pointer))) -> Some(Ok(Pointer)) + | (Bil.AND, Some(Ok(Data)), Some(Ok(Data))) -> Some(Ok(Data)) + | (Bil.AND, _, _) -> None + | (Bil.OR, Some(Ok(Pointer)), Some(Ok(Pointer))) -> Some(Error(())) (* TODO: This could be a pointer, but is there any case where this is used in practice? *) + | (Bil.OR, Some(Ok(Pointer)), other) + | (Bil.OR, other, Some(Ok(Pointer))) -> Some(Ok(Pointer)) + | (Bil.OR, Some(Ok(Data)), Some(Ok(Data))) -> Some(Ok(Data)) + | (Bil.OR, _, _) -> None + | _ -> Some(Ok(Data)) (* every other operation should not yield valid pointers *) + end + | Bil.UnOp(_) -> Some(Ok(Data)) + | Bil.Var(var) -> Map.find state.TypeInfo.reg var + | Bil.Int(_) -> None (* TODO: For non-relocateable binaries this could be a pointer to a function/global variable *) + | Bil.Cast(Bil.SIGNED, _, _) -> Some(Ok(Data)) + | Bil.Cast(_, size, exp) -> + if size = (Symbol_utils.arch_pointer_size_in_bytes project * 8) then type_of_exp exp state project else Some(Ok(Data)) (* TODO: There is probably a special case when 64bit addresses are converted to 32bit addresses here, which can yield pointers *) + | Bil.Let(_) -> None + | Bil.Unknown(_) -> None + | Bil.Ite(if_, then_, else_) -> begin + match (type_of_exp then_ state project, type_of_exp else_ state project) with + | (Some(value1), Some(value2)) -> if value1 = value2 then Some(value1) else None + | _ -> None + end + | Bil.Extract(_) -> Some(Ok(Data)) (* TODO: Similar to cast: Are there cases of 32bit-64bit-address-conversions here? *) + | Bil.Concat(_) -> Some(Ok(Data)) (* TODO: If alignment of the pointer is known, it could be used like AND and OR *) + +let pointer_size_as_bitvector project = + let psize = Symbol_utils.arch_pointer_size_in_bytes project in + Bitvector.of_int psize ~width:(psize * 8) + +(* If exp is a store to the stack, add the corresponding value to the stack. If the + we cannot determine the value, delete the corresponding data on the stack. *) +let set_stack_elem state exp ~project = + match exp with + | Bil.Store(_, addr_exp, value_exp, endian, size) -> + begin + match (TypeInfo.compute_stack_offset state addr_exp project, type_of_exp value_exp state ~project) with + | (Some(offset), Some(Ok(value))) when Size.in_bytes size = (Symbol_utils.arch_pointer_size_in_bytes project) -> + let stack = Mem_region.add state.TypeInfo.stack value ~pos:offset ~size:(pointer_size_as_bitvector project) in + { state with TypeInfo.stack = stack} + | (Some(offset), Some(Ok(value))) when Size.in_bytes size <> (Symbol_utils.arch_pointer_size_in_bytes project) -> + let stack = Mem_region.add state.TypeInfo.stack Register.Data ~pos:offset ~size:(Bitvector.of_int (Size.in_bytes size) ~width:(Symbol_utils.arch_pointer_size_in_bytes project)) in + { state with TypeInfo.stack = stack} + | (Some(offset), Some(Error(_))) -> + let stack = Mem_region.mark_error state.TypeInfo.stack ~pos:offset ~size:(Bitvector.of_int (Size.in_bytes size) ~width:(Symbol_utils.arch_pointer_size_in_bytes project)) in + { state with TypeInfo.stack = stack} + | (Some(offset), None) -> + let stack = Mem_region.remove state.TypeInfo.stack ~pos:offset ~size:(Bitvector.of_int (Size.in_bytes size) ~width:(Symbol_utils.arch_pointer_size_in_bytes project)) in + { state with TypeInfo.stack = stack} + | _ -> state + end + | _ -> state + +let add_mem_address_registers state exp ~project = + let exp_list = nested_exp_list exp in + List.fold exp_list ~init:state ~f:(fun state exp -> + match exp with + | Bil.Load(_, addr_exp, _, _) + | Bil.Store(_, addr_exp, _, _, _) -> begin + match addr_exp with + | Bil.Var(addr) + | Bil.BinOp(Bil.PLUS, Bil.Var(addr), Bil.Int(_)) + | Bil.BinOp(Bil.MINUS, Bil.Var(addr), Bil.Int(_)) + | Bil.BinOp(Bil.AND, Bil.Var(addr), Bil.Int(_)) + | Bil.BinOp(Bil.OR, Bil.Var(addr), Bil.Int(_)) -> + { state with TypeInfo.reg = Map.set state.TypeInfo.reg addr (Ok(Register.Pointer)) } (* TODO: there are some false positives here for indices in global data arrays, where the immediate is the pointer. Maybe remove all cases with potential false positives? *) + | Bil.BinOp(Bil.PLUS, Bil.Var(addr), exp2) + | Bil.BinOp(Bil.MINUS, Bil.Var(addr), exp2) + | Bil.BinOp(Bil.AND, Bil.Var(addr), exp2) + | Bil.BinOp(Bil.OR, Bil.Var(addr), exp2) -> + if type_of_exp exp2 state project = Some(Ok(Register.Data)) then + { state with TypeInfo.reg = Map.set state.TypeInfo.reg addr (Ok(Register.Pointer)) } + else + state + | _ -> state + end + | _ -> state + ) + + +(* updates the stack offset if a definition changes the stack pointer value. + TODO: Bil.AND, Bil.OR are ignored because we do not track alignment yet. *) +let update_stack_offset state def ~project = + let stack_register = Symbol_utils.stack_register project in + if Def.lhs def = stack_register && Option.is_some state.TypeInfo.stack_offset then + match Def.rhs def with + | Bil.BinOp(Bil.PLUS, Bil.Var(var), Bil.Int(value)) -> + if var = stack_register then + TypeInfo.stack_offset_add state value + else + { state with TypeInfo.stack_offset = None } + | Bil.BinOp(Bil.MINUS, Bil.Var(var), Bil.Int(value)) -> + if var = stack_register then + TypeInfo.stack_offset_add state (Bitvector.neg (Bitvector.signed value)) + else + { state with TypeInfo.stack_offset = None } + | _ -> { state with TypeInfo.stack_offset = None } + else + state + +(* Remove any knowledge of the stack (except the stack_offset) and the registers (except stack and flag registers) from the state. *) +let keep_only_stack_offset state ~project = + let empty_state = TypeInfo.empty() in + { empty_state with + TypeInfo.stack_offset = state.TypeInfo.stack_offset; + TypeInfo.reg = TypeInfo.get_stack_pointer_and_flags project } + +let update_state_def state def ~project = + (* add all registers that are used as address registers in load/store expressions to the state *) + let state = add_mem_address_registers state (Def.rhs def) project in + let state = match type_of_exp (Def.rhs def) state project with + | Some(value) -> + let reg = Map.set state.TypeInfo.reg (Def.lhs def) value in + { state with TypeInfo.reg = reg } + | None -> (* We don't know the type of the new value *) + let reg = Map.remove state.TypeInfo.reg (Def.lhs def) in + { state with TypeInfo.reg = reg } in + (* update stack offset and maybe write something to the stack *) + let state = update_stack_offset state def ~project in + let state = set_stack_elem state (Def.rhs def) ~project in + state + +let update_state_jmp state jmp ~project = + match Jmp.kind jmp with + | Call(call) -> begin match Call.target call with + | Direct(tid) -> + let func_name = match String.lsplit2 (Tid.name tid) ~on:'@' with + | Some(_left, right) -> right + | None -> Tid.name tid in + if String.Set.mem (Cconv.parse_dyn_syms project) func_name then + let empty_state = TypeInfo.empty () in (* TODO: to preserve stack information we need to be sure that the callee does not write on the stack => needs pointer source tracking! *) + { empty_state with + TypeInfo.stack_offset = state.TypeInfo.stack_offset; + TypeInfo.reg = Var.Map.filter_keys state.TypeInfo.reg ~f:(fun var -> Cconv.is_callee_saved var project) } + else + keep_only_stack_offset state project (* TODO: add interprocedural analysis here. *) + | Indirect(_) -> keep_only_stack_offset state project (* TODO: when we have value tracking and interprocedural analysis, we can add indirect calls to the regular analysis. *) + end + | Int(_, _) -> (* TODO: We need stubs and/or interprocedural analysis here *) + keep_only_stack_offset state project + | Goto(Indirect(Bil.Var(var))) (* TODO: warn when jumping to something that is marked as data. *) + | Ret(Indirect(Bil.Var(var))) -> + let reg = Map.set state.TypeInfo.reg var (Ok(Register.Pointer)) in + { state with TypeInfo.reg = reg } + | Goto(_) + | Ret(_) -> state + +(* This is public for unit test purposes. *) +let update_type_info block_elem state ~project = + match block_elem with + | `Def def -> update_state_def state def ~project + | `Phi phi -> state (* We ignore phi terms for this analysis. *) + | `Jmp jmp -> update_state_jmp state jmp ~project + +(** updates a block analysis. *) +let update_block_analysis block register_state ~project = + (* get all elements (Defs, Jumps, Phi-nodes) in the correct order *) + let elements = Blk.elts block in + let register_state = Seq.fold elements ~init:register_state ~f:(fun state element -> + update_type_info element state ~project + ) in + TypeInfo.remove_virtual_registers register_state (* virtual registers should not be accessed outside of the block where they are defined. *) + + +let intraprocedural_fixpoint func ~project = + let cfg = Sub.to_cfg func in + (* default state for nodes *) + let only_sp = { (TypeInfo.empty ()) with TypeInfo.reg = TypeInfo.get_stack_pointer_and_flags project } in + try + (* Create a starting solution where only the first block of a function knows the stack_offset. *) + let fn_start_state = TypeInfo.function_start_state project in + let fn_start_block = Option.value_exn (Term.first blk_t func) in + let fn_start_state = update_block_analysis fn_start_block fn_start_state ~project in + let fn_start_node = Seq.find_exn (Graphs.Ir.nodes cfg) ~f:(fun node -> (Term.tid fn_start_block) = (Term.tid (Graphs.Ir.Node.label node))) in + let empty = Map.empty (module Graphs.Ir.Node) in + let with_start_node = Map.set empty fn_start_node fn_start_state in + let init = Graphlib.Std.Solution.create with_start_node only_sp in + let equal = TypeInfo.equal in + let merge = TypeInfo.merge in + let f = (fun node state -> + let block = Graphs.Ir.Node.label node in + update_block_analysis block state ~project + ) in + Graphlib.Std.Graphlib.fixpoint (module Graphs.Ir) cfg ~steps:100 ~rev:false ~init:init ~equal:equal ~merge:merge ~f:f + with + | _ -> (* An exception will be thrown if the function does not contain any blocks. In this case we can simply return an empty solution. *) + Graphlib.Std.Solution.create (Map.empty (module Graphs.Ir.Node)) only_sp + + +(** Extract the starting state of a node. *) +let extract_start_state node ~cfg ~solution ~project = + let predecessors = Graphs.Ir.Node.preds node cfg in + if Seq.is_empty predecessors then + TypeInfo.function_start_state project (* This should be the first block of a function. Maybe add a test for when there is more than one such block in a function? *) + else + let only_sp = { (TypeInfo.empty ()) with TypeInfo.reg = TypeInfo.get_stack_pointer_and_flags project } in + Seq.fold predecessors ~init:only_sp ~f:(fun state node -> + TypeInfo.merge state (Graphlib.Std.Solution.get solution node) + ) + +(** Returns a list of pairs (tid, state) for each def in a (blk_t-)node. The state +is the state _after execution of the node. *) +let state_list_def node ~cfg ~solution ~project = + let input_state = extract_start_state node ~cfg ~solution ~project in + let block = Graphs.Ir.Node.label node in + let defs = Term.enum def_t block in + let (output, _) = Seq.fold defs ~init:([], input_state) ~f:(fun (list_, state) def -> + let state = update_state_def state def project in + ( (Term.tid def, state) :: list_, state) + ) in + output + + +let compute_pointer_register project = + let program = Project.program project in + let program_with_tags = Term.map sub_t program ~f:(fun func -> + let cfg = Sub.to_cfg func in + let solution = intraprocedural_fixpoint func project in + Seq.fold (Graphs.Ir.nodes cfg) ~init:func ~f:(fun func node -> + let block = Graphs.Ir.Node.label node in + let start_state = extract_start_state node cfg solution project in + let tagged_block = Term.set_attr block type_info_tag start_state in + Term.update blk_t func tagged_block + ) + ) in + Project.with_program project program_with_tags + +(** Prints type info to debug. *) +let print_type_info_to_debug state block_tid ~tid_map = + let register_list = Map.fold state.TypeInfo.reg ~init:[] ~f:(fun ~key:var ~data:reg str_list -> + match reg with + | Ok(Register.Pointer) -> (Var.name var ^ ":Pointer, ") :: str_list + | Ok(Register.Data) -> (Var.name var ^ ":Data, ") :: str_list + | Error(_) -> (Var.name var ^ ":Error, ") :: str_list ) in + let register_string = String.concat register_list in + let stack_offset_str = match state.TypeInfo.stack_offset with + | Some(Ok(x)) -> begin + match Bitvector.to_int (Bitvector.signed x) with + | Ok(number) -> string_of_int number + | _ -> "NaN" + end + | _ -> "Unknown" in + Log_utils.debug + "[%s] {%s} TypeInfo at %s:\nRegister: %s\nStackOffset: %s" + name + version + (Address_translation.translate_tid_to_assembler_address_string block_tid tid_map) + register_string + stack_offset_str + +let print_type_info_tags ~project ~tid_map = + let program = Project.program project in + let functions = Term.enum sub_t program in + Seq.iter functions ~f:(fun func -> + let blocks = Term.enum blk_t func in + Seq.iter blocks ~f:(fun block -> + match Term.get_attr block type_info_tag with + | Some(start_state) -> print_type_info_to_debug start_state (Term.tid block) ~tid_map + | None -> (* block has no type info tag, which should not happen *) + Log_utils.error + "[%s] {%s} Block has no TypeInfo at %s (block TID %s)" + name + version + (Address_translation.translate_tid_to_assembler_address_string (Term.tid block) tid_map) + (Tid.name (Term.tid block)) + ) + ) + +(* Functions made public for unit tests *) +module Test = struct + let update_block_analysis = update_block_analysis +end diff --git a/src/analysis/type_inference.mli b/src/analysis/type_inference.mli new file mode 100644 index 000000000..77bc789dc --- /dev/null +++ b/src/analysis/type_inference.mli @@ -0,0 +1,47 @@ +(* This file contains analysis passes for type recognition *) + +open Bap.Std +open Core_kernel + + +(** The register type. *) +module Register : sig + type t = + | Pointer + | Data + [@@deriving bin_io, compare, sexp] + +end + + +module TypeInfo : sig + type reg_state = (Register.t, unit) Result.t Var.Map.t [@@deriving bin_io, compare, sexp] + type t = { + stack: Register.t Mem_region.t; + stack_offset: (Bitvector.t, unit) Result.t Option.t; + reg: reg_state; + } [@@deriving bin_io, compare, sexp] + + (* Pretty Printer. At the moment, the output is not pretty at all. *) + val pp: Format.formatter -> t -> unit +end + +val type_info_tag: TypeInfo.t Value.tag + +(** Computes TypeInfo for the given project. Adds tags to each block containing the + TypeInfo at the start of the block. *) +val compute_pointer_register: Project.t -> Project.t + +(** Print type info tags. TODO: If this should be used for more than debug purposes, + then the output format should be refactored accordingly. *) +val print_type_info_tags: project:Project.t -> tid_map:word Tid.Map.t -> unit + +(** Updates the type info for a single element (Phi/Def/Jmp) of a block. Input + is the type info before execution of the element, output is the type info + after execution of the element. *) +val update_type_info: Blk.elt -> TypeInfo.t -> project:Project.t -> TypeInfo.t + +(* functions made public for unit tests: *) +module Test : sig + val update_block_analysis: Blk.t -> TypeInfo.t -> project:Project.t -> TypeInfo.t +end diff --git a/src/checkers/cwe_476.ml b/src/checkers/cwe_476.ml index 55ebc3957..23c6a62f1 100644 --- a/src/checkers/cwe_476.ml +++ b/src/checkers/cwe_476.ml @@ -239,13 +239,13 @@ let print_hit tid ~sub ~function_names ~tid_map = match Jmp.kind jmp with | Call(call) -> begin match Call.target call with - | Direct(call_tid) -> Option.is_some (List.find function_names ~f:(fun name -> - if name = (Tid.name call_tid) then begin + | Direct(call_tid) -> Option.is_some (List.find function_names ~f:(fun fn_name -> + if fn_name = (Tid.name call_tid) then begin Log_utils.warn "[%s] {%s} (NULL Pointer Dereference) There is no check if the return value is NULL at %s (%s)." name version (Address_translation.translate_tid_to_assembler_address_string tid tid_map) - name; + fn_name; true end else false @@ -273,7 +273,7 @@ let check_cwe prog proj tid_map symbol_names parameters = Seq.iter subfunctions ~f:(fun subfn -> let cfg = Sub.to_cfg subfn in let cwe_hits = ref [] in - let empty = Map.empty Graphs.Ir.Node.comparator in + let empty = Map.empty (module Graphs.Ir.Node) in let init = Graphlib.Std.Solution.create empty [] in let equal = State.equal in let merge = State.union in diff --git a/src/cwe_checker_core.opam b/src/cwe_checker_core.opam new file mode 100644 index 000000000..e69de29bb diff --git a/src/dune b/src/dune new file mode 100644 index 000000000..033894ae1 --- /dev/null +++ b/src/dune @@ -0,0 +1,12 @@ +(library + (name cwe_checker_core) + (public_name cwe_checker_core) + (libraries + yojson + bap + core_kernel + core) + (preprocess (pps ppx_jane)) +) + +(include_subdirs unqualified) ; Include all subdirs when looking for source files diff --git a/src/dune-project b/src/dune-project new file mode 100644 index 000000000..a26d6e273 --- /dev/null +++ b/src/dune-project @@ -0,0 +1 @@ +(lang dune 1.6) diff --git a/src/utils/address_translation.ml b/src/utils/address_translation.ml index 097cf30bd..83dabca02 100644 --- a/src/utils/address_translation.ml +++ b/src/utils/address_translation.ml @@ -11,5 +11,5 @@ let generate_tid_map prog = inherit [addr Tid.Map.t] Term.visitor method enter_term _ t addrs = match Term.get_attr t address with | None -> addrs - | Some addr -> Map.add addrs ~key:(Term.tid t) ~data:addr + | Some addr -> Map.add_exn addrs ~key:(Term.tid t) ~data:addr end)#run prog Tid.Map.empty diff --git a/src/utils/cconv.ml b/src/utils/cconv.ml new file mode 100644 index 000000000..d7f95aa6e --- /dev/null +++ b/src/utils/cconv.ml @@ -0,0 +1,84 @@ + +open Bap.Std +open Core_kernel + +let dyn_syms = ref None + +let callee_saved_registers = ref None + +(** Return a list of registers that are callee-saved. + TODO: At least ARMv7 and PPC have floating point registers that are callee saved. Check their names in bap and then add them. *) +let callee_saved_register_list project = + let arch = Project.arch project in + match arch with + | `x86_64 -> (* System V ABI *) + "RBX" :: "RSP" :: "RBP" :: "R12" :: "R13" :: "R14" :: "R15" :: [] + | `x86_64 -> (* Microsoft x64 calling convention *) (* TODO: How to distinguish from System V? For the time being, only use the System V ABI, since it saves less registers. *) + "RBX" :: "RBP" :: "RDI" :: "RSI" :: "RSP" :: "R12" :: "R13" :: "R14" :: "R15" :: [] + | `x86 -> (* Both Windows and Linux save the same registers *) + "EBX" :: "ESI" :: "EDI" :: "EBP" :: [] + | `armv4 | `armv5 | `armv6 | `armv7 + | `armv4eb | `armv5eb | `armv6eb | `armv7eb + | `thumbv4 | `thumbv5 | `thumbv6 | `thumbv7 + | `thumbv4eb | `thumbv5eb | `thumbv6eb | `thumbv7eb -> (* ARM 32bit. R13 and SP are both names for the stack pointer. *) + "R4" :: "R5" :: "R6" :: "R7" :: "R8" :: "R9" :: "R10" :: "R11" :: "R13" :: "SP" :: [] + | `aarch64 | `aarch64_be -> (* ARM 64bit *) (* TODO: This architecture is not contained in the acceptance tests yet? *) + "X19" :: "X20" :: "X21" :: "X22" :: "X23" :: "X24" :: "X25" :: "X26" :: "X27" :: "X28" :: "X29" :: "SP" :: [] + | `ppc (* 32bit PowerPC *) (* TODO: add floating point registers. *) (* TODO: add CR2, CR3, CR4. Test their representation in bap first. *) + | `ppc64 | `ppc64le -> (* 64bit PowerPC *) + "R14" :: "R15" :: "R16" :: "R17" :: "R18" :: "R19" :: "R20" :: "R21" :: "R22" :: "R23" :: + "R24" :: "R25" :: "R26" :: "R27" :: "R28" :: "R29" :: "R30" :: "R31" :: "R1" :: "R2" :: [] + | `mips | `mips64 | `mips64el | `mipsel -> (* S8 and FP are the same register. bap uses FP, S8 is left there just in case. *) + "S0" :: "S1" :: "S2" :: "S3" :: "S4" :: "S5" :: "S6" :: "S7" :: "S8" :: "GP" :: "SP" :: "FP" :: [] + | _ -> failwith "No calling convention implemented for the given architecture." + +let is_callee_saved var project = + match !callee_saved_registers with + | Some(register_set) -> String.Set.mem register_set (Var.name var) + | None -> + callee_saved_registers := Some(String.Set.of_list (callee_saved_register_list project)); + String.Set.mem (Option.value_exn !callee_saved_registers) (Var.name var) + +(** Parse a line from the dyn-syms output table of readelf. Return the name of a symbol if the symbol is an extern function name. *) +let parse_dyn_sym_line line = + let line = ref (String.strip line) in + let str_list = ref [] in + while Option.is_some (String.rsplit2 !line ~on:' ') do + let (left, right) = Option.value_exn (String.rsplit2 !line ~on:' ') in + line := String.strip left; + str_list := right :: !str_list; + done; + match !str_list with + | _ :: value :: _ :: "FUNC" :: _ :: _ :: _ :: name :: [] -> begin + match ( String.strip ~drop:(fun x -> x = '0') value, String.lsplit2 name ~on:'@') with + | ("", Some(left, _)) -> Some(left) + | ("", None) -> Some(name) + | _ -> None (* The symbol has a nonzero value, so we assume that it is not an extern function symbol. *) + end + | _ -> None + +let parse_dyn_syms project = + match !dyn_syms with + | Some(symbol_set) -> symbol_set + | None -> + match Project.get project filename with + | None -> failwith "[CWE-checker] Project has no file name." + | Some(fname) -> begin + let cmd = Format.sprintf "readelf --dyn-syms %s" fname in + try + let in_chan = Unix.open_process_in cmd in + let lines = In_channel.input_lines in_chan in + let () = In_channel.close in_chan in begin + match lines with + | _ :: _ :: _ :: tail -> (* The first three lines are not part of the table *) + let symbol_set = String.Set.of_list (List.filter_map tail ~f:parse_dyn_sym_line) in + dyn_syms := Some(symbol_set); + symbol_set + | _ -> + dyn_syms := Some(String.Set.empty); + String.Set.empty (* *) + end + with + Unix.Unix_error (e,fm,argm) -> + failwith (Format.sprintf "[CWE-checker] Parsing of dynamic symbols failed: %s %s %s" (Unix.error_message e) fm argm) + end diff --git a/src/utils/cconv.mli b/src/utils/cconv.mli new file mode 100644 index 000000000..ce24c31d1 --- /dev/null +++ b/src/utils/cconv.mli @@ -0,0 +1,13 @@ +open Bap.Std +open Core_kernel + +(** Returns whether a variable is callee saved according to the calling convention + of the target architecture. Should only used for calls to functions outside + of the program, not for calls between functions inside the program. *) +val is_callee_saved: Var.t -> Project.t -> bool + + +(** Returns a list of those function names that are extern symbols. + TODO: Since we do not do name demangling here, check whether bap name demangling + yields different function names for the symbols. *) +val parse_dyn_syms: Project.t -> String.Set.t diff --git a/src/utils/symbol_utils.ml b/src/utils/symbol_utils.ml index bdc27a199..1e5b10432 100644 --- a/src/utils/symbol_utils.ml +++ b/src/utils/symbol_utils.ml @@ -147,3 +147,17 @@ let get_program_entry_points program = let entry_points = Seq.filter subfunctions ~f:(fun subfn -> Term.has_attr subfn Sub.entry_point) in let main_fn = Seq.filter subfunctions ~f:(fun subfn -> "@main" = Tid.name (Term.tid subfn)) in Seq.append main_fn entry_points + +let stack_register project = + let arch = Project.arch project in + let module Target = (val target_of_arch arch) in + Target.CPU.sp + +let flag_register_list project = + let arch = Project.arch project in + let module Target = (val target_of_arch arch) in + Target.CPU.zf :: Target.CPU.cf :: Target.CPU.vf :: Target.CPU.nf :: [] + +let arch_pointer_size_in_bytes project : int = + let arch = Project.arch project in + Size.in_bytes (Arch.addr_size arch) diff --git a/src/utils/symbol_utils.mli b/src/utils/symbol_utils.mli index 80ed2791f..b806b7138 100644 --- a/src/utils/symbol_utils.mli +++ b/src/utils/symbol_utils.mli @@ -68,3 +68,15 @@ val extract_direct_call_tid_from_block : Bap.Std.blk Bap.Std.term -> Bap.Std.tid TODO: The _start entry point usually calls a libc-function which then calls the main function. Since right now only direct calls are tracked, our graph traversal may never find the main function. For now, we add it by hand to the entry points. *) val get_program_entry_points : Bap.Std.program Bap.Std.term -> Bap.Std.sub Bap.Std.term Bap.Std.Seq.t + +(** Returns the stack register on the architecture of the given project. *) +val stack_register: Bap.Std.Project.t -> Bap.Std.Var.t + +(** Returns a list of the known flag registers on the architecture of the given project. + TODO: Right now it only returns flag registers that exist on all architectures. + We should add known architecture dependend flag registers, too. *) +val flag_register_list: Bap.Std.Project.t -> Bap.Std.Var.t list + + +(** Returns the pointer size in bytes on the architecture of the given project. *) +val arch_pointer_size_in_bytes: Bap.Std.Project.t -> int diff --git a/test/unit/Makefile b/test/unit/Makefile new file mode 100644 index 000000000..c31dcf7b4 --- /dev/null +++ b/test/unit/Makefile @@ -0,0 +1,9 @@ +all: + bapbundle remove cwe_checker_unit_tests.plugin + bapbuild -r -Is analysis cwe_checker_unit_tests.plugin -pkgs core,alcotest,yojson,unix,ppx_jane,cwe_checker_core + bapbundle install cwe_checker_unit_tests.plugin + bap ../artificial_samples/build/arrays_x64.out --pass=cwe-checker-unit-tests + bapbundle remove cwe_checker_unit_tests.plugin + +clean: + bapbuild -clean diff --git a/test/unit/analysis/mem_region_test.ml b/test/unit/analysis/mem_region_test.ml new file mode 100644 index 000000000..f57160f04 --- /dev/null +++ b/test/unit/analysis/mem_region_test.ml @@ -0,0 +1,98 @@ +open Bap.Std +open Core_kernel +open Cwe_checker_core + +let check msg x = Alcotest.(check bool) msg true x + +let test_add () : unit = + let bv num = Bitvector.of_int num ~width:32 in + let x = Mem_region.empty () in + let x = Mem_region.add x "Five" ~pos:(bv 3) ~size:(bv 5) in + let x = Mem_region.add x "Seven" ~pos:(bv 9) ~size:(bv 7) in + let x = Mem_region.add x "Three" ~pos:(bv 0) ~size:(bv 3) in + check "add_ok" (Some(Ok("Five", bv 5)) = (Mem_region.get x (bv 3))); + check "add_err" (Some(Error(())) = (Mem_region.get x (bv 1))); + check "add_none" (None = (Mem_region.get x (bv 8))) + +let test_minus () = + let bv num = Bitvector.of_int num ~width:32 in + let x = Mem_region.empty () in + let x = Mem_region.add x "One" ~pos:(bv (-8)) ~size:(bv 8) in + check "negative_index" (Some(Ok("One", bv 8)) = Mem_region.get x (Bitvector.unsigned (bv (-8)))) + +let test_remove () = + let bv num = Bitvector.of_int num ~width:32 in + let x = Mem_region.empty () in + let x = Mem_region.add x "One" ~pos:(bv 0) ~size:(bv 10) in + let x = Mem_region.add x "Two" ~pos:(bv 15) ~size:(bv 11) in + let x = Mem_region.remove x ~pos:(bv 5) ~size:(bv 20) in + check "remove_error_before" (Some(Error()) = Mem_region.get x (bv 4)); + check "remove_none1" (None = Mem_region.get x (bv 5)); + check "remove_none2" (None = Mem_region.get x (bv 24)); + check "remove_error_after1" (Some(Error()) = Mem_region.get x (bv 25)); + check "remove_error_after2" (None = Mem_region.get x (bv 26)) + +let test_mark_error () = + let bv num = Bitvector.of_int num ~width:32 in + let x = Mem_region.empty () in + let x = Mem_region.add x "One" ~pos:(bv 0) ~size:(bv 10) in + let x = Mem_region.mark_error x ~pos:(bv 5) ~size:(bv 10) in + check "mark_error1" (Some(Error()) = Mem_region.get x (bv 0)); + check "mark_error2" (Some(Error()) = Mem_region.get x (bv 14)); + check "mark_error3" (None = Mem_region.get x (bv 15)) + +let test_merge () = + let bv num = Bitvector.of_int num ~width:32 in + let x = Mem_region.empty () in + let x = Mem_region.add x "One" ~pos:(bv 0) ~size:(bv 10) in + let x = Mem_region.add x "Two" ~pos:(bv 15) ~size:(bv 5) in + let x = Mem_region.add x "Three" ~pos:(bv 25) ~size:(bv 5) in + let y = Mem_region.empty () in + let y = Mem_region.add y "One" ~pos:(bv 1) ~size:(bv 10) in + let y = Mem_region.add y "Two" ~pos:(bv 15) ~size:(bv 5) in + let y = Mem_region.add y "Four" ~pos:(bv 25) ~size:(bv 5) in + let merge_fn a b = if a = b then Some(Ok(a)) else Some(Error()) in + let z = Mem_region.merge x y ~data_merge:merge_fn in + check "merge_intersect" (Some(Error()) = Mem_region.get z (bv 0)); + check "merge_match_ok" (Some(Ok("Two", bv 5)) = Mem_region.get z (bv 15)); + check "merge_match_error" (Some(Error()) = Mem_region.get z (bv 25)) + +let test_equal () = + let bv num = Bitvector.of_int num ~width:32 in + let x = Mem_region.empty () in + let x = Mem_region.add x "One" ~pos:(bv 0) ~size:(bv 10) in + let x = Mem_region.add x "Two" ~pos:(bv 15) ~size:(bv 5) in + let y = Mem_region.empty () in + let y = Mem_region.add y "Two" ~pos:(bv 15) ~size:(bv 5) in + check "equal_no" (false = (Mem_region.equal x y ~data_equal:(fun x y -> x = y))); + let y = Mem_region.add y "One" ~pos:(bv 0) ~size:(bv 10) in + check "equal_yes" (Mem_region.equal x y ~data_equal:(fun x y -> x = y)) + +let test_around_zero () = + let bv num = Bitvector.of_int num ~width:32 in + let x = Mem_region.empty () in + let x = Mem_region.add x "One" ~pos:(bv (-5)) ~size:(bv 10) in + let x = Mem_region.add x "Two" ~pos:(bv 0) ~size:(bv 10) in + check "around_zero1" (Some(Error()) = Mem_region.get x (bv (-5))); + let x = Mem_region.empty () in + let x = Mem_region.add x "One" ~pos:(bv 0) ~size:(bv 10) in + let x = Mem_region.add x "Two" ~pos:(bv (-5)) ~size:(bv 10) in + check "around_zero2" (Some(Error()) = Mem_region.get x (bv 0)); + let x = Mem_region.empty () in + let x = Mem_region.add x "One" ~pos:(bv (-5)) ~size:(bv 20) in + let x = Mem_region.add x "Two" ~pos:(bv 0) ~size:(bv 10) in + check "around_zero3" (Some(Error()) = Mem_region.get x (bv (-5))); + let x = Mem_region.empty () in + let x = Mem_region.add x "One" ~pos:(bv 0) ~size:(bv 10) in + let x = Mem_region.add x "Two" ~pos:(bv (-5)) ~size:(bv 20) in + check "around_zero2" (Some(Error()) = Mem_region.get x (bv 0)) + +let tests = [ + "Add", `Quick, test_add; + "Negative Indices", `Quick, test_minus; + "Remove", `Quick, test_remove; + "Mark_error", `Quick, test_mark_error; + "Merge", `Quick, test_merge; + "Equal", `Quick, test_equal; + "Around Zero", `Quick, test_around_zero; +] diff --git a/test/unit/analysis/mem_region_test.mli b/test/unit/analysis/mem_region_test.mli new file mode 100644 index 000000000..dca4bab74 --- /dev/null +++ b/test/unit/analysis/mem_region_test.mli @@ -0,0 +1,4 @@ +open Bap.Std +open Core_kernel + +val tests: unit Alcotest.test_case list diff --git a/test/unit/analysis/type_inference_test.ml b/test/unit/analysis/type_inference_test.ml new file mode 100644 index 000000000..76a33535c --- /dev/null +++ b/test/unit/analysis/type_inference_test.ml @@ -0,0 +1,81 @@ +open Bap.Std +open Core_kernel +open Cwe_checker_core + +open Type_inference +open Type_inference.Test + +let check msg x = Alcotest.(check bool) msg true x + +let example_project = ref None + +(* TODO: As soon as more pointers than stack pointer are tracked, add more tests! *) + +let create_block_from_defs def_list = + let block = Blk.Builder.create () in + let () = List.iter def_list ~f:(fun def -> Blk.Builder.add_def block def) in + Blk.Builder.result block + +let start_state stack_register project = + let bv x = Bitvector.of_int x ~width:(Symbol_utils.arch_pointer_size_in_bytes project * 8) in + let start_reg = Var.Map.empty in + let start_reg = Map.add_exn start_reg ~key:stack_register ~data:(Ok(Register.Pointer)) in + { TypeInfo.stack = Mem_region.empty (); + TypeInfo.stack_offset = Some (Ok(bv 0)); + TypeInfo.reg = start_reg; + } + +let test_update_stack_offset () = + let project = Option.value_exn !example_project in + let bv x = Bitvector.of_int x ~width:(Symbol_utils.arch_pointer_size_in_bytes project * 8) in + let stack_register = Symbol_utils.stack_register project in + let fn_start_state = start_state stack_register project in + let def1 = Def.create stack_register (Bil.binop Bil.plus (Bil.var stack_register) (Bil.int (bv 8))) in + let def2 = Def.create stack_register (Bil.binop Bil.minus (Bil.var stack_register) (Bil.int (bv 16))) in + let block = create_block_from_defs [def1; def2] in + let state = update_block_analysis block fn_start_state project in + let () = check "update_stack_offset" (state.TypeInfo.stack_offset = Some(Ok(Bitvector.unsigned (bv (-8))))) in + () + +let test_update_reg () = + let project = Option.value_exn !example_project in + let bv x = Bitvector.of_int x ~width:(Symbol_utils.arch_pointer_size_in_bytes project * 8) in + let stack_register = Symbol_utils.stack_register project in + let fn_start_state = start_state stack_register project in + let register1 = Var.create "Register1" (Bil.Imm (Symbol_utils.arch_pointer_size_in_bytes project * 8)) in + let register2 = Var.create "Register2" (Bil.Imm (Symbol_utils.arch_pointer_size_in_bytes project * 8)) in + let def1 = Def.create register1 (Bil.binop Bil.AND (Bil.var stack_register) (Bil.int (bv 8))) in + let def2 = Def.create register2 (Bil.binop Bil.XOR (Bil.var register1) (Bil.var stack_register)) in + let block = create_block_from_defs [def1; def2] in + let state = update_block_analysis block fn_start_state project in + let () = check "update_pointer_register" (Var.Map.find state.TypeInfo.reg register1 = Some(Ok(Pointer))) in + let () = check "update_data_register" (Var.Map.find state.TypeInfo.reg register2 = Some(Ok(Data))) in + let def1 = Def.create register1 (Bil.Load (Bil.var register1, Bil.var register2, Bitvector.LittleEndian, `r64) ) in + let block = create_block_from_defs [def1;] in + let state = update_block_analysis block fn_start_state project in + let () = check "add_mem_address_registers" (Var.Map.find state.TypeInfo.reg register2 = Some(Ok(Pointer))) in + () + +let test_update_stack () = + let project = Option.value_exn !example_project in + let bv x = Bitvector.of_int x ~width:(Symbol_utils.arch_pointer_size_in_bytes project * 8) in + let stack_register = Symbol_utils.stack_register project in + let fn_start_state = start_state stack_register project in + let register1 = Var.create "Register1" (Bil.Imm (Symbol_utils.arch_pointer_size_in_bytes project * 8)) in + let register2 = Var.create "Register2" (Bil.Imm (Symbol_utils.arch_pointer_size_in_bytes project * 8)) in + let mem_reg = Var.create "Mem_reg" (Bil.Imm (Symbol_utils.arch_pointer_size_in_bytes project * 8)) in + let def1 = Def.create register1 (Bil.binop Bil.AND (Bil.var stack_register) (Bil.int (bv 8))) in + let def2 = Def.create mem_reg (Bil.Store ((Bil.var mem_reg), (Bil.binop Bil.PLUS (Bil.var stack_register) (Bil.int (bv (-8)))), (Bil.var stack_register), Bitvector.LittleEndian, `r64)) in + let def3 = Def.create register2 (Bil.Load (Bil.var register2, (Bil.binop Bil.MINUS (Bil.var stack_register) (Bil.int (bv 8))), Bitvector.LittleEndian, `r64) ) in + let block = create_block_from_defs [def1; def2; def3;] in + let state = update_block_analysis block fn_start_state project in + let () = check "write_to_stack" ((Mem_region.get state.TypeInfo.stack (bv (-8))) = Some(Ok(Pointer, bv (Symbol_utils.arch_pointer_size_in_bytes project)))) in + let () = check "load_from_stack" (Var.Map.find state.TypeInfo.reg register2 = Some(Ok(Pointer))) in + () + + +let tests = [ + "Update Stack Offset", `Quick, test_update_stack_offset; + "Update Register", `Quick, test_update_reg; + "Update Stack", `Quick, test_update_stack; +] diff --git a/test/unit/analysis/type_inference_test.mli b/test/unit/analysis/type_inference_test.mli new file mode 100644 index 000000000..5e5173a93 --- /dev/null +++ b/test/unit/analysis/type_inference_test.mli @@ -0,0 +1,7 @@ +open Bap.Std +open Core_kernel + + +val example_project: Project.t option ref + +val tests: unit Alcotest.test_case list diff --git a/test/unit/cwe_checker_unit_tests.ml b/test/unit/cwe_checker_unit_tests.ml new file mode 100644 index 000000000..d16989d2b --- /dev/null +++ b/test/unit/cwe_checker_unit_tests.ml @@ -0,0 +1,21 @@ +open Bap.Std +open Core_kernel +open Cwe_checker_core + +let run_tests project = + Type_inference_test.example_project := Some(project); + Alcotest.run "Unit tests" ~argv:[|"DoNotComplainWhenRunAsABapPlugin";"--color=always";|] [ + "Mem_region_tests", Mem_region_test.tests; + "Type_inference_tests", Type_inference_test.tests; + ] + +let () = + (* Check whether this file is run as an executable (via dune runtest) or + as a bap plugin *) + if Sys.argv.(0) = "bap" then + (* The file was run as a bap plugin. *) + Project.register_pass' run_tests + else + (* The file was run as a standalone executable. Use make to build and run the unit test plugin *) + let () = Sys.chdir (Sys.getenv "PWD" ^ "/test/unit") in + exit (Sys.command "make all") diff --git a/test/unit/dune b/test/unit/dune new file mode 100644 index 000000000..14f6fc8a5 --- /dev/null +++ b/test/unit/dune @@ -0,0 +1,18 @@ +(executable + (name cwe_checker_unit_tests) + (libraries + alcotest + yojson + bap + cwe_checker_core + core_kernel + core) + (preprocess (pps ppx_jane)) +) + +(include_subdirs unqualified) ; Include all subdirs when looking for source files + +(alias + (name runtest) + (deps cwe_checker_unit_tests.exe) +(action (run %{deps} --color=always)))