1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
(* SPDX-License-Identifier: AGPL-3.0-or-later *)
(* Copyright © 2021-2024 OCamlPro *)
(* Written by the Owi programmers *)

(* Multicore is based on several layers of monad transformers. The module as a whole is made to provide a monad to explore in parallel different possibilites, with a notion of priority. *)

(* Add a notion of faillibility to the evaluation. "Transformer without module functor" style. *)
type 'a t = (('a, Bug.t) result, Thread.t) State_monad.t

(* ================================================
   Functions to operate on the three monads layers.
   ================================================ *)

let[@inline] return x : _ t = State_monad.return (Ok x)

let[@inline] lift x =
  let ( let+ ) = State_monad.( let+ ) in
  let+ x in
  Ok x

let[@inline] bind (mx : _ t) f : _ t =
  let ( let* ) = State_monad.( let* ) in
  let* mx in
  match mx with Ok x -> f x | Error _ as mx -> State_monad.return mx

let[@inline] ( let* ) mx f = bind mx f

let[@inline] map mx f =
  let ( let+ ) = State_monad.( let+ ) in
  let+ mx in
  match mx with Ok x -> Ok (f x) | Error _ as mx -> mx

let[@inline] ( let+ ) mx f = map mx f

let[@inline] lift_schedulable (v : 'a Scheduler.Schedulable.t) : 'a t =
  let v = State_monad.lift v in
  lift v

let[@inline] with_thread (f : Thread.t -> 'a) : 'a t =
  let x = State_monad.with_state (fun st -> (f st, st)) in
  lift x

let thread = with_thread Fun.id

let[@inline] modify_thread f = lift (State_monad.modify_state f)

let[@inline] set_thread st = modify_thread (Fun.const st)

let solver_to_use = ref None

let solver_dls_key =
  Domain.DLS.new_key (fun () ->
    let solver_to_use = !solver_to_use in
    match solver_to_use with
    | None -> assert false
    | Some solver_to_use -> Solver.fresh solver_to_use () )

let[@inline] solver () = Domain.DLS.get solver_dls_key

(* Create two new branches, they do not yield so the yield should be created manually! *)
let[@inline] choose a b = State_monad.liftF2 Scheduler.Schedulable.choose a b

(* Yield the current branch (i.e. add it to the work queue so that it gets executed later. )*)
let yield prio = lift_schedulable @@ Scheduler.Schedulable.yield prio

(* Child will be a new branch that immediately yields, and parent will execute directly without yielding. *)
let[@inline] fork ~(parent : 'a t) ~child : 'a t =
  let prio, child = child in
  let child =
    let* () = yield prio in
    child
  in
  choose parent child

let stop = lift_schedulable Scheduler.Schedulable.stop

(* ============================================
   Now this is actual symbolic execution stuff!
   ============================================ *)

let add_pc (c : Symbolic_boolean.t) =
  let c = Smtml.Typed.simplify c in
  match Smtml.Typed.view c with
  | Val True -> return ()
  | Val False -> stop
  | _ ->
    let* thread in
    let new_thread = Thread.add_pc thread c in
    set_thread new_thread
[@@inline]

let get_pc () =
  let+ thread in
  let pc = thread.pc in
  let pc = Symbolic_path_condition.slice pc in
  List.fold_left Smtml.Expr.Set.union Smtml.Expr.Set.empty pc

let add_breadcrumb crumb =
  modify_thread (fun t -> Thread.add_breadcrumb t crumb)

let add_label label = modify_thread (fun t -> Thread.add_label t label)

let open_scope scope = modify_thread (fun t -> Thread.open_scope t scope)

let close_scope () = modify_thread (fun t -> Thread.close_scope t)

let with_new_invisible_symbol ty f =
  let* thread in
  let n = thread.num_symbols in
  let+ () = modify_thread Thread.incr_num_symbols in
  let sym = Fmt.kstr (Smtml.Symbol.make ty) "symbol_invisible_%i" n in
  f sym

let with_new_symbol ty f =
  let* thread in
  let n = thread.num_symbols in
  let sym = Fmt.kstr (Smtml.Symbol.make ty) "symbol_%d" n in
  let+ () =
    modify_thread (fun thread ->
      let thread = Thread.add_symbol thread sym in
      Thread.incr_num_symbols thread )
  in
  f sym

let check_reachability v =
  let* thread in
  let solver = solver () in
  let pc = thread.pc |> Symbolic_path_condition.slice_on_condition v in
  let stats = thread.bench_stats in
  let reachability =
    Benchmark.handle_time_span stats.solver_sat_time @@ fun () ->
    Solver.check solver pc
  in
  return reachability

let get_model_or_stop symbol =
  let* thread in
  let solver = solver () in
  let set = thread.pc |> Symbolic_path_condition.slice_on_symbol symbol in
  let stats = thread.bench_stats in
  let symbol_scopes = Symbol_scope.of_symbol symbol in
  let sat_model =
    Benchmark.handle_time_span stats.solver_intermediate_model_time (fun () ->
      Solver.model_of_set solver ~symbol_scopes ~set )
  in
  match sat_model with
  | `Unsat -> stop
  | `Model model -> begin
    match Smtml.Model.evaluate model symbol with
    | Some v -> return v
    | None ->
      (* the model exists so the symbol should evaluate *)
      assert false
  end
  | `Unknown ->
    (* It can happen when the solver is interrupted *)
    (* TODO: once https://github.com/formalsec/smtml/pull/479 is merged
               if solver was interrupted then stop else assert false *)
    stop

let select_inner ~with_breadcrumbs (cond : Symbolic_boolean.t)
  ~instr_counter_true ~instr_counter_false =
  let cond = Smtml.Typed.simplify cond in
  match Smtml.Typed.view cond with
  | Val True -> return true
  | Val False -> return false
  | _ ->
    let is_other_branch_unsat = Atomic.make false in
    let branch condition final_value priority =
      let* () = add_pc condition in
      let* () =
        if with_breadcrumbs then add_breadcrumb (if final_value then 1 else 0)
        else return ()
      in
      (* this is an optimisation under the assumption that the PC is always SAT (i.e. we are performing eager pruning), in such a case, when a branch is unsat, we don't have to check the reachability of the other's branch negation, because it is always going to be SAT. *)
      if Atomic.get is_other_branch_unsat then begin
        Log.debug (fun m ->
          m "The SMT call for the %b branch was optimized away" final_value );
        (* the other branch is unsat, we must be SAT and don't need to check reachability! *)
        return final_value
      end
      else begin
        (* the other branch is SAT (or we haven't computed it yet), so we have to check reachability *)
        let* () = yield priority in
        let* satisfiability = check_reachability condition in
        begin match satisfiability with
        | `Sat ->
          let* () = modify_thread (Thread.set_priority priority) in
          return final_value
        | `Unsat ->
          Atomic.set is_other_branch_unsat true;
          stop
        | `Unknown ->
          (* It can happen when the solver is interrupted *)
          (* TODO: once https://github.com/formalsec/smtml/pull/479 is merged
                                   if solver was interrupted then stop else assert false *)
          stop
        end
      end
    in

    let* thread in

    let prio_true =
      let instr_counter =
        match instr_counter_true with
        | None -> thread.priority.instr_counter
        | Some instr_counter -> instr_counter
      in
      Prio.v ~instr_counter ~distance_to_unreachable:None ~depth:thread.depth
    in
    let true_branch = branch cond true prio_true in

    let prio_false =
      let instr_counter =
        match instr_counter_false with
        | None -> thread.priority.instr_counter
        | Some instr_counter -> instr_counter
      in
      Prio.v ~instr_counter ~distance_to_unreachable:None ~depth:thread.depth
    in
    let false_branch = branch (Symbolic_boolean.not cond) false prio_false in
    Thread.incr_path_count thread;

    choose true_branch false_branch
[@@inline]

let select (cond : Symbolic_boolean.t) ~instr_counter_true ~instr_counter_false
    =
  select_inner cond ~instr_counter_true ~instr_counter_false
    ~with_breadcrumbs:true
[@@inline]

let summary_symbol (e : Smtml.Typed.Bitv32.t) :
  (Smtml.Typed.Bool.t option * Smtml.Symbol.t) t =
  let* thread in
  match Smtml.Typed.view e with
  | Symbol sym -> return (None, sym)
  | _ ->
    let num_symbols = thread.num_symbols in
    let+ () = modify_thread Thread.incr_num_symbols in
    let name = Fmt.str "choice_i32_%i" num_symbols in
    (* TODO: having to build two times the symbol this way is not really elegant... *)
    let sym = Smtml.Symbol.make Smtml.Typed.Types.(to_ty bitv32) name in
    let assign = Smtml.Typed.Bitv32.(eq (symbol sym) e) in
    (Some assign, sym)

let select_i32 (i : Symbolic_i32.t) : int32 t =
  match Smtml.Typed.view i with
  | Val (Bitv bv) when Smtml.Bitvector.numbits bv <= 32 ->
    return (Smtml.Bitvector.to_int32 bv)
  | _ ->
    let* assign, symbol = summary_symbol i in
    let* () =
      match assign with Some assign -> add_pc assign | None -> return ()
    in
    let rec generator () =
      let* possible_value = get_model_or_stop symbol in
      let i =
        match possible_value with
        | Smtml.Value.Bitv bv ->
          assert (Smtml.Bitvector.numbits bv <= 32);
          Smtml.Bitvector.to_int32 bv
        | _ ->
          (* it should be a value! *)
          assert false
      in
      (* TODO: everything which follows look like select_inner and could probably be simplified by calling it directly! *)
      let this_value_cond =
        Symbolic_i32.eq_concrete (Smtml.Typed.Bitv32.symbol symbol) i
      in
      let this_val_branch =
        let* () = add_breadcrumb (Int32.to_int i) in
        let* () = add_pc this_value_cond in
        return i
      in

      let not_this_value_cond = Symbolic_boolean.not this_value_cond in
      let not_this_val_branch =
        let* () = add_pc not_this_value_cond in
        generator ()
      in
      let* thread in
      Thread.incr_path_count thread;

      (* TODO: better prio here? *)
      let prio =
        Prio.v ~instr_counter:thread.priority.instr_counter
          ~distance_to_unreachable:None ~depth:thread.depth
      in

      fork ~parent:this_val_branch ~child:(prio, not_this_val_branch)
    in
    generator ()

let bug kind =
  let* thread in
  let* model =
    let stats = thread.bench_stats in
    Benchmark.handle_time_span stats.solver_final_model_time @@ fun () ->
    let solver = solver () in
    let path_condition = thread.pc in
    match Solver.model_of_path_condition solver ~path_condition with
    | Some model -> return model
    | None ->
      (* It can happen when the solver is interrupted *)
      (* TODO: once https://github.com/formalsec/smtml/pull/479 is merged
             if solver was interrupted then stop else assert false *)
      stop
  in
  State_monad.return (Error { Bug.kind; model; thread })

let trap t = bug (`Trap t)

let assertion (c : Symbolic_boolean.t) =
  (* TODO: better prio here ? *)
  let* assertion_true =
    select_inner c ~with_breadcrumbs:false ~instr_counter_true:None
      ~instr_counter_false:None
  in
  if assertion_true then return () else bug (`Assertion c)

let ite (c : Symbolic_boolean.t) ~(if_true : Symbolic_value.t)
  ~(if_false : Symbolic_value.t) : Symbolic_value.t t =
  match (if_true, if_false) with
  | I32 if_true, I32 if_false ->
    let res = Symbolic_boolean.ite c if_true if_false in
    return (Symbolic_value.I32 res)
  | I64 if_true, I64 if_false ->
    let res = Symbolic_boolean.ite c if_true if_false in
    return (Symbolic_value.I64 res)
  | F32 if_true, F32 if_false ->
    return (Symbolic_value.F32 (Symbolic_boolean.ite c if_true if_false))
  | F64 if_true, F64 if_false ->
    return (Symbolic_value.F64 (Symbolic_boolean.ite c if_true if_false))
  | Ref _, Ref _ ->
    (* TODO: better prio here *)
    let+ b = select c ~instr_counter_true:None ~instr_counter_false:None in
    if b then if_true else if_false
  | _, _ -> assert false

let assume condition =
  let condition = Smtml.Typed.simplify condition in
  match Smtml.Typed.view condition with
  | Val True -> return ()
  | Val False -> stop
  | _ -> (
    let* () = add_pc condition in
    let* satisfiability = check_reachability condition in
    match satisfiability with
    | `Sat -> return ()
    | `Unsat -> stop
    | `Unknown ->
      (* It can happen when the solver is interrupted *)
      (* TODO: once https://github.com/formalsec/smtml/pull/479 is merged
                         if solver was interrupted then stop else assert false *)
      stop )