(* $Id$
 * ----------------------------------------------------------------------
 * This module is derived from, which is published
 * as part of Ocamlnet, and whose copyright holder is Gerd Stolpmann.

open Netamqp_types
open Printf

type 't result =
    [ `Ok of 't
    | `Error of exn

type 't result_eof =
    [ 't result
    | `End_of_file

type sockaddr =
    [ `Implied
    | `Sockaddr of Unix.sockaddr

let string_of_sockaddr =
    | `Implied -> "<implied>"
    | `Sockaddr sa -> Netsys.string_of_sockaddr sa

let hex_dump_s s pos len =
  Rpc_util.hex_dump_b (Bytes.unsafe_of_string s) pos len

exception Error of string

class type amqp_multiplex_controller =
  method alive : bool
  method event_system : Unixqueue.event_system
  method getsockname : sockaddr
  method getpeername : sockaddr
  method transport_type : transport_type
  method set_max_frame_size : int -> unit
  method eff_max_frame_size : int
  method reading : bool
  method read_eof : bool
  method start_reading : 
    when_done:( frame result_eof -> unit) -> unit -> unit
  method cancel_rd_polling : unit -> unit
  method abort_rw : unit -> unit
  method writing : bool
  method start_writing :
    when_done:(unit result -> unit) -> frame -> unit
  method start_shutting_down :
    when_done:(unit result -> unit) -> unit -> unit
  method cancel_shutting_down : unit -> unit
  method set_timeout : notify:(unit -> unit) -> float -> unit
  method inactivate : unit -> unit
  method tls_session_props : Nettls_support.tls_session_props option

module Debug = struct
  let enable = ref false

let dlog = Netlog.Debug.mk_dlog "Netamqp_transport" Debug.enable
let dlogr = Netlog.Debug.mk_dlogr "Netamqp_transport" Debug.enable

let () =
  Netlog.Debug.register_module "Netamqp_transport" Debug.enable

let mem_size = Netsys_mem.pool_block_size Netsys_mem.default_pool
  (* for allocated bigarrays *)

let fallback_size = 16384   (* for I/O via Unix *)

let mem_alloc() =
  Netsys_mem.pool_alloc_memory Netsys_mem.default_pool

let mem_dummy() =
    Bigarray.char Bigarray.c_layout 0

let mk_mstring s =
  Netxdr_mstring.bytes_based_mstrings # create_from_bytes
    s 0 (Bytes.length s) false

let buffer_add_subbytes buf by pos len =
  (* emulation of Buffer.add_subbytes *)
  (* TODO: for OCaml>=4.02 use Buffer.add_subbytes directly *)
  Buffer.add_string buf (Bytes.unsafe_to_string (Bytes.sub by pos len))

let buffer_to_bytes buf =
  (* emulation of Buffer.to_bytes *)
  (* TODO: for OCaml>=4.02 use Buffer.to_bytes directly *)
  Bytes.unsafe_of_string (Buffer.contents buf)

exception Continue of (unit -> unit)

class tcp_amqp_multiplex_controller sockname peername 
        (mplex : Uq_engines.multiplex_controller) esys 
      : amqp_multiplex_controller =
  let () = 
    dlogr (fun () ->
	     sprintf "new tcp_amqpQ_multiplex_controller mplex=%d"
	       ( mplex))
  let lim_frame_size =
    if Sys.word_size = 64 then
      Int64.to_int 0xffff_ffffL
      Sys.max_string_length in
  val mutable rd_buffer = Netpagebuffer.create mem_size
  val mutable rd_buffer_nomem = 
    if mplex#mem_supported then Bytes.create 0 else Bytes.create fallback_size

  val mutable rd_mode = `Frame_header 0
  val mutable rd_processing = false
  val mutable rd_stream_at_beginning = true

  val mutable max_frame_size = lim_frame_size

  method alive = mplex # alive
  method event_system = esys
  method getsockname = sockname
  method getpeername = peername
  method transport_type =
    if mplex#tls_session_props = None then `TCP else `TLS
  method tls_session_props = mplex # tls_session_props

  method set_max_frame_size size =
    if size < 255 then
      failwith "Netamqp_transport.set_max_frame_size: too low";
    max_frame_size <- min lim_frame_size size

  method eff_max_frame_size = max_frame_size

  method reading = mplex # reading
  method read_eof = mplex # read_eof
  method writing = mplex # writing

  val mutable aborted = false

  method start_reading ~when_done () =
    assert(not mplex#reading);

    let rec est_reading() =
      let mplex_when_done exn_opt n =
	self # timer_event `Stop `R;
	match exn_opt with
	  | None ->
	      process ()
	  | Some End_of_file ->
	      if rd_mode = `Frame_header 0 && Netpagebuffer.length rd_buffer=0
		return_eof()   (* EOF between messages *)
		return_error (Error "EOF within message")
	  | Some Uq_engines.Cancelled ->
	      ()   (* Ignore *)
	  | Some error ->
	      return_error error 
      rd_processing <- false;
      if mplex#mem_supported then (
	let (b, start, len) = Netpagebuffer.page_for_additions rd_buffer in
	mplex # start_mem_reading 
	  ~when_done:(fun exn_opt n ->
			dlogr (fun () ->
				 sprintf "Reading [mem]: %s%s"
				   (Rpc_util.hex_dump_m b start (min n 200))
				   (if n > 200 then "..." else "")
			Netpagebuffer.advance rd_buffer n;
			mplex_when_done exn_opt n
      else (
	mplex # start_reading
	  ~when_done:(fun exn_opt n ->
			dlogr (fun () ->
				 sprintf "Reading [str]: %s%s"
				      rd_buffer_nomem 0 (min n 200))
				   (if n > 200 then "..." else "")
			  rd_buffer rd_buffer_nomem 0 n;
			mplex_when_done exn_opt n
	  (Bytes.length rd_buffer_nomem)
      self # timer_event `Start `R

    and process () =
      rd_processing <- true;
      let len = Netpagebuffer.length rd_buffer in
      match rd_mode with
	| `Frame_header n ->
	    (* n: we already saw n bytes *)
	    let n' = min len 7 in
	    rd_mode <- `Frame_header n';
	    if n' = 7 then (
	      (* Decode the header. If we are at the beginning of the
		 stream, it is also possible that we see a protocol header
		 (with 8 bytes)
		let s = Netpagebuffer.sub_bytes rd_buffer 0 7 in
		let frame_type =
		  match Bytes.get s 0 with
		    | '\001' -> `Method
		    | '\002' -> `Header
		    | '\003' -> `Body
		    | '\008' -> `Heartbeat
		    | '\065' when rd_stream_at_beginning -> `Proto_header
		    | _ -> raise(Error "Bad frame header") in
		if frame_type = `Proto_header then (
		  if Bytes.sub_string s 0 5 = "AMQP\000" then (
		    if len >= 8 then (
		      let p = Netpagebuffer.sub_bytes rd_buffer 5 3 in
		      let frame =
			{ frame_type = `Proto_header;
			  frame_channel = 0;
			  frame_payload = [mk_mstring p]
			} in
		      Netpagebuffer.delete_hd rd_buffer 8;
		      raise (Continue (fun () -> return_msg frame))
		    else raise (Continue est_reading)
		    raise(Error "Bad frame header")
		else (
		  let channel =
		    Netamqp_rtypes.read_uint2_unsafe s 1 in
		  let size =
		    Netnumber.BE.read_uint4_unsafe s 3 in
		  let max_size =
		    Netnumber.uint4_of_int (max_frame_size-8) in
		  if Netnumber.gt_uint4 size max_size then
		    raise(Error "Frame too long");
		  let size =
		    Netnumber.int_of_uint4 size in
		  rd_mode <- `Payload(frame_type, channel, size, 7)
		raise (Continue process)
		| Continue f -> (* call f at tail position *)
		| error ->
		    return_error error
	    else est_reading()
	| `Payload(frame_type, channel, size, payload_start) ->
	    if len >= size+payload_start+1 then (
	      let trailer =
		Netpagebuffer.sub rd_buffer (payload_start+size) 1 in
	      if trailer = "\xCE" then (
		let data =
		  Netpagebuffer.sub_bytes rd_buffer payload_start size in
		let ms =
		  mk_mstring data in
		let frame = 
		  { frame_type = frame_type;
		    frame_channel = channel;
		    frame_payload = [ms]
		  } in
		Netpagebuffer.delete_hd rd_buffer (payload_start+size+1);
		rd_mode <- `Frame_header 0;
		return_msg frame
	      else return_error (Error "Bad frame end")
	    else est_reading()

    and return_msg msg =
      rd_stream_at_beginning <- false;
      if not aborted then
	when_done (`Ok msg)

    and return_error e =
      rd_processing <- false;
      if not aborted then
	when_done (`Error e)

    and return_eof () =
      rd_processing <- false;
      if not aborted then
	when_done `End_of_file 

    if rd_processing then
      process ()

  method start_writing ~when_done frame =

    assert(not mplex#writing);

    (* - `Bytes(s,p,l): We have still to write s[p] to s[p+l-1]
       - `Memory(m,p,l,ms,q): We have still to write
          m[p] to m[p+l-1], followed by ms[q] to end of ms
          (where ms is the managed string)

    let item_of_mstring ms r =
      (* Create the item for ms, starting at offset r *)
      let l = ms#length in
      assert(r <= l);
      match ms # preferred with
	| `Bytes ->
	    let (s,pos) = ms#as_bytes in (* usually only r=0 *)
	| `Memory ->
	    if mplex#mem_supported then (
	      let (m,pos) = ms#as_memory in
	      `Memory(m, pos+r, l-r, ms, l)
	      let (s,pos) = ms#as_bytes in
	      `Bytes(s,pos+r,l-r) in

    let rec optimize_items items =
      (* Merge adjacent short items (only for strings) *)
      match items with
	| (`Bytes(s1,p1,l1) as i1) :: (`Bytes(s2,p2,l2) as i2) :: items' ->
	    if l1 < 256 && l2 < 256 then (
	      let b = Buffer.create (l1+l2) in
	      buffer_add_subbytes b s1 p1 l1;
	      buffer_add_subbytes b s2 p2 l2;
	      gather_items b items'
	      i1 :: optimize_items (i2 :: items')
	| other :: items' ->
	    other :: optimize_items items'
	| [] ->

    and gather_items b items =
      match items with
	| `Bytes(s,p,l) :: items' when l < 256 ->
	    buffer_add_subbytes b s p l;
	    gather_items b items'
	| _ ->
	    `Bytes(buffer_to_bytes b, 0, Buffer.length b) :: 
	      optimize_items items in

    let item_is_empty =
	| `Bytes(_,_,l) -> l=0
	| `Memory(_,_,l,ms,q) -> l=0 && ms#length=q in

    let rec est_writing item remaining =
      (* [item] is the current buffer to write; and [remaining] need to be
	 written after that
      let mplex_when_done exn_opt n = (* n bytes written *)
	self # timer_event `Stop `W;
	match exn_opt with
	  | None ->
	      ( match item with
		  | `Memory(m,p,l,ms,q) ->
		      let l' = l-n in
		      if l' > 0 then
			est_writing (`Memory(m,p+n,l',ms,q)) remaining
		      else (
			let mlen = ms#length in
			if q < mlen then
			  let item' = item_of_mstring ms q in
			  est_writing item' remaining
			  est_writing_next remaining
		  | `Bytes(s,p,l) ->
		      let l' = l-n in
		      if l' > 0 then
			est_writing (`Bytes(s,p+n,l')) remaining
			est_writing_next remaining
	  | Some Uq_engines.Cancelled ->
	      ()  (* ignore *)
	  | Some error ->
	      if not aborted then
		when_done (`Error error)

      ( match item with
	  | `Memory(m,p,l,_,_) ->
	      dlogr (fun () ->
		       sprintf "Writing [mem]: %s%s" 
			 (Rpc_util.hex_dump_m m p (min l 200))
			 (if l > 200 then "..." else "")
	      mplex # start_mem_writing
		~when_done:mplex_when_done m p l
	  | `Bytes(s,p,l) ->
	      dlogr (fun () ->
		       sprintf "Writing [str]: %s%s" 
			 (Rpc_util.hex_dump_b s p (min l 200))
			 (if l > 200 then "..." else "")
	      mplex # start_writing
		~when_done:mplex_when_done s p l
      self # timer_event `Start `W

    and  est_writing_next remaining =
      match remaining with
	| item :: remaining' ->
	    if item_is_empty item then
	      est_writing_next remaining'
	      est_writing item remaining'
	| [] ->
	    if not aborted then
	      when_done (`Ok ())

    let write mstrings =
	   ( (fun ms -> item_of_mstring ms 0) mstrings)) in

    match frame.frame_type with
      | `Proto_header ->
	  let s = Netxdr_mstring.concat_mstrings frame.frame_payload in
	  if String.length s <> 3 then
	    raise(Error "The `Proto_header frame requires a 3-byte payload");
	  let u = Bytes.of_string ("AMQP\000" ^ s) in
	  write [mk_mstring u]
      | _ ->
	  (* Create frame header and frame end mstrings: *)
	  let l = Netxdr_mstring.length_mstrings frame.frame_payload in
	  if l > max_frame_size then (
	      (fun () -> sprintf "l=%d max_frame_size=%d" l max_frame_size);
	    raise(Error "The frame is too large")
	  let s = Bytes.create 7 in
	  let c0 = 
	    match frame.frame_type with
	      | `Method -> '\001'
	      | `Header -> '\002'
	      | `Body -> '\003'
	      | `Heartbeat -> '\008'
	      | `Proto_header -> assert false in
	  Bytes.set s 0 c0;
	  Netamqp_rtypes.write_uint2 s 1 frame.frame_channel;
	  Netnumber.BE.write_uint4 s 3 (Netnumber.uint4_of_int l);
	  let header = mk_mstring s in
	  let trailer = mk_mstring (Bytes.of_string "\xCE") in
	  write (header :: (frame.frame_payload @ [trailer]))

  method cancel_rd_polling () =
    if mplex#reading then
      mplex # cancel_reading()

  method abort_rw () =
    aborted <- true;
    mplex # cancel_reading();
    mplex # cancel_writing()
  method start_shutting_down ~when_done () =
    dlogr (fun () ->
	     sprintf "start_shutting_down mplex=%d"
	       ( mplex));
    mplex # start_shutting_down
      ~when_done:(fun exn_opt ->
		    dlogr (fun () ->
			     sprintf "done shutting_down mplex=%d"
			       ( mplex));
		    self # timer_event `Stop `D;
		    match exn_opt with
		      | None -> when_done (`Ok ())
		      | Some error -> when_done (`Error error)
    self # timer_event `Start `D

  method cancel_shutting_down () =
    self # timer_event `Stop `D;
    mplex # cancel_shutting_down()

  method inactivate () =
    dlogr (fun () ->
	     sprintf "inactivate mplex=%d"
	       ( mplex));
    self # stop_timer();
    mplex # inactivate()

  val mutable timer = None
  val mutable timer_r = `Stop
  val mutable timer_w = `Stop
  val mutable timer_d = `Stop
  val mutable timer_group = None

  method set_timeout ~notify tmo =
    timer <- Some(notify, tmo)

  method private timer_event start_stop which =
    ( match timer with
	| None -> ()
	| Some(notify, tmo) ->
	    ( match which with
		| `R -> timer_r <- start_stop
		| `W -> timer_w <- start_stop
		| `D -> timer_d <- start_stop
	    self # stop_timer();
	    if timer_r = `Start || timer_w = `Start || timer_d = `Start then (
	      let g = Unixqueue.new_group esys in
	      timer_group <- Some g;
	      Unixqueue.once esys g tmo
		(fun () -> 
		   timer_group <- None;

  method private stop_timer() =
    ( match timer_group with
	| None -> ()
	| Some g -> Unixqueue.clear esys g
    timer_group <- None;
    timer_r <- `Stop;
    timer_w <- `Stop;
    timer_d <- `Stop


let tcp_amqp_multiplex_controller ?(close_inactive_descr=true)
                                  fd esys =
  let sockname = 
      `Sockaddr(Unix.getsockname fd) 
      | Unix.Unix_error(_,_,_) -> `Implied in
  let peername = 
      `Sockaddr(Netsys.getpeername fd)
      | Unix.Unix_error(_,_,_) -> `Implied in
  let mplex1 = 
      ~close_inactive_descr ~preclose
      fd esys in
  let mplex =
    match tls_config with
      | None -> mplex1
      | Some(c, host_opt) ->
            c mplex1 in
  new tcp_amqp_multiplex_controller sockname peername mplex esys

