Plasma GitLab Archive
Projects Blog Knowledge

(* $Id: uq_ssl.ml 1445 2010-04-27 15:51:50Z gerd $ *)

module Debug = struct
  let enable = ref false
end

let dlog = Netlog.Debug.mk_dlog "Uq_ssl" Debug.enable
let dlogr = Netlog.Debug.mk_dlogr "Uq_ssl" Debug.enable

let () =
  Netlog.Debug.register_module "Uq_ssl" Debug.enable


open Printf

exception Ssl_error of Ssl.ssl_error

type ssl_socket_state = [ `Unset | `Client | `Server | `Unclean | `Clean ]

let string_of_socket_state =
  function
    | `Unset -> "Unset"
    | `Client -> "Client"
    | `Server -> "Server"
    | `Unclean -> "Unclean"
    | `Clean -> "Clean"


class type ssl_multiplex_controller =
object
  inherit Uq_engines.multiplex_controller
  method ssl_socket : Ssl.socket
  method ssl_socket_state : ssl_socket_state
  method ssl_connecting : bool
  method ssl_accepting : bool
  method start_ssl_connecting : 
    when_done:(exn option -> unit) -> unit -> unit
  method start_ssl_accepting :
    when_done:(exn option -> unit) -> unit -> unit
end


let string_of_tag =
  function
    | `Connecting -> "Connecting"
    | `Accepting -> "Accepting"
    | `Reading -> "Reading"
    | `Writing -> "Writing"
    | `Shutting_down -> "Shutting_down"


class ssl_mplex_ctrl ?(close_inactive_descr=true)
                     ?(preclose = fun () -> ())
                     fd ssl_sock esys : ssl_multiplex_controller =
  let () = Unix.set_nonblock fd in
  let fdi = Netsys.int64_of_file_descr fd in
object(self)
  val mutable alive = true    (* if false => state in { `Clean, `Unclean } *)
  val mutable read_eof = false
  val mutable wrote_eof = false

  val mutable state = (`Unset : ssl_socket_state)

  val mutable connecting = false   (* true only in state `Unset *)
  val mutable accepting = false    (* true only in state `Unset *)
  val mutable reading = None       (* <> None only in states `Client/`Server *)
  val mutable writing = None       (* <> None only in states `Client/`Server *)
  val mutable shutting_down = None (* <> None only in states `Client/`Server *)
  val mutable disconnecting = None

  val mutable have_handler = false

  val mutable pending = []
         (* list of pending socket operations *)

  val mutable expecting_input = false
  val mutable expecting_output = false

  val group = Unixqueue.new_group esys

  method alive = alive
  method ssl_socket = ssl_sock
  method ssl_socket_state = state

  method ssl_connecting = connecting
  method ssl_accepting = accepting
  method reading = reading <> None
  method writing = writing <> None
  method shutting_down = shutting_down <> None
  method read_eof = read_eof
  method wrote_eof = wrote_eof

  method supports_half_open_connection = false

  method mem_supported = false



  method start_ssl_connecting ~when_done () =
    if state <> `Unset then
      failwith "#start_connecting: no longer possible in this state";
    if connecting || accepting then
      failwith "#start_connecting: handshake already in progress";
    dlogr
      (fun () ->
	 sprintf "FD %Ld: start_ssl_connecting" fdi);
    let when_done arg =
      dlogr
	(fun () ->
	   sprintf "FD %Ld: done start_ssl_connecting" fdi);
      when_done arg in
    self # nonblock_operation
      (ref false)
      `Connecting
      (fun () ->
	 try
	   Ssl.connect ssl_sock;
	   state <- `Client;
	   connecting <- false;
	   (false, false, fun () -> when_done None)
	 with
	   | Ssl.Connection_error Ssl.Error_want_read ->
	       (true, false, fun () -> ())
	   | Ssl.Connection_error Ssl.Error_want_write ->
	       (false, true, fun () -> ())
	   | Ssl.Connection_error ssl_err ->
	       state <- `Unclean;
	       connecting <- false;
	       (false, false, fun () -> when_done (Some (Ssl_error ssl_err)))
	   | err ->
	       state <- `Unclean;
	       connecting <- false;
	       (false, false, fun () -> when_done (Some err))
      );
    connecting <- true


  method start_ssl_accepting ~when_done () =
    if state <> `Unset then
      failwith "#start_accepting: no longer possible in this state";
    if connecting || accepting then
      failwith "#start_accepting: handshake already in progress";
    dlogr
      (fun () ->
	 sprintf "FD %Ld: start_ssl_accepting" fdi);
    let when_done arg =
      dlogr
	(fun () ->
	   sprintf "FD %Ld: done start_ssl_accepting" fdi);
      when_done arg in
    self # nonblock_operation
      (ref false)
      `Accepting
      (fun () ->
	 try
	   Ssl.accept ssl_sock;
	   state <- `Server;
	   accepting <- false;
	   (false, false, fun () -> when_done None)
	 with
	   | Ssl.Accept_error Ssl.Error_want_read ->
	       (true, false, fun () -> ())
	   | Ssl.Accept_error Ssl.Error_want_write ->
	       (false, true, fun () -> ())
	   | Ssl.Accept_error ssl_err ->
	       state <- `Unclean;
	       accepting <- false;
	       (false, false, fun () -> when_done (Some (Ssl_error ssl_err)))
	   | err ->
	       state <- `Unclean;
	       accepting <- false;
	       (false, false, fun () -> when_done (Some err);)
      );
    accepting <- true;


  method start_reading ?(peek = fun() -> ()) ~when_done s pos len =
    if pos < 0 || len < 0 || pos + len > String.length s then
      invalid_arg "#start_reading";
    if state <> `Client && state <> `Server then
      failwith "#start_reading: bad state";
    if reading <> None then
      failwith "#start_reading: already reading";
    if shutting_down <> None then
      failwith "#start_reading: already shutting down";
    dlogr
      (fun () ->
	 sprintf "FD %Ld: start_reading" fdi);
    let when_done arg =
      dlogr
	(fun () ->
	   sprintf "FD %Ld: done start_reading" fdi);
      when_done arg in
    let cancel_flag = ref false in
    self # nonblock_operation
      cancel_flag
      `Reading
      (fun () ->
	 try
	   (* peek(); *)
	   (* [peek] is used by auth-local. It does not work for SSL. *)
	   let n = Ssl_exts.single_read ssl_sock s pos len in
	   reading <- None;
	   assert(n > 0);
	   (false, false, fun () -> when_done None n)
	 with
	   | Ssl.Read_error Ssl.Error_zero_return ->
	       (* Read EOF *)
	       read_eof <- true;  
	       (* Note: read_eof should be consistent with Ssl.read *)
	       (false, false, fun () -> when_done (Some End_of_file) 0)
	   | Ssl.Read_error Ssl.Error_want_read ->
	       (true, false, fun () -> ())
	   | Ssl.Read_error Ssl.Error_want_write ->
	       (false, true, fun () -> ())
	   | Ssl.Read_error ssl_err ->
	       state <- `Unclean;
	       reading <- None;
	       (false, false, fun () -> when_done (Some (Ssl_error ssl_err)) 0)
	   | err ->
	       state <- `Unclean;
	       reading <- None;
	       (false, false, fun () -> when_done (Some err) 0)
      );
    reading <- Some (when_done, cancel_flag)

  method start_mem_reading ?(peek = fun() -> ()) ~when_done m pos len =
    raise Uq_engines.Mem_not_supported

  method cancel_reading () =
    dlogr
      (fun () ->
	 sprintf "FD %Ld: cancel_reading" fdi);
    match reading with
      | None ->
	  ()
      | Some (f_when_done, cancel_flag) ->
	  assert(not !cancel_flag);
	  self # cancel_operation `Reading;
	  cancel_flag := true;
	  reading <- None;
	  f_when_done (Some Uq_engines.Cancelled) 0


  method start_writing ~when_done s pos len =
    if pos < 0 || len < 0 || pos + len > String.length s then
      invalid_arg "#start_writing";
    if state <> `Client && state <> `Server then
      failwith "#start_writing: bad state";
    if writing <> None then
      failwith "#start_writing: already reading";
    if shutting_down <> None then
      failwith "#start_writing: already shutting down";
    if wrote_eof then
      failwith "#start_writing: already past EOF";
    dlogr
      (fun () ->
	 sprintf "FD %Ld: start_writing" fdi);
    let when_done arg =
      dlogr
	(fun () ->
	   sprintf "FD %Ld: done start_writing" fdi);
      when_done arg in
    let cancel_flag = ref false in
    self # nonblock_operation
      cancel_flag
      `Writing
      (fun () ->
	 try
	   let n = Ssl_exts.single_write ssl_sock s pos len in
	   writing <- None;
	   (false, false, fun () -> when_done None n)
	 with
	   | Ssl.Write_error Ssl.Error_zero_return ->
	       (false, true, fun () -> ())
	   | Ssl.Write_error Ssl.Error_want_read ->
	       (true, false, fun () -> ())
	   | Ssl.Write_error Ssl.Error_want_write ->
	       (false, true, fun () -> ())
	   | Ssl.Write_error ssl_err ->
	       state <- `Unclean;
	       writing <- None;
	       (false, false, fun () -> when_done (Some (Ssl_error ssl_err)) 0)
	   | err ->
	       state <- `Unclean;
	       writing <- None;
	       (false, false, fun () -> when_done (Some err) 0)
      );
    writing <- Some (when_done, cancel_flag)


  method start_mem_writing ~when_done m pos len =
    raise Uq_engines.Mem_not_supported


  method start_writing_eof ~when_done () =
    failwith "#start_writing_eof: operation not supported";
    

  method cancel_writing () =
    dlogr
      (fun () ->
	 sprintf "FD %Ld: cancel_writing" fdi);
    match writing with
      | None ->
	  ()
      | Some (f_when_done, cancel_flag) ->
	  assert(not !cancel_flag);
	  self # cancel_operation `Writing;
	  cancel_flag := true;
	  writing <- None;
	  f_when_done (Some Uq_engines.Cancelled) 0


  method start_shutting_down ?(linger = 60.0) ~when_done () =
    if state <> `Client && state <> `Server then
      failwith "#start_shutting_down: bad state";
    if reading <> None || writing <> None then
      failwith "#start_shutting_down: still reading or writing";
    if shutting_down <> None then
      failwith "#start_shutting_down: already shutting down";
    dlogr
      (fun () ->
	 sprintf "FD %Ld: start_shutting_down" fdi);
    let when_done arg =
      dlogr
	(fun () ->
	   sprintf "FD %Ld: done start_shutting_down" fdi);
      when_done arg in
    let n = ref 0 in
    let cancel_flag = ref false in
    self # nonblock_operation
      cancel_flag
      `Shutting_down
      (fun () ->
	 try
	   Ssl_exts.single_shutdown ssl_sock;
	   incr n;

	   let (rcvd_shutdown, sent_shutdown) =
	     Ssl_exts.get_shutdown ssl_sock in
	   if rcvd_shutdown then
	     read_eof <- true;
	   if sent_shutdown then
	     wrote_eof <- true;

	   if !n=2 && not (rcvd_shutdown && sent_shutdown) then (
	     (* Unclean crash *)
	     shutting_down <- None;
	     state <- `Unclean;
	     if rcvd_shutdown || sent_shutdown then
	       (false, false, fun () -> when_done None)
	     else
	       (false, false, 
		fun () -> when_done(Some(Failure "Unclean SSL shutdown")))

	   )
	   else
	     match (rcvd_shutdown, sent_shutdown) with
	       | (false, false) ->
		   (* strange *)
		   (false, true, fun () -> ())
	       | (true, false) ->
		   (false, true, fun () -> ())
	       | (false, true) ->
		   (true, false, fun () -> ())
	       | (true, true) ->
		   shutting_down <- None;
		   state <- `Clean;
		   (false, false, fun () -> when_done None)
	 with
	   | Ssl_exts.Shutdown_error Ssl.Error_want_read ->
	       (true, false, fun () -> ())
	   | Ssl_exts.Shutdown_error Ssl.Error_want_write ->
	       (false, true, fun () -> ())
	   | Ssl_exts.Shutdown_error ssl_err ->
	       state <- `Unclean;
	       shutting_down <- None;
	       (false, false, fun () -> when_done (Some (Ssl_error ssl_err)))
	   | err ->
	       state <- `Unclean;
	       shutting_down <- None;
	       (false, false, fun () -> when_done (Some err))
      );
    shutting_down <- Some(when_done, cancel_flag)

  method cancel_shutting_down () =
    dlogr
      (fun () ->
	 sprintf "FD %Ld: cancel_shutting_down" fdi);
    match shutting_down with
      | None ->
	  ()
      | Some (f_when_done, cancel_flag) ->
	  assert(not !cancel_flag);
	  self # cancel_operation `Shutting_down;
	  cancel_flag := true;
	  shutting_down <- None;
	  f_when_done (Some Uq_engines.Cancelled)


  method private nonblock_operation cancel_flag tag f =
    Unixqueue.once
      esys
      group
      0.0
      (fun () ->
	 if not !cancel_flag then (
	   dlogr
	     (fun () ->
		sprintf "FD %Ld: operation: %s" fdi (string_of_tag tag));
	   let (want_rd, want_wr, action) = f() in
	   dlogr
	     (fun () ->
		sprintf "FD %Ld: returning from %s - want_rd=%b want_wr=%b %s"
		  fdi (string_of_tag tag) want_rd want_wr
		  (if want_rd || want_wr then "- queuing op and retrying later"
		   else ""));
	   if want_rd || want_wr then
	     pending <- (tag, want_rd, want_wr, f) :: pending;
	   ( try
	       action();
	       self # setup_queue();
	     with
	       | error ->
		   self # setup_queue(); raise error
	   )
	 )
      )


  method private cancel_operation tag =
    pending <-
      List.filter (fun (t, _, _, _) -> t <> tag) pending;
    self # setup_queue()


  method private retry_nonblock_operations can_read can_write =
    dlogr
      (fun () ->
	 sprintf "FD %Ld: retry_nonblock_operations" fdi);
    let cur_pending = pending in
    pending <- [];    (* maybe new operations are added! *)
    let actions = ref [] in
    let pending' =
      List.flatten
	(List.map
	   (fun (tag, want_rd, want_wr, f) ->
	      if (want_rd && can_read) || (want_wr && can_write)  then (
		dlogr
		  (fun () ->
		     sprintf "FD %Ld: retried operation: %s" 
		       fdi (string_of_tag tag));
		let (want_rd', want_wr', action) = f() in  (* must not fail! *)
		dlogr
		  (fun () ->
		     sprintf "FD %Ld: returning from %s - \
                              want_rd=%b want_wr=%b %s"
		       fdi (string_of_tag tag) want_rd' want_wr'
		       (if want_rd' || want_wr' then
			  "- queuing op and retrying later"
			else ""));
		actions := action :: !actions;
		if want_rd' || want_wr' then
		  [ tag, want_rd', want_wr', f ]   (* try again later *)
		else
		  []
	      )		      
	      else
		[ tag, want_rd, want_wr, f ]   (* just keep *)
	   )
	   cur_pending
	) in
    pending <- pending @ pending';

    (* Be careful: We can only return the first error *)
    let first_error = ref None in
    List.iter
      (fun f ->
	 try f()
	 with
	   | e ->
	       ( match !first_error with
		   | None -> first_error := Some e
		   | Some _ ->
		       Netlog.logf `Crit
			 "Uq_ssl hidden exception: %s"
			 (Netexn.to_string e)
			 
	       )
      )
      (List.rev !actions);

    self # setup_queue();

    ( match !first_error with
	| None -> ()
	| Some e -> raise e
    )


  method private setup_queue() =
    if alive then (
      let expecting_input' = 
	List.exists (fun (_, want_rd, _, _) -> want_rd) pending in
      let expecting_output' =
	List.exists (fun (_, _, want_wr, _) -> want_wr) pending in
      
      if expecting_input' || expecting_output' then (
	if not have_handler then (
	  Unixqueue.add_handler esys group (fun _ _ -> self # handle_event);
	  have_handler <- true;
	);
	disconnecting <- None;
      )
      else
	if have_handler && disconnecting = None then (
	  (* It makes only sense to disconnect if all callbacks are cancelled *)
	  if not(accepting || connecting || reading <> None ||
		   writing <> None || shutting_down <> None) then (
	    let wid = Unixqueue.new_wait_id esys in
	    let disconnector = Unixqueue.Wait wid in
	    Unixqueue.add_event esys (Unixqueue.Timeout(group,disconnector));
	    disconnecting <- Some disconnector
	  )
	);
      
      ( match expecting_input, expecting_input' with
	  | (false, true) ->
	      Unixqueue.add_resource esys group (Unixqueue.Wait_in fd, (-1.0))
	  | (true, false) ->
	      Unixqueue.remove_resource esys group (Unixqueue.Wait_in fd)
	  | _ ->
	      ()
      );
	
      ( match expecting_output, expecting_output' with
	  | (false, true) ->
	      Unixqueue.add_resource esys group (Unixqueue.Wait_out fd, (-1.0))
	  | (true, false) ->
	      Unixqueue.remove_resource esys group (Unixqueue.Wait_out fd)
	  | _ ->
	      ()
      );

      expecting_input  <- expecting_input';
      expecting_output <- expecting_output';
    )


  method private handle_event ev =
    match ev with
      | Unixqueue.Input_arrived(g, _) when g = group ->
	  self # retry_nonblock_operations true false

      | Unixqueue.Output_readiness(g, _) when g = group ->
	  self # retry_nonblock_operations false true

      | Unixqueue.Timeout (g, op) when g = group ->
	  ( match disconnecting with
	      | Some op' when op = op' ->
		  disconnecting <- None;
		  have_handler <- false;
		  raise Equeue.Terminate

	      | _ -> raise Equeue.Reject
		  (* Can also be a timeout event from a "once" handler *)
	  )

      | _ ->
	  raise Equeue.Reject


  method inactivate() =
    if alive then (
      alive <- false;
      pending <- [];
      disconnecting <- None;
      have_handler <- false;
      Unixqueue.clear esys group;
      if close_inactive_descr then (
	preclose();
	Unix.close fd
      )
    )

  method event_system = esys

end
;;


let create_ssl_multiplex_controller
       ?close_inactive_descr ?preclose fd ctx esys =
  let () = Unix.set_nonblock fd in
  let s = Ssl.embed_socket fd ctx in
  let m = Ssl_exts.get_mode s in
  let () = Ssl_exts.set_mode s 
    { m with
	Ssl_exts.enable_partial_write = true; 
	accept_moving_write_buffer = true } in
  new ssl_mplex_ctrl ?close_inactive_descr ?preclose fd s esys
;;


class ssl_connect_engine (mplex : ssl_multiplex_controller) =
object(self)
  inherit [ unit ] Uq_engines.engine_mixin (`Working 0) mplex#event_system

  initializer
    mplex # start_ssl_connecting
      ~when_done:(fun exn_opt ->
		    match exn_opt with
		      | None ->
			  self # set_state (`Done())
		      | Some err ->
			  self # set_state (`Error err)
		 )
      ()

  method event_system = mplex # event_system

  method abort() =
    match self#state with
      | `Working _ ->
	  mplex # inactivate();
	  self # set_state `Aborted
      | _ ->
	  ()

end


let ssl_connect_engine = new ssl_connect_engine


class ssl_accept_engine (mplex : ssl_multiplex_controller) =
object(self)
  inherit [ unit ] Uq_engines.engine_mixin (`Working 0) mplex#event_system

  initializer
    mplex # start_ssl_accepting
      ~when_done:(fun exn_opt ->
		    match exn_opt with
		      | None ->
			  self # set_state (`Done())
		      | Some err ->
			  self # set_state (`Error err)
		 )
      ()

  method event_system = mplex # event_system

  method abort() =
    match self#state with
      | `Working _ ->
	  mplex # inactivate();
	  self # set_state `Aborted
      | _ ->
	  ()

end


let ssl_accept_engine = new ssl_accept_engine

This web site is published by Informatikbüro Gerd Stolpmann
Powered by Caml