(* $Id: uq_ssl.ml 1445 2010-04-27 15:51:50Z gerd $ *) module Debug = struct let enable = ref false end let dlog = Netlog.Debug.mk_dlog "Uq_ssl" Debug.enable let dlogr = Netlog.Debug.mk_dlogr "Uq_ssl" Debug.enable let () = Netlog.Debug.register_module "Uq_ssl" Debug.enable open Printf exception Ssl_error of Ssl.ssl_error type ssl_socket_state = [ `Unset | `Client | `Server | `Unclean | `Clean ] let string_of_socket_state = function | `Unset -> "Unset" | `Client -> "Client" | `Server -> "Server" | `Unclean -> "Unclean" | `Clean -> "Clean" class type ssl_multiplex_controller = object inherit Uq_engines.multiplex_controller method ssl_socket : Ssl.socket method ssl_socket_state : ssl_socket_state method ssl_connecting : bool method ssl_accepting : bool method start_ssl_connecting : when_done:(exn option -> unit) -> unit -> unit method start_ssl_accepting : when_done:(exn option -> unit) -> unit -> unit end let string_of_tag = function | `Connecting -> "Connecting" | `Accepting -> "Accepting" | `Reading -> "Reading" | `Writing -> "Writing" | `Shutting_down -> "Shutting_down" class ssl_mplex_ctrl ?(close_inactive_descr=true) ?(preclose = fun () -> ()) fd ssl_sock esys : ssl_multiplex_controller = let () = Unix.set_nonblock fd in let fdi = Netsys.int64_of_file_descr fd in object(self) val mutable alive = true (* if false => state in { `Clean, `Unclean } *) val mutable read_eof = false val mutable wrote_eof = false val mutable state = (`Unset : ssl_socket_state) val mutable connecting = false (* true only in state `Unset *) val mutable accepting = false (* true only in state `Unset *) val mutable reading = None (* <> None only in states `Client/`Server *) val mutable writing = None (* <> None only in states `Client/`Server *) val mutable shutting_down = None (* <> None only in states `Client/`Server *) val mutable disconnecting = None val mutable have_handler = false val mutable pending = [] (* list of pending socket operations *) val mutable expecting_input = false val mutable expecting_output = false val group = Unixqueue.new_group esys method alive = alive method ssl_socket = ssl_sock method ssl_socket_state = state method ssl_connecting = connecting method ssl_accepting = accepting method reading = reading <> None method writing = writing <> None method shutting_down = shutting_down <> None method read_eof = read_eof method wrote_eof = wrote_eof method supports_half_open_connection = false method mem_supported = false method start_ssl_connecting ~when_done () = if state <> `Unset then failwith "#start_connecting: no longer possible in this state"; if connecting || accepting then failwith "#start_connecting: handshake already in progress"; dlogr (fun () -> sprintf "FD %Ld: start_ssl_connecting" fdi); let when_done arg = dlogr (fun () -> sprintf "FD %Ld: done start_ssl_connecting" fdi); when_done arg in self # nonblock_operation (ref false) `Connecting (fun () -> try Ssl.connect ssl_sock; state <- `Client; connecting <- false; (false, false, fun () -> when_done None) with | Ssl.Connection_error Ssl.Error_want_read -> (true, false, fun () -> ()) | Ssl.Connection_error Ssl.Error_want_write -> (false, true, fun () -> ()) | Ssl.Connection_error ssl_err -> state <- `Unclean; connecting <- false; (false, false, fun () -> when_done (Some (Ssl_error ssl_err))) | err -> state <- `Unclean; connecting <- false; (false, false, fun () -> when_done (Some err)) ); connecting <- true method start_ssl_accepting ~when_done () = if state <> `Unset then failwith "#start_accepting: no longer possible in this state"; if connecting || accepting then failwith "#start_accepting: handshake already in progress"; dlogr (fun () -> sprintf "FD %Ld: start_ssl_accepting" fdi); let when_done arg = dlogr (fun () -> sprintf "FD %Ld: done start_ssl_accepting" fdi); when_done arg in self # nonblock_operation (ref false) `Accepting (fun () -> try Ssl.accept ssl_sock; state <- `Server; accepting <- false; (false, false, fun () -> when_done None) with | Ssl.Accept_error Ssl.Error_want_read -> (true, false, fun () -> ()) | Ssl.Accept_error Ssl.Error_want_write -> (false, true, fun () -> ()) | Ssl.Accept_error ssl_err -> state <- `Unclean; accepting <- false; (false, false, fun () -> when_done (Some (Ssl_error ssl_err))) | err -> state <- `Unclean; accepting <- false; (false, false, fun () -> when_done (Some err);) ); accepting <- true; method start_reading ?(peek = fun() -> ()) ~when_done s pos len = if pos < 0 || len < 0 || pos + len > String.length s then invalid_arg "#start_reading"; if state <> `Client && state <> `Server then failwith "#start_reading: bad state"; if reading <> None then failwith "#start_reading: already reading"; if shutting_down <> None then failwith "#start_reading: already shutting down"; dlogr (fun () -> sprintf "FD %Ld: start_reading" fdi); let when_done arg = dlogr (fun () -> sprintf "FD %Ld: done start_reading" fdi); when_done arg in let cancel_flag = ref false in self # nonblock_operation cancel_flag `Reading (fun () -> try (* peek(); *) (* [peek] is used by auth-local. It does not work for SSL. *) let n = Ssl_exts.single_read ssl_sock s pos len in reading <- None; assert(n > 0); (false, false, fun () -> when_done None n) with | Ssl.Read_error Ssl.Error_zero_return -> (* Read EOF *) read_eof <- true; (* Note: read_eof should be consistent with Ssl.read *) (false, false, fun () -> when_done (Some End_of_file) 0) | Ssl.Read_error Ssl.Error_want_read -> (true, false, fun () -> ()) | Ssl.Read_error Ssl.Error_want_write -> (false, true, fun () -> ()) | Ssl.Read_error ssl_err -> state <- `Unclean; reading <- None; (false, false, fun () -> when_done (Some (Ssl_error ssl_err)) 0) | err -> state <- `Unclean; reading <- None; (false, false, fun () -> when_done (Some err) 0) ); reading <- Some (when_done, cancel_flag) method start_mem_reading ?(peek = fun() -> ()) ~when_done m pos len = raise Uq_engines.Mem_not_supported method cancel_reading () = dlogr (fun () -> sprintf "FD %Ld: cancel_reading" fdi); match reading with | None -> () | Some (f_when_done, cancel_flag) -> assert(not !cancel_flag); self # cancel_operation `Reading; cancel_flag := true; reading <- None; f_when_done (Some Uq_engines.Cancelled) 0 method start_writing ~when_done s pos len = if pos < 0 || len < 0 || pos + len > String.length s then invalid_arg "#start_writing"; if state <> `Client && state <> `Server then failwith "#start_writing: bad state"; if writing <> None then failwith "#start_writing: already reading"; if shutting_down <> None then failwith "#start_writing: already shutting down"; if wrote_eof then failwith "#start_writing: already past EOF"; dlogr (fun () -> sprintf "FD %Ld: start_writing" fdi); let when_done arg = dlogr (fun () -> sprintf "FD %Ld: done start_writing" fdi); when_done arg in let cancel_flag = ref false in self # nonblock_operation cancel_flag `Writing (fun () -> try let n = Ssl_exts.single_write ssl_sock s pos len in writing <- None; (false, false, fun () -> when_done None n) with | Ssl.Write_error Ssl.Error_zero_return -> (false, true, fun () -> ()) | Ssl.Write_error Ssl.Error_want_read -> (true, false, fun () -> ()) | Ssl.Write_error Ssl.Error_want_write -> (false, true, fun () -> ()) | Ssl.Write_error ssl_err -> state <- `Unclean; writing <- None; (false, false, fun () -> when_done (Some (Ssl_error ssl_err)) 0) | err -> state <- `Unclean; writing <- None; (false, false, fun () -> when_done (Some err) 0) ); writing <- Some (when_done, cancel_flag) method start_mem_writing ~when_done m pos len = raise Uq_engines.Mem_not_supported method start_writing_eof ~when_done () = failwith "#start_writing_eof: operation not supported"; method cancel_writing () = dlogr (fun () -> sprintf "FD %Ld: cancel_writing" fdi); match writing with | None -> () | Some (f_when_done, cancel_flag) -> assert(not !cancel_flag); self # cancel_operation `Writing; cancel_flag := true; writing <- None; f_when_done (Some Uq_engines.Cancelled) 0 method start_shutting_down ?(linger = 60.0) ~when_done () = if state <> `Client && state <> `Server then failwith "#start_shutting_down: bad state"; if reading <> None || writing <> None then failwith "#start_shutting_down: still reading or writing"; if shutting_down <> None then failwith "#start_shutting_down: already shutting down"; dlogr (fun () -> sprintf "FD %Ld: start_shutting_down" fdi); let when_done arg = dlogr (fun () -> sprintf "FD %Ld: done start_shutting_down" fdi); when_done arg in let n = ref 0 in let cancel_flag = ref false in self # nonblock_operation cancel_flag `Shutting_down (fun () -> try Ssl_exts.single_shutdown ssl_sock; incr n; let (rcvd_shutdown, sent_shutdown) = Ssl_exts.get_shutdown ssl_sock in if rcvd_shutdown then read_eof <- true; if sent_shutdown then wrote_eof <- true; if !n=2 && not (rcvd_shutdown && sent_shutdown) then ( (* Unclean crash *) shutting_down <- None; state <- `Unclean; if rcvd_shutdown || sent_shutdown then (false, false, fun () -> when_done None) else (false, false, fun () -> when_done(Some(Failure "Unclean SSL shutdown"))) ) else match (rcvd_shutdown, sent_shutdown) with | (false, false) -> (* strange *) (false, true, fun () -> ()) | (true, false) -> (false, true, fun () -> ()) | (false, true) -> (true, false, fun () -> ()) | (true, true) -> shutting_down <- None; state <- `Clean; (false, false, fun () -> when_done None) with | Ssl_exts.Shutdown_error Ssl.Error_want_read -> (true, false, fun () -> ()) | Ssl_exts.Shutdown_error Ssl.Error_want_write -> (false, true, fun () -> ()) | Ssl_exts.Shutdown_error ssl_err -> state <- `Unclean; shutting_down <- None; (false, false, fun () -> when_done (Some (Ssl_error ssl_err))) | err -> state <- `Unclean; shutting_down <- None; (false, false, fun () -> when_done (Some err)) ); shutting_down <- Some(when_done, cancel_flag) method cancel_shutting_down () = dlogr (fun () -> sprintf "FD %Ld: cancel_shutting_down" fdi); match shutting_down with | None -> () | Some (f_when_done, cancel_flag) -> assert(not !cancel_flag); self # cancel_operation `Shutting_down; cancel_flag := true; shutting_down <- None; f_when_done (Some Uq_engines.Cancelled) method private nonblock_operation cancel_flag tag f = Unixqueue.once esys group 0.0 (fun () -> if not !cancel_flag then ( dlogr (fun () -> sprintf "FD %Ld: operation: %s" fdi (string_of_tag tag)); let (want_rd, want_wr, action) = f() in dlogr (fun () -> sprintf "FD %Ld: returning from %s - want_rd=%b want_wr=%b %s" fdi (string_of_tag tag) want_rd want_wr (if want_rd || want_wr then "- queuing op and retrying later" else "")); if want_rd || want_wr then pending <- (tag, want_rd, want_wr, f) :: pending; ( try action(); self # setup_queue(); with | error -> self # setup_queue(); raise error ) ) ) method private cancel_operation tag = pending <- List.filter (fun (t, _, _, _) -> t <> tag) pending; self # setup_queue() method private retry_nonblock_operations can_read can_write = dlogr (fun () -> sprintf "FD %Ld: retry_nonblock_operations" fdi); let cur_pending = pending in pending <- []; (* maybe new operations are added! *) let actions = ref [] in let pending' = List.flatten (List.map (fun (tag, want_rd, want_wr, f) -> if (want_rd && can_read) || (want_wr && can_write) then ( dlogr (fun () -> sprintf "FD %Ld: retried operation: %s" fdi (string_of_tag tag)); let (want_rd', want_wr', action) = f() in (* must not fail! *) dlogr (fun () -> sprintf "FD %Ld: returning from %s - \ want_rd=%b want_wr=%b %s" fdi (string_of_tag tag) want_rd' want_wr' (if want_rd' || want_wr' then "- queuing op and retrying later" else "")); actions := action :: !actions; if want_rd' || want_wr' then [ tag, want_rd', want_wr', f ] (* try again later *) else [] ) else [ tag, want_rd, want_wr, f ] (* just keep *) ) cur_pending ) in pending <- pending @ pending'; (* Be careful: We can only return the first error *) let first_error = ref None in List.iter (fun f -> try f() with | e -> ( match !first_error with | None -> first_error := Some e | Some _ -> Netlog.logf `Crit "Uq_ssl hidden exception: %s" (Netexn.to_string e) ) ) (List.rev !actions); self # setup_queue(); ( match !first_error with | None -> () | Some e -> raise e ) method private setup_queue() = if alive then ( let expecting_input' = List.exists (fun (_, want_rd, _, _) -> want_rd) pending in let expecting_output' = List.exists (fun (_, _, want_wr, _) -> want_wr) pending in if expecting_input' || expecting_output' then ( if not have_handler then ( Unixqueue.add_handler esys group (fun _ _ -> self # handle_event); have_handler <- true; ); disconnecting <- None; ) else if have_handler && disconnecting = None then ( (* It makes only sense to disconnect if all callbacks are cancelled *) if not(accepting || connecting || reading <> None || writing <> None || shutting_down <> None) then ( let wid = Unixqueue.new_wait_id esys in let disconnector = Unixqueue.Wait wid in Unixqueue.add_event esys (Unixqueue.Timeout(group,disconnector)); disconnecting <- Some disconnector ) ); ( match expecting_input, expecting_input' with | (false, true) -> Unixqueue.add_resource esys group (Unixqueue.Wait_in fd, (-1.0)) | (true, false) -> Unixqueue.remove_resource esys group (Unixqueue.Wait_in fd) | _ -> () ); ( match expecting_output, expecting_output' with | (false, true) -> Unixqueue.add_resource esys group (Unixqueue.Wait_out fd, (-1.0)) | (true, false) -> Unixqueue.remove_resource esys group (Unixqueue.Wait_out fd) | _ -> () ); expecting_input <- expecting_input'; expecting_output <- expecting_output'; ) method private handle_event ev = match ev with | Unixqueue.Input_arrived(g, _) when g = group -> self # retry_nonblock_operations true false | Unixqueue.Output_readiness(g, _) when g = group -> self # retry_nonblock_operations false true | Unixqueue.Timeout (g, op) when g = group -> ( match disconnecting with | Some op' when op = op' -> disconnecting <- None; have_handler <- false; raise Equeue.Terminate | _ -> raise Equeue.Reject (* Can also be a timeout event from a "once" handler *) ) | _ -> raise Equeue.Reject method inactivate() = if alive then ( alive <- false; pending <- []; disconnecting <- None; have_handler <- false; Unixqueue.clear esys group; if close_inactive_descr then ( preclose(); Unix.close fd ) ) method event_system = esys end ;; let create_ssl_multiplex_controller ?close_inactive_descr ?preclose fd ctx esys = let () = Unix.set_nonblock fd in let s = Ssl.embed_socket fd ctx in let m = Ssl_exts.get_mode s in let () = Ssl_exts.set_mode s { m with Ssl_exts.enable_partial_write = true; accept_moving_write_buffer = true } in new ssl_mplex_ctrl ?close_inactive_descr ?preclose fd s esys ;; class ssl_connect_engine (mplex : ssl_multiplex_controller) = object(self) inherit [ unit ] Uq_engines.engine_mixin (`Working 0) mplex#event_system initializer mplex # start_ssl_connecting ~when_done:(fun exn_opt -> match exn_opt with | None -> self # set_state (`Done()) | Some err -> self # set_state (`Error err) ) () method event_system = mplex # event_system method abort() = match self#state with | `Working _ -> mplex # inactivate(); self # set_state `Aborted | _ -> () end let ssl_connect_engine = new ssl_connect_engine class ssl_accept_engine (mplex : ssl_multiplex_controller) = object(self) inherit [ unit ] Uq_engines.engine_mixin (`Working 0) mplex#event_system initializer mplex # start_ssl_accepting ~when_done:(fun exn_opt -> match exn_opt with | None -> self # set_state (`Done()) | Some err -> self # set_state (`Error err) ) () method event_system = mplex # event_system method abort() = match self#state with | `Working _ -> mplex # inactivate(); self # set_state `Aborted | _ -> () end let ssl_accept_engine = new ssl_accept_engine