Plasma GitLab Archive
Projects Blog Knowledge

(* $Id: netamqp_transport.ml 53347 2011-03-01 00:38:28Z gerd $
 * ----------------------------------------------------------------------
 *
 * This module is derived from rpc_transport.ml, which is published
 * as part of Ocamlnet, and whose copyright holder is Gerd Stolpmann.
 *
 *)

open Netamqp_types
open Printf


type 't result =
    [ `Ok of 't
    | `Error of exn
    ]

type 't result_eof =
    [ 't result
    | `End_of_file
    ]


type sockaddr =
    [ `Implied
    | `Sockaddr of Unix.sockaddr
    ]

let string_of_sockaddr =
  function
    | `Implied -> "<implied>"
    | `Sockaddr sa -> Netsys.string_of_sockaddr sa

exception Error of string


class type amqp_multiplex_controller =
object
  method alive : bool
  method event_system : Unixqueue.event_system
  method getsockname : sockaddr
  method getpeername : sockaddr
  method transport_type : transport_type
  method set_max_frame_size : int -> unit
  method eff_max_frame_size : int
  method reading : bool
  method read_eof : bool
  method start_reading : 
    when_done:( frame result_eof -> unit) -> unit -> unit
  method cancel_rd_polling : unit -> unit
  method abort_rw : unit -> unit
  method writing : bool
  method start_writing :
    when_done:(unit result -> unit) -> frame -> unit
  method start_shutting_down :
    when_done:(unit result -> unit) -> unit -> unit
  method cancel_shutting_down : unit -> unit
  method set_timeout : notify:(unit -> unit) -> float -> unit
  method inactivate : unit -> unit
end

module Debug = struct
  let enable = ref false
end

let dlog = Netlog.Debug.mk_dlog "Netamqp_transport" Debug.enable
let dlogr = Netlog.Debug.mk_dlogr "Netamqp_transport" Debug.enable

let () =
  Netlog.Debug.register_module "Netamqp_transport" Debug.enable


let mem_size = Netsys_mem.pool_block_size Netsys_mem.default_pool
  (* for allocated bigarrays *)

let fallback_size = 16384   (* for I/O via Unix *)

let mem_alloc() =
  Netsys_mem.pool_alloc_memory Netsys_mem.default_pool


let mem_dummy() =
  Bigarray.Array1.create
    Bigarray.char Bigarray.c_layout 0

let mk_mstring s =
  Xdr_mstring.string_based_mstrings # create_from_string
    s 0 (String.length s) false
  
exception Continue of (unit -> unit)


class tcp_amqp_multiplex_controller sockname peername 
        (mplex : Uq_engines.multiplex_controller) esys 
      : amqp_multiplex_controller =
  let () = 
    dlogr (fun () ->
	     sprintf "new tcp_amqpQ_multiplex_controller mplex=%d"
	       (Oo.id mplex))
  in
  let lim_frame_size =
    if Sys.word_size = 64 then
      Int64.to_int 0xffff_ffffL
    else
      Sys.max_string_length in
object(self)
  val mutable rd_buffer = Netpagebuffer.create mem_size
  val mutable rd_buffer_nomem = 
    if mplex#mem_supported then "" else String.create fallback_size

  val mutable rd_mode = `Frame_header 0
  val mutable rd_processing = false
  val mutable rd_stream_at_beginning = true

  val mutable max_frame_size = lim_frame_size

  method alive = mplex # alive
  method event_system = esys
  method getsockname = sockname
  method getpeername = peername
  method transport_type = `TCP

  method set_max_frame_size size =
    if size < 255 then
      failwith "Netamqp_transport.set_max_frame_size: too low";
    max_frame_size <- min lim_frame_size size

  method eff_max_frame_size = max_frame_size

  method reading = mplex # reading
  method read_eof = mplex # read_eof
  method writing = mplex # writing

  val mutable aborted = false

  method start_reading ~when_done () =
    assert(not mplex#reading);

    let rec est_reading() =
      let mplex_when_done exn_opt n =
	self # timer_event `Stop `R;
	match exn_opt with
	  | None ->
	      process ()
	  | Some End_of_file ->
	      if rd_mode = `Frame_header 0 && Netpagebuffer.length rd_buffer=0
	      then
		return_eof()   (* EOF between messages *)
	      else
		return_error (Error "EOF within message")
	  | Some Uq_engines.Cancelled ->
	      ()   (* Ignore *)
	  | Some error ->
	      return_error error 
      in
      
      rd_processing <- false;
      if mplex#mem_supported then (
	let (b, start, len) = Netpagebuffer.page_for_additions rd_buffer in
	mplex # start_mem_reading 
	  ~when_done:(fun exn_opt n ->
			dlogr (fun () ->
				 sprintf "Reading [mem]: %s%s"
				   (Rpc_util.hex_dump_m b start (min n 200))
				   (if n > 200 then "..." else "")
			      );
			Netpagebuffer.advance rd_buffer n;
			mplex_when_done exn_opt n
		     )
	  b
	  start
	  len
      )
      else (
	mplex # start_reading
	  ~when_done:(fun exn_opt n ->
			dlogr (fun () ->
				 sprintf "Reading [str]: %s%s"
				   (Rpc_util.hex_dump_s 
				      rd_buffer_nomem 0 (min n 200))
				   (if n > 200 then "..." else "")
			      );
			Netpagebuffer.add_sub_string 
			  rd_buffer rd_buffer_nomem 0 n;
			mplex_when_done exn_opt n
		     )
	  rd_buffer_nomem
	  0
	  (String.length rd_buffer_nomem)
      );
      self # timer_event `Start `R

    and process () =
      rd_processing <- true;
      let len = Netpagebuffer.length rd_buffer in
      match rd_mode with
	| `Frame_header n ->
	    (* n: we already saw n bytes *)
	    let n' = min len 7 in
	    rd_mode <- `Frame_header n';
	    if n' = 7 then (
	      (* Decode the header. If we are at the beginning of the
		 stream, it is also possible that we see a protocol header
		 (with 8 bytes)
	       *)
	      try
		let s = Netpagebuffer.sub rd_buffer 0 7 in
		let frame_type =
		  match s.[0] with
		    | '\001' -> `Method
		    | '\002' -> `Header
		    | '\003' -> `Body
		    | '\008' -> `Heartbeat
		    | '\065' when rd_stream_at_beginning -> `Proto_header
		    | _ -> raise(Error "Bad frame header") in
		if frame_type = `Proto_header then (
		  if String.sub s 0 5 = "AMQP\000" then (
		    if len >= 8 then (
		      let p = Netpagebuffer.sub rd_buffer 5 3 in
		      let frame =
			{ frame_type = `Proto_header;
			  frame_channel = 0;
			  frame_payload = [mk_mstring p]
			} in
		      Netpagebuffer.delete_hd rd_buffer 8;
		      raise (Continue (fun () -> return_msg frame))
		    )
		    else raise (Continue est_reading)
		  )
		  else
		    raise(Error "Bad frame header")
		)
		else (
		  let channel =
		    Netamqp_rtypes.read_uint2_unsafe s 1 in
		  let size =
		    Rtypes.read_uint4_unsafe s 3 in
		  let max_size =
		    Rtypes.uint4_of_int (max_frame_size-8) in
		  if Rtypes.gt_uint4 size max_size then
		    raise(Error "Frame too long");
		  let size =
		    Rtypes.int_of_uint4 size in
		  rd_mode <- `Payload(frame_type, channel, size, 7)
		);
		raise (Continue process)
	      with
		| Continue f -> (* call f at tail position *)
		    f()
		| error ->
		    return_error error
	    )
	    else est_reading()
	| `Payload(frame_type, channel, size, payload_start) ->
	    if len >= size+payload_start+1 then (
	      let trailer =
		Netpagebuffer.sub rd_buffer (payload_start+size) 1 in
	      if trailer = "\xCE" then (
		let data =
		  Netpagebuffer.sub rd_buffer payload_start size in
		let ms =
		  mk_mstring data in
		let frame = 
		  { frame_type = frame_type;
		    frame_channel = channel;
		    frame_payload = [ms]
		  } in
		Netpagebuffer.delete_hd rd_buffer (payload_start+size+1);
		rd_mode <- `Frame_header 0;
		return_msg frame
	      )
	      else return_error (Error "Bad frame end")
	    )
	    else est_reading()

    and return_msg msg =
      rd_stream_at_beginning <- false;
      if not aborted then
	when_done (`Ok msg)

    and return_error e =
      rd_processing <- false;
      if not aborted then
	when_done (`Error e)

    and return_eof () =
      rd_processing <- false;
      if not aborted then
	when_done `End_of_file 

    in
    if rd_processing then
      process ()
    else
      est_reading()
	    

  method start_writing ~when_done frame =

    assert(not mplex#writing);

    (* - `String(s,p,l): We have still to write s[p] to s[p+l-1]
       - `Memory(m,p,l,ms,q): We have still to write
          m[p] to m[p+l-1], followed by ms[q] to end of ms
          (where ms is the managed string)
     *)

    let item_of_mstring ms r =
      (* Create the item for ms, starting at offset r *)
      let l = ms#length in
      assert(r <= l);
      match ms # preferred with
	| `String ->
	    let (s,pos) = ms#as_string in (* usually only r=0 *)
	    `String(s,pos+r,l-r)
	| `Memory ->
	    if mplex#mem_supported then (
	      let (m,pos) = ms#as_memory in
	      `Memory(m, pos+r, l-r, ms, l)
	    )
	    else
	      let (s,pos) = ms#as_string in
	      `String(s,pos+r,l-r) in

    let rec optimize_items items =
      (* Merge adjacent short items (only for strings) *)
      match items with
	| (`String(s1,p1,l1) as i1) :: (`String(s2,p2,l2) as i2) :: items' ->
	    if l1 < 256 && l2 < 256 then (
	      let b = Buffer.create (l1+l2) in
	      Buffer.add_substring b s1 p1 l1;
	      Buffer.add_substring b s2 p2 l2;
	      gather_items b items'
	    )
	    else
	      i1 :: optimize_items (i2 :: items')
	| other :: items' ->
	    other :: optimize_items items'
	| [] ->
	    []

    and gather_items b items =
      match items with
	| `String(s,p,l) :: items' when l < 256 ->
	    Buffer.add_substring b s p l;
	    gather_items b items'
	| _ ->
	    `String(Buffer.contents b, 0, Buffer.length b) :: 
	      optimize_items items in


    let item_is_empty =
      function
	| `String(_,_,l) -> l=0
	| `Memory(_,_,l,ms,q) -> l=0 && ms#length=q in

    let rec est_writing item remaining =
      (* [item] is the current buffer to write; and [remaining] need to be
	 written after that
       *)
      let mplex_when_done exn_opt n = (* n bytes written *)
	self # timer_event `Stop `W;
	match exn_opt with
	  | None ->
	      ( match item with
		  | `Memory(m,p,l,ms,q) ->
		      let l' = l-n in
		      if l' > 0 then
			est_writing (`Memory(m,p+n,l',ms,q)) remaining
		      else (
			let mlen = ms#length in
			if q < mlen then
			  let item' = item_of_mstring ms q in
			  est_writing item' remaining
			else
			  est_writing_next remaining
		      )
		  | `String(s,p,l) ->
		      let l' = l-n in
		      if l' > 0 then
			est_writing (`String(s,p+n,l')) remaining
		      else 
			est_writing_next remaining
	      )
	  | Some Uq_engines.Cancelled ->
	      ()  (* ignore *)
	  | Some error ->
	      if not aborted then
		when_done (`Error error)
      in

      ( match item with
	  | `Memory(m,p,l,_,_) ->
	      dlogr (fun () ->
		       sprintf "Writing [mem]: %s%s" 
			 (Rpc_util.hex_dump_m m p (min l 200))
			 (if l > 200 then "..." else "")
		    );
	      mplex # start_mem_writing
		~when_done:mplex_when_done m p l
	  | `String(s,p,l) ->
	      dlogr (fun () ->
		       sprintf "Writing [str]: %s%s" 
			 (Rpc_util.hex_dump_s s p (min l 200))
			 (if l > 200 then "..." else "")
		    );
	      mplex # start_writing
		~when_done:mplex_when_done s p l
      );
      self # timer_event `Start `W

    and  est_writing_next remaining =
      match remaining with
	| item :: remaining' ->
	    if item_is_empty item then
	      est_writing_next remaining'
	    else
	      est_writing item remaining'
	| [] ->
	    if not aborted then
	      when_done (`Ok ())
    in

    let write mstrings =
      est_writing_next
	(optimize_items
	   (List.map (fun ms -> item_of_mstring ms 0) mstrings)) in

    match frame.frame_type with
      | `Proto_header ->
	  let s = Xdr_mstring.concat_mstrings frame.frame_payload in
	  if String.length s <> 3 then
	    raise(Error "The `Proto_header frame requires a 3-byte payload");
	  let u = "AMQP\000" ^ s in
	  write [mk_mstring u]
      | _ ->
	  (* Create frame header and frame end mstrings: *)
	  let l = Xdr_mstring.length_mstrings frame.frame_payload in
	  if l > max_frame_size then (
	    dlogr 
	      (fun () -> sprintf "l=%d max_frame_size=%d" l max_frame_size);
	    raise(Error "The frame is too large")
	  );
	  let s = String.create 7 in
	  let c0 = 
	    match frame.frame_type with
	      | `Method -> '\001'
	      | `Header -> '\002'
	      | `Body -> '\003'
	      | `Heartbeat -> '\008'
	      | `Proto_header -> assert false in
	  s.[0] <- c0;
	  Netamqp_rtypes.write_uint2 s 1 frame.frame_channel;
	  Rtypes.write_uint4 s 3 (Rtypes.uint4_of_int l);
	  let header = mk_mstring s in
	  let trailer = mk_mstring "\xCE" in
	  write (header :: (frame.frame_payload @ [trailer]))

  method cancel_rd_polling () =
    if mplex#reading then
      mplex # cancel_reading()

  method abort_rw () =
    aborted <- true;
    mplex # cancel_reading();
    mplex # cancel_writing()
    
  method start_shutting_down ~when_done () =
    dlogr (fun () ->
	     sprintf "start_shutting_down mplex=%d"
	       (Oo.id mplex));
    mplex # start_shutting_down
      ~when_done:(fun exn_opt ->
		    dlogr (fun () ->
			     sprintf "done shutting_down mplex=%d"
			       (Oo.id mplex));
		    self # timer_event `Stop `D;
		    match exn_opt with
		      | None -> when_done (`Ok ())
		      | Some error -> when_done (`Error error)
		 )
      ();
    self # timer_event `Start `D

  method cancel_shutting_down () =
    self # timer_event `Stop `D;
    mplex # cancel_shutting_down()

  method inactivate () =
    dlogr (fun () ->
	     sprintf "inactivate mplex=%d"
	       (Oo.id mplex));
    self # stop_timer();
    mplex # inactivate()

  val mutable timer = None
  val mutable timer_r = `Stop
  val mutable timer_w = `Stop
  val mutable timer_d = `Stop
  val mutable timer_group = None

  method set_timeout ~notify tmo =
    timer <- Some(notify, tmo)

  method private timer_event start_stop which =
    ( match timer with
	| None -> ()
	| Some(notify, tmo) ->
	    ( match which with
		| `R -> timer_r <- start_stop
		| `W -> timer_w <- start_stop
		| `D -> timer_d <- start_stop
	    );
	    self # stop_timer();
	    if timer_r = `Start || timer_w = `Start || timer_d = `Start then (
	      let g = Unixqueue.new_group esys in
	      timer_group <- Some g;
	      Unixqueue.once esys g tmo
		(fun () -> 
		   timer_group <- None;
		   notify()
		)
	    );
    )


  method private stop_timer() =
    ( match timer_group with
	| None -> ()
	| Some g -> Unixqueue.clear esys g
    );
    timer_group <- None;
    timer_r <- `Stop;
    timer_w <- `Stop;
    timer_d <- `Stop

end



let tcp_amqp_multiplex_controller ?(close_inactive_descr=true)
                                  ?(preclose=fun()->()) fd esys =
  let sockname = 
    try
      `Sockaddr(Unix.getsockname fd) 
    with
      | Unix.Unix_error(_,_,_) -> `Implied in
  let peername = 
    try
      `Sockaddr(Netsys.getpeername fd)
    with
      | Unix.Unix_error(_,_,_) -> `Implied in
  let mplex = 
    Uq_engines.create_multiplex_controller_for_connected_socket
      ~close_inactive_descr ~preclose
      fd esys in
  new tcp_amqp_multiplex_controller sockname peername mplex esys
;;

This web site is published by Informatikbüro Gerd Stolpmann
Powered by Caml