(* $Id: netchannels.ml 1717 2012-02-20 17:31:58Z gerd $ * ---------------------------------------------------------------------- * *) exception Closed_channel exception Buffer_underrun exception Command_failure of Unix.process_status let () = Netexn.register_printer (Command_failure(Unix.WEXITED 0)) (fun e -> match e with | Command_failure ps -> let ps_str = match ps with | Unix.WEXITED n -> "WEXITED " ^ string_of_int n | Unix.WSIGNALED n -> "WSIGNALED " ^ string_of_int n | Unix.WSTOPPED n -> "WSTOPPED " ^ string_of_int n in "Netchannels.Command_failure(" ^ ps_str ^ ")" | _ -> assert false ) let () = Netsys_signal.init() class type rec_in_channel = object method input : string -> int -> int -> int method close_in : unit -> unit end class type raw_in_channel = object inherit rec_in_channel method pos_in : int (* number of read characters *) end type input_result = [ `Data of int | `Separator of string ] class type enhanced_raw_in_channel = object inherit raw_in_channel method private enhanced_input_line : unit -> string method private enhanced_input : string -> int -> int -> input_result end class type rec_out_channel = object method output : string -> int -> int -> int method close_out : unit -> unit method flush : unit -> unit end class type raw_out_channel = object inherit rec_out_channel method pos_out : int (* number of written characters *) end class type raw_io_channel = object inherit raw_in_channel inherit raw_out_channel end class type compl_in_channel = object (* Classic operations: *) method really_input : string -> int -> int -> unit method input_char : unit -> char method input_line : unit -> string method input_byte : unit -> int end class type in_obj_channel = object inherit raw_in_channel inherit compl_in_channel end class type compl_out_channel = object (* Classic operations: *) method really_output : string -> int -> int -> unit method output_char : char -> unit method output_string : string -> unit method output_byte : int -> unit method output_buffer : Buffer.t -> unit method output_channel : ?len:int -> in_obj_channel -> unit (* ~len: optionally limit the number of bytes *) end class type out_obj_channel = object inherit raw_out_channel inherit compl_out_channel end class type io_obj_channel = object inherit in_obj_channel inherit out_obj_channel end class type trans_out_obj_channel = object inherit out_obj_channel method commit_work : unit -> unit method rollback_work : unit -> unit end ;; (* error_behavior: currently not used. This was a proposal to control * error handling, but it is not clear whether it is really * useful or not. * I do not delete these types because they remind us of this * possibility. Maybe we find an outstanding example for them, and * want to have them back. *) type error_behavior = [ `Close | `Fun of (unit -> unit) | `None ] type extended_error_behavior = [ `Close | `Rollback | `Fun of (unit -> unit) | `None ] type close_mode = [ `Commit | `Rollback ];; (* Delegation *) class rec_in_channel_delegation ?(close=true) (ch:rec_in_channel) = object(self) method input = ch#input method close_in() = if close then ch#close_in() end class raw_in_channel_delegation ?(close=true) (ch:raw_in_channel) = object(self) method input = ch#input method close_in() = if close then ch#close_in() method pos_in = ch#pos_in end class in_obj_channel_delegation ?(close=true) (ch:in_obj_channel) = object(self) method input = ch#input method close_in() = if close then ch#close_in() method pos_in = ch#pos_in method really_input = ch#really_input method input_char = ch#input_char method input_line = ch#input_line method input_byte = ch#input_byte end class rec_out_channel_delegation ?(close=true) (ch:rec_out_channel) = object(self) method output = ch#output method close_out() = if close then ch#close_out() method flush = ch#flush end class raw_out_channel_delegation ?(close=true) (ch:raw_out_channel) = object(self) method output = ch#output method close_out() = if close then ch#close_out() method flush = ch#flush method pos_out = ch#pos_out end class out_obj_channel_delegation ?(close=true) (ch:out_obj_channel) = object(self) method output = ch#output method close_out() = if close then ch#close_out() method flush = ch#flush method pos_out = ch#pos_out method really_output = ch#really_output method output_char = ch#output_char method output_string = ch#output_string method output_byte = ch#output_byte method output_buffer = ch#output_buffer method output_channel = ch#output_channel end (****************************** input ******************************) class input_channel ?(onclose=fun () -> ()) ch (* : in_obj_channel *) = object (self) val ch = ch val mutable closed = false method private complain_closed() = raise Closed_channel method input buf pos len = if closed then self # complain_closed(); try if len=0 then raise Sys_blocked_io; let n = Pervasives.input ch buf pos len in if n=0 then raise End_of_file else n with Sys_blocked_io -> 0 method really_input buf pos len = if closed then self # complain_closed(); Pervasives.really_input ch buf pos len method input_char () = if closed then self # complain_closed(); Pervasives.input_char ch method input_line () = if closed then self # complain_closed(); Pervasives.input_line ch method input_byte () = if closed then self # complain_closed(); Pervasives.input_byte ch method close_in () = if not closed then ( Pervasives.close_in ch; closed <- true; onclose() ) method pos_in = if closed then self # complain_closed(); Pervasives.pos_in ch end ;; class input_command cmd = let ch = Unix.open_process_in cmd in object (self) inherit input_channel ch as super method close_in() = if not closed then ( let p = Unix.close_process_in ch in closed <- true; if p <> Unix.WEXITED 0 then raise (Command_failure p); ) end ;; class input_string ?(pos = 0) ?len s : in_obj_channel = object (self) val mutable str = s val mutable str_len = match len with None -> String.length s | Some l -> pos + l val mutable str_pos = pos val mutable closed = false initializer if str_pos < 0 || str_pos > String.length str || str_len < 0 || str_len > String.length s then invalid_arg "new Netchannels.input_string"; method private complain_closed() = raise Closed_channel method input buf pos len = if closed then self # complain_closed(); if pos < 0 || len < 0 || pos+len > String.length buf then invalid_arg "input"; let n = min len (str_len - str_pos) in String.blit str str_pos buf pos n; str_pos <- str_pos + n; if n=0 && len>0 then raise End_of_file else n method really_input buf pos len = if closed then self # complain_closed(); if pos < 0 || len < 0 || pos+len > String.length buf then invalid_arg "really_input"; let n = self # input buf pos len in if n <> len then raise End_of_file; () method input_char() = if closed then self # complain_closed(); if str_pos >= str_len then raise End_of_file; let c = str.[ str_pos ] in str_pos <- str_pos + 1; c method input_line() = if closed then self # complain_closed(); try let k = String.index_from str str_pos '\n' in (* CHECK: Are the different end of line conventions important here? *) let line = String.sub str str_pos (k - str_pos) in str_pos <- k+1; line with Not_found -> if str_pos >= str_len then raise End_of_file; (* Implicitly add linefeed at the end of the file: *) let line = String.sub str str_pos (str_len - str_pos) in str_pos <- str_len; line method input_byte() = Char.code (self # input_char()) method close_in() = str <- ""; closed <- true; method pos_in = if closed then self # complain_closed(); str_pos end ;; class type nb_in_obj_channel = object inherit in_obj_channel method shutdown : unit -> unit end class input_netbuffer b : nb_in_obj_channel = object (self) val mutable b = b val mutable eof = false val mutable closed = false val mutable ch_pos = 0 method private complain_closed() = raise Closed_channel method input buf pos len = if closed then self # complain_closed(); if pos < 0 || len < 0 || pos+len > String.length buf then invalid_arg "input"; let n = min len (Netbuffer.length b) in if n = 0 && len>0 then begin if eof then raise End_of_file else raise Buffer_underrun end else begin Netbuffer.blit b 0 buf pos n; Netbuffer.delete b 0 n; ch_pos <- ch_pos + n; n end method really_input buf pos len = if closed then self # complain_closed(); if pos < 0 || len < 0 || pos+len > String.length buf then invalid_arg "really_input"; let n = self # input buf pos len in if n <> len then raise End_of_file; () method input_char() = if closed then self # complain_closed(); let s = String.create 1 in match self # input s 0 1 with | 1 -> s.[0] | _ -> assert false method input_line() = if closed then self # complain_closed(); try let k = Netbuffer.index_from b 0 '\n' in (* CHECK: Are the different end of line conventions important here? *) let line = Netbuffer.sub b 0 k in Netbuffer.delete b 0 (k+1); ch_pos <- ch_pos + k + 1; line with Not_found -> if eof then begin if Netbuffer.length b = 0 then raise End_of_file; (* Implicitly add linefeed at the end of the file: *) let line = Netbuffer.contents b in Netbuffer.clear b; ch_pos <- ch_pos + (Netbuffer.length b); line end else raise Buffer_underrun method input_byte() = Char.code (self # input_char()) method close_in() = closed <- true; method pos_in = if closed then self # complain_closed(); ch_pos method shutdown() = eof <- true end ;; let create_input_netbuffer b = let ch = new input_netbuffer b in (ch :> in_obj_channel), (ch # shutdown) ;; let lexbuf_of_in_obj_channel (objch : in_obj_channel) : Lexing.lexbuf = let fill_buffer buf len = try let n = objch # input buf 0 len in if n=0 then failwith "Netchannels.lexbuf_of_in_obj_channel: No data (non-blocking I/O?)"; n with End_of_file -> 0 in Lexing.from_function fill_buffer ;; let string_of_in_obj_channel (objch : in_obj_channel) : string = (* There are similarities to copy_channel below. *) (* The following algorithm uses only up to 2 * N memory, not 3 * N * as with the Buffer module. *) let slen = 1024 in let l = ref [] in let k = ref 0 in try while true do let s = String.create slen in let n = objch # input s 0 slen in if n = 0 then failwith "Netchannels.string_of_in_obj_channel: No data (non-blocking I/O?)"; k := !k + n; if n < slen then l := (String.sub s 0 n) :: !l else l := s :: !l; done; assert false with End_of_file -> let s = String.create !k in while !l <> [] do match !l with u :: l' -> let n = String.length u in k := !k - n; String.blit u 0 s !k n; l := l' | [] -> assert false done; assert (!k = 0); s ;; let lines_of_in_obj_channel ch = let acc = ref [] in try while true do acc := ch#input_line() :: !acc done; assert false with | End_of_file -> List.rev !acc ;; let with_in_obj_channel ch f = try let result = f ch in ( try ch # close_in() with Closed_channel -> ()); result with e -> ( try ch # close_in() with Closed_channel -> ()); raise e ;; class virtual augment_raw_in_channel = object (self) method virtual input : string -> int -> int -> int method virtual close_in : unit -> unit method virtual pos_in : int method really_input s pos len = let rec read_rest n = if n < len then let m = self # input s (pos+n) (len-n) in if m = 0 then raise Sys_blocked_io; read_rest (n+m) else () in read_rest 0 method input_char () = let s = String.create 1 in self # really_input s 0 1; s.[0] method input_byte () = let s = String.create 1 in self # really_input s 0 1; Char.code s.[0] method input_line () = let s = String.create 1 in let b = Buffer.create 80 in let m = self # input s 0 1 in if m = 0 then raise Sys_blocked_io; while s.[0] <> '\n' do Buffer.add_char b s.[0]; try let m = self # input s 0 1 in if m = 0 then raise Sys_blocked_io; with End_of_file -> s.[0] <- '\n' done; Buffer.contents b end ;; class lift_raw_in_channel r = object(self) inherit augment_raw_in_channel method input s p l = r # input s p l method close_in () = r # close_in() method pos_in = r # pos_in end;; class lift_rec_in_channel ?(start_pos_in = 0) (r : rec_in_channel) = object(self) inherit augment_raw_in_channel val mutable closed = false val mutable pos_in = start_pos_in method input s p l = if closed then raise Closed_channel; let n = r # input s p l in pos_in <- pos_in + n; n method close_in () = if not closed then ( closed <- true; r # close_in() ) method pos_in = if closed then raise Closed_channel; pos_in end;; type eol_status = EOL_not_found | EOL_partially_found of int (* Position *) | EOL_found of int * int (* Position, length *) exception Pass_through class buffered_raw_in_channel ?(eol = [ "\n" ]) ?(buffer_size = 4096) ?(pass_through = max_int) (ch : raw_in_channel) : enhanced_raw_in_channel = object (self) val out = ch val bufsize = buffer_size val buf = String.create buffer_size val mutable bufpos = 0 val mutable buflen = 0 val mutable eof = false val mutable closed = false initializer if List.exists(fun s -> s = "") eol then invalid_arg "Netchannels.buffered_raw_in_channel"; if List.exists(fun s -> String.length s > buffer_size) eol then invalid_arg "Netchannels.buffered_raw_in_channel"; method input s pos len = if closed then raise Closed_channel; try if len > 0 then ( if bufpos = buflen then ( if len >= pass_through then raise Pass_through else self # refill(); ); let n = min len (buflen - bufpos) in String.blit buf bufpos s pos n; bufpos <- bufpos + n; n ) else 0 with Pass_through -> ch # input s pos len method private refill() = let d = bufpos in if d > 0 && d < buflen then ( String.blit buf d buf 0 (buflen-d) ); bufpos <- 0; buflen <- buflen - d; try assert(bufsize > buflen); (* otherwise problems... *) let n = ch # input buf buflen (bufsize-buflen) in (* or End_of_file *) if n = 0 then raise Sys_blocked_io; buflen <- buflen+n; with End_of_file as exn -> eof <- true; raise exn method close_in () = if not closed then ( ch # close_in(); closed <- true ) method pos_in = (ch # pos_in) - (buflen - bufpos) method private find_eol() = (* Try all strings from [eol] in turn. For every string we may * have three results: * - Not found * - Partially found * - Found * The eol delimiter is only found if there are no partial * results, and at least one positive result. The longest * string is taken. *) let find_this_eol eol = (* Try to find the eol string [eol] in [buf] starting at * [bufpos] up to [buflen]. Return [eol_status]. *) let eol0 = eol.[0] in try let k = String.index_from buf bufpos eol0 in (* or Not_found *) if k>=buflen then raise Not_found; let k' = min buflen (k+String.length eol) in let s = String.sub buf k (k' - k) in if s = eol then EOL_found(k, String.length eol) else if not eof && String.sub eol 0 (String.length s) = s then EOL_partially_found k else EOL_not_found with Not_found -> EOL_not_found in let rec find_best_eol best eol_result = match eol_result with EOL_not_found :: eol_result' -> find_best_eol best eol_result' | EOL_partially_found pos as r :: eol_result' -> ( match best with EOL_partially_found pos' -> if pos < pos' then find_best_eol r eol_result' else find_best_eol best eol_result' | _ -> find_best_eol r eol_result' ) | EOL_found(pos,len) as r :: eol_result' -> ( match best with EOL_found(pos',len') -> if pos < pos' || (pos = pos' && len > len') then find_best_eol r eol_result' else find_best_eol best eol_result' | EOL_partially_found _ -> find_best_eol best eol_result' | EOL_not_found -> find_best_eol r eol_result' ) | [] -> best in let eol_results = List.map find_this_eol eol in find_best_eol EOL_not_found eol_results method private enhanced_input s pos len : input_result = if closed then raise Closed_channel; if len > 0 then ( if bufpos = buflen then ( self # refill(); (* may raise End_of_file *) ); let result = ref None in while !result = None do let best = self # find_eol() in match best with EOL_not_found -> let n = min len (buflen - bufpos) in String.blit buf bufpos s pos n; bufpos <- bufpos + n; result := Some(`Data n) | EOL_found(p,l) -> if p = bufpos then ( bufpos <- bufpos + l; result := Some(`Separator(String.sub buf p l)) ) else ( let n = min len (p - bufpos) in String.blit buf bufpos s pos n; bufpos <- bufpos + n; result := Some(`Data n) ) | EOL_partially_found p -> if p = bufpos then ( try self # refill() with End_of_file -> () (* ... and continue! *) ) else ( let n = min len (p - bufpos) in String.blit buf bufpos s pos n; bufpos <- bufpos + n; result := Some(`Data n) ) done; match !result with None -> assert false | Some r -> r ) else `Data 0 method private enhanced_input_line() = if closed then raise Closed_channel; let b = Buffer.create 80 in let eol_found = ref false in if bufpos = buflen then ( self # refill(); (* may raise End_of_file *) ); while not !eol_found do let best = self # find_eol() in try match best with EOL_not_found -> Buffer.add_substring b buf bufpos (buflen-bufpos); bufpos <- buflen; self # refill(); (* may raise End_of_file *) | EOL_partially_found pos -> Buffer.add_substring b buf bufpos (pos-bufpos); bufpos <- pos; self # refill(); (* may raise End_of_file *) | EOL_found(pos,len) -> Buffer.add_substring b buf bufpos (pos-bufpos); bufpos <- pos+len; eol_found := true with End_of_file -> bufpos <- 0; buflen <- 0; eof <- true; eol_found := true done; Buffer.contents b end ;; class lift_raw_in_channel_buf ?eol ?buffer_size ?pass_through r = object(self) inherit buffered_raw_in_channel ?eol ?buffer_size ?pass_through r inherit augment_raw_in_channel method input_line () = self # enhanced_input_line() end;; type lift_in_arg = [ `Rec of rec_in_channel | `Raw of raw_in_channel ] let lift_in ?(eol = ["\n"]) ?(buffered=true) ?buffer_size ?pass_through (x : lift_in_arg) = match x with `Rec r when not buffered -> if eol <> ["\n"] then invalid_arg "Netchannels.lift_in"; new lift_rec_in_channel r | `Rec r when buffered -> let r' = new lift_rec_in_channel r in new lift_raw_in_channel_buf ~eol ?buffer_size ?pass_through (r' :> raw_in_channel) | `Raw r when not buffered -> if eol <> ["\n"] then invalid_arg "Netchannels.lift_in"; new lift_raw_in_channel r | `Raw r when buffered -> new lift_raw_in_channel_buf ~eol ?buffer_size ?pass_through r ;; (****************************** output ******************************) exception No_end_of_file let copy_channel ?(buf = String.create 1024) ?len (src_ch : in_obj_channel) (dest_ch : out_obj_channel) = (* Copies contents from src_ch to dest_ch. Returns [true] if at EOF. *) let slen = String.length buf in let k = ref 0 in try while true do let m = min slen (match len with Some x -> x - !k | None -> max_int) in if m <= 0 then raise No_end_of_file; let n = src_ch # input buf 0 m in if n = 0 then raise Sys_blocked_io; dest_ch # really_output buf 0 n; k := !k + n done; assert false with End_of_file -> true | No_end_of_file -> false ;; class output_channel ?(onclose = fun () -> ()) ch (* : out_obj_channel *) = let errflag = ref false in let monitored f arg = try let r = f arg in errflag := false; r with | error -> errflag := true; raise error in object (self) val ch = ch val onclose = onclose val mutable closed = false method private complain_closed() = raise Closed_channel method output buf pos len = if closed then self # complain_closed(); (* Pervasives.output does not support non-blocking I/O directly. * Work around it: *) let p0 = Pervasives.pos_out ch in try Pervasives.output ch buf pos len; errflag := false; len with | Sys_blocked_io -> let p1 = Pervasives.pos_out ch in errflag := false; p1 - p0 | error -> errflag := true; raise error method really_output buf pos len = if closed then self # complain_closed(); monitored (Pervasives.output ch buf pos) len method output_char c = if closed then self # complain_closed(); monitored (Pervasives.output_char ch) c method output_string s = if closed then self # complain_closed(); monitored (Pervasives.output_string ch) s method output_byte b = if closed then self # complain_closed(); monitored (Pervasives.output_byte ch) b method output_buffer b = if closed then self # complain_closed(); monitored(Buffer.output_buffer ch) b method output_channel ?len ch = if closed then self # complain_closed(); ignore (monitored (copy_channel ?len ch) (self : #out_obj_channel :> out_obj_channel)) method flush() = if closed then self # complain_closed(); monitored Pervasives.flush ch method close_out() = if not closed then ( ( try (* if !errflag is set, we know that the immediately preceding operation raised an exception, and we are now likely in the exception handler *) if !errflag then Pervasives.close_out_noerr ch else Pervasives.close_out ch; closed <- true; with | error -> let bt = Printexc.get_backtrace() in Netlog.logf `Err "Netchannels.output_channel: \ Suppressed error in close_out: %s - backtrace: %s" (Netexn.to_string error) bt; Pervasives.close_out_noerr ch; closed <- true; ); onclose() ) method pos_out = if closed then self # complain_closed(); Pervasives.pos_out ch end ;; class output_command ?onclose cmd = let ch = Unix.open_process_out cmd in object (self) inherit output_channel ?onclose ch as super method close_out() = if not closed then ( let p = Unix.close_process_out ch in closed <- true; onclose(); if p <> Unix.WEXITED 0 then raise (Command_failure p); (* Keep this *) ) end ;; class output_buffer ?(onclose = fun () -> ()) buffer : out_obj_channel = object(self) val buffer = buffer val onclose = onclose val mutable closed = false method private complain_closed() = raise Closed_channel method output buf pos len = if closed then self # complain_closed(); Buffer.add_substring buffer buf pos len; len method really_output buf pos len = if closed then self # complain_closed(); Buffer.add_substring buffer buf pos len method output_char c = if closed then self # complain_closed(); Buffer.add_char buffer c method output_string s = if closed then self # complain_closed(); Buffer.add_string buffer s method output_byte b = if closed then self # complain_closed(); Buffer.add_char buffer (Char.chr b) method output_buffer b = if closed then self # complain_closed(); Buffer.add_buffer buffer b method output_channel ?len ch = if closed then self # complain_closed(); ignore(copy_channel ?len ch (self : #out_obj_channel :> out_obj_channel)) method flush() = if closed then self # complain_closed(); () method close_out() = if not closed then ( closed <- true; onclose() ) method pos_out = if closed then self # complain_closed(); Buffer.length buffer end ;; class output_netbuffer ?(onclose = fun () -> ()) buffer : out_obj_channel = object(self) val buffer = buffer val onclose = onclose val mutable closed = false val mutable ch_pos = 0 method private complain_closed() = raise Closed_channel method output buf pos len = if closed then self # complain_closed(); Netbuffer.add_sub_string buffer buf pos len; ch_pos <- ch_pos + len; len method really_output buf pos len = if closed then self # complain_closed(); Netbuffer.add_sub_string buffer buf pos len; ch_pos <- ch_pos + len; method output_char c = if closed then self # complain_closed(); Netbuffer.add_string buffer (String.make 1 c); ch_pos <- ch_pos + 1; method output_string s = if closed then self # complain_closed(); Netbuffer.add_string buffer s; ch_pos <- ch_pos + String.length s method output_byte b = if closed then self # complain_closed(); Netbuffer.add_string buffer (String.make 1 (Char.chr b)); ch_pos <- ch_pos + 1; method output_buffer b = if closed then self # complain_closed(); Netbuffer.add_string buffer (Buffer.contents b); ch_pos <- ch_pos + Buffer.length b; method output_channel ?len ch = if closed then self # complain_closed(); ignore(copy_channel ?len ch (self : #out_obj_channel :> out_obj_channel)) method flush() = if closed then self # complain_closed(); () method close_out() = if not closed then ( closed <- true; onclose() ) method pos_out = if closed then self # complain_closed(); ch_pos (* We cannot return Netbuffer.length b as [pos_out] (like in the class * [output_buffer]) because the user of this class is allowed to delete * data from the netbuffer. So we manually count how many bytes are * ever appended to the netbuffer. * This behavior is especially needed by [pipe_channel] below. *) end ;; class output_null ?(onclose = fun () -> ()) () : out_obj_channel = object(self) val mutable closed = false val mutable pos = 0 method private complain_closed() = raise Closed_channel method output s start len = if closed then self # complain_closed(); pos <- pos + len; len method really_output s start len = if closed then self # complain_closed(); pos <- pos + len method output_char _ = if closed then self # complain_closed(); pos <- pos + 1 method output_string s = if closed then self # complain_closed(); pos <- pos + String.length s method output_byte _ = if closed then self # complain_closed(); pos <- pos + 1 method output_buffer b = if closed then self # complain_closed(); pos <- pos + Buffer.length b method output_channel ?len ch = if closed then self # complain_closed(); ignore(copy_channel ?len ch (self : #out_obj_channel :> out_obj_channel)) method flush() = if closed then self # complain_closed(); method close_out() = closed <- true method pos_out = if closed then self # complain_closed(); pos end ;; let with_out_obj_channel ch f = try let result = f ch in (* we _have_ to flush here because close_out often does no longer report exceptions *) ( try ch # flush() with Closed_channel -> ()); ( try ch # close_out() with Closed_channel -> ()); result with e -> ( try ch # close_out() with Closed_channel -> ()); raise e ;; class virtual augment_raw_out_channel = object (self) method virtual output : string -> int -> int -> int method virtual close_out : unit -> unit method virtual flush : unit -> unit method virtual pos_out : int method really_output s pos len = let rec print_rest n = if n < len then let m = self # output s (pos+n) (len-n) in if m=0 then raise Sys_blocked_io; print_rest (n+m) else () in print_rest 0 method output_char c = ignore(self # output (String.make 1 c) 0 1) method output_byte n = ignore(self # output (String.make 1 (Char.chr n)) 0 1) method output_string s = self # really_output s 0 (String.length s) method output_buffer b = self # output_string (Buffer.contents b) method output_channel ?len ch = ignore(copy_channel ?len ch (self : #out_obj_channel :> out_obj_channel)) end ;; class lift_raw_out_channel (r : raw_out_channel) = object(self) inherit augment_raw_out_channel method output s p l = r # output s p l method flush () = r # flush() method close_out () = r # close_out() method pos_out = r # pos_out end;; class lift_rec_out_channel ?(start_pos_out = 0) (r : rec_out_channel) = object(self) inherit augment_raw_out_channel val mutable closed = false val mutable pos_out = start_pos_out method output s p l = if closed then raise Closed_channel; let n = r # output s p l in pos_out <- pos_out + n; n method flush() = if closed then raise Closed_channel; r # flush(); method close_out () = if not closed then ( closed <- true; r # close_out() ) method pos_out = if closed then raise Closed_channel; pos_out end;; class buffered_raw_out_channel ?(buffer_size = 4096) ?(pass_through = max_int) (ch : raw_out_channel) : raw_out_channel = object (self) val out = ch val bufsize = buffer_size val buf = String.create buffer_size val mutable bufpos = 0 val mutable closed = false method output s pos len = if closed then raise Closed_channel; if bufpos=0 && len >= pass_through then ch # output s pos len else let n = min len (bufsize - bufpos) in String.blit s pos buf bufpos n; bufpos <- bufpos + n; if bufpos = bufsize then self # flush(); n method flush() = let k = ref 0 in while !k < bufpos do k := !k + (ch # output buf !k (bufpos - !k)) done; bufpos <- 0; ch # flush() method close_out() = if not closed then ( ( try self # flush() with | error -> let bt = Printexc.get_backtrace() in Netlog.logf `Err "Netchannels.buffered_raw_out_channel: \ Suppressed error in close_out: %s - backtrace: %s" (Netexn.to_string error) bt; ); ch # close_out(); closed <- true ) method pos_out = (ch # pos_out) + bufpos end ;; type lift_out_arg = [ `Rec of rec_out_channel | `Raw of raw_out_channel ] let lift_out ?(buffered=true) ?buffer_size ?pass_through (x : lift_out_arg) = match x with `Rec r when not buffered -> new lift_rec_out_channel r | `Rec r when buffered -> let r' = new lift_rec_out_channel r in let r'' = new buffered_raw_out_channel ?buffer_size ?pass_through (r' :> raw_out_channel) in new lift_raw_out_channel r'' | `Raw r when not buffered -> new lift_raw_out_channel r | `Raw r when buffered -> let r' = new buffered_raw_out_channel ?buffer_size ?pass_through r in new lift_raw_out_channel r' ;; (************************* raw channels *******************************) class input_descr_prelim ?(blocking=true) ?(start_pos_in = 0) fd = let fd_style = Netsys.get_fd_style fd in object (self) val fd_in = fd val mutable pos_in = start_pos_in val mutable closed_in = false method private complain_closed() = raise Closed_channel method input buf pos len = if closed_in then self # complain_closed(); try let n = Netsys.gread fd_style fd_in buf pos len in pos_in <- pos_in + n; if n=0 && len>0 then raise End_of_file; n with Unix.Unix_error(Unix.EINTR,_,_) -> self # input buf pos len | Unix.Unix_error(Unix.EAGAIN,_,_) | Unix.Unix_error(Unix.EWOULDBLOCK,_,_) -> if blocking then ( let _ = Netsys.restart (Netsys.wait_until_readable fd_style fd) (-1.0) in self # input buf pos len ) else 0 method close_in () = if not closed_in then ( Netsys.gclose fd_style fd_in; closed_in <- true ) method pos_in = if closed_in then self # complain_closed(); pos_in end ;; class input_descr ?blocking ?start_pos_in fd : raw_in_channel = input_descr_prelim ?blocking ?start_pos_in fd ;; class output_descr_prelim ?(blocking=true) ?(start_pos_out = 0) fd = let fd_style = Netsys.get_fd_style fd in object (self) val fd_out = fd val mutable pos_out = start_pos_out val mutable closed_out = false method private complain_closed() = raise Closed_channel method output buf pos len = if closed_out then self # complain_closed(); try let n = Netsys.gwrite fd_style fd_out buf pos len in pos_out <- pos_out + n; n with Unix.Unix_error(Unix.EINTR,_,_) -> self # output buf pos len | Unix.Unix_error(Unix.EAGAIN,_,_) | Unix.Unix_error(Unix.EWOULDBLOCK,_,_) -> if blocking then ( let _ = Netsys.restart (Netsys.wait_until_writable fd_style fd) (-1.0) in self # output buf pos len ) else 0 method close_out () = if not closed_out then ( ( try Netsys.gshutdown fd_style fd Unix.SHUTDOWN_SEND with | Netsys.Shutdown_not_supported -> () | Unix.Unix_error(Unix.EAGAIN, _, _) -> (* FIXME. We block here even when non-blocking semantics is requested. We do this because most programmers would be surprised to get EAGAIN when closing a channel. Actually, this only affects Win32 output threads. *) let _ = Netsys.restart (Netsys.wait_until_writable fd_style fd) (-1.0) in Netsys.gshutdown fd_style fd Unix.SHUTDOWN_SEND | Unix.Unix_error(Unix.EPERM, _, _) -> () ); Netsys.gclose fd_style fd_out; closed_out <- true ) method pos_out = if closed_out then self # complain_closed(); pos_out method flush () = if closed_out then self # complain_closed() end ;; class output_descr ?blocking ?start_pos_out fd : raw_out_channel = output_descr_prelim ?blocking ?start_pos_out fd ;; class socket_descr ?blocking ?(start_pos_in = 0) ?(start_pos_out = 0) fd : raw_io_channel = let fd_style = Netsys.get_fd_style fd in let () = match fd_style with | `Recv_send _ | `Recv_send_implied | `W32_pipe -> () | _ -> failwith "Netchannels.socket_descr: This type of descriptor is \ unsupported" in object (self) inherit input_descr_prelim ?blocking ~start_pos_in fd inherit output_descr_prelim ?blocking ~start_pos_out fd method private gen_close cmd = ( try Netsys.gshutdown fd_style fd cmd with | Netsys.Shutdown_not_supported -> () | Unix.Unix_error(Unix.EAGAIN, _, _) -> assert false | Unix.Unix_error(Unix.EPERM, _, _) -> () ); if cmd = Unix.SHUTDOWN_ALL then Netsys.gclose fd_style fd method close_in () = if not closed_in then ( closed_in <- true; if closed_out then self # gen_close Unix.SHUTDOWN_ALL else self # gen_close Unix.SHUTDOWN_RECEIVE ) method close_out () = if not closed_out then ( closed_out <- true; if closed_in then self # gen_close Unix.SHUTDOWN_ALL else self # gen_close Unix.SHUTDOWN_SEND ) end ;; (************************** transactional *****************************) class buffered_trans_channel ?(close_mode = (`Commit : close_mode)) (ch : out_obj_channel) : trans_out_obj_channel = let closed = ref false in let transbuf = ref(Buffer.create 50) in let trans = ref(new output_buffer !transbuf) in let reset() = transbuf := Buffer.create 50; trans := new output_buffer !transbuf in object (self) val out = ch val close_mode = close_mode method output = !trans # output method really_output = !trans # really_output method output_char = !trans # output_char method output_string = !trans # output_string method output_byte = !trans # output_byte method output_buffer = !trans # output_buffer method output_channel = !trans # output_channel method flush = !trans # flush method close_out() = if not !closed then ( ( try ( match close_mode with `Commit -> self # commit_work() | `Rollback -> self # rollback_work() ) with | error -> let bt = Printexc.get_backtrace() in Netlog.logf `Err "Netchannels.buffered_trans_channel: \ Suppressed error in close_out: %s - backtrace: %s" (Netexn.to_string error) bt; ); !trans # close_out(); out # close_out(); closed := true ) method pos_out = out # pos_out + !trans # pos_out method commit_work() = try (* in any way avoid that the contents of transbuf are printed twice *) let b = !transbuf in reset(); out # output_buffer b; out # flush(); with err -> self # rollback_work(); (* reset anyway *) raise err method rollback_work() = reset() end ;; let make_temporary_file ?(mode = 0o600) ?(limit = 1000) ?(tmp_directory = Netsys_tmp.tmp_directory() ) ?(tmp_prefix = "netstring") () = (* Returns (filename, in_channel, out_channel). *) let rec try_creation n = try let fn = Filename.concat tmp_directory (Netsys_tmp.tmp_prefix tmp_prefix ^ "-" ^ (string_of_int n)) in let fd_in = Unix.openfile fn [ Unix.O_RDWR; Unix.O_CREAT; Unix.O_EXCL ] mode in let fd_out = Unix.openfile fn [ Unix.O_RDWR ] mode in (* For security reasons check that fd_in and fd_out are the same file: *) let stat_in = Unix.fstat fd_in in let stat_out = Unix.fstat fd_out in if stat_in.Unix.st_dev <> stat_out.Unix.st_dev || stat_in.Unix.st_rdev <> stat_out.Unix.st_rdev || stat_in.Unix.st_ino <> stat_out.Unix.st_ino then raise(Sys_error("File has been replaced (security alert)")); let ch_in = Unix.in_channel_of_descr fd_in in let ch_out = Unix.out_channel_of_descr fd_out in fn, ch_in, ch_out with Unix.Unix_error(Unix.EEXIST,_,_) -> (* This does not look very intelligent, but it is the only chance * to limit the number of trials. * Note that we get EACCES if the directory is not writeable. *) if n > limit then failwith ("Netchannels: Cannot create temporary file - too many files in this temp directory: " ^ tmp_directory); try_creation (n+1) | Unix.Unix_error(e,_,_) -> raise (Sys_error("Cannot create a temporary file in the directory " ^ tmp_directory ^ ": " ^ Unix.error_message e)) in try_creation 0 ;; class tempfile_trans_channel ?(close_mode = (`Commit : close_mode)) ?tmp_directory ?tmp_prefix (ch : out_obj_channel) : trans_out_obj_channel = let _transname, _transch_in, _transch_out = make_temporary_file ?tmp_directory ?tmp_prefix () in let closed = ref false in object (self) val transch_out = _transch_out val mutable transch_in = _transch_in val trans = new output_channel _transch_out val mutable out = ch val close_mode = close_mode val mutable need_clear = false initializer try Sys.remove _transname; (* Remove the file immediately. This requires "Unix semantics" of the * underlying file system, because we don't remove the file but only * the entry in the directory. So we can read and write the file and * allocate disk space, but the file is private from now on. (It's * not fully private, because another process can obtain a descriptor * between creation of the file and removal of the entry. We should * keep that in mind if privacy really matters.) * The disk space will be freed when the descriptor is closed. *) with err -> close_in _transch_in; close_out _transch_out; raise err method output = if need_clear then self#clear(); trans # output method really_output = if need_clear then self#clear(); trans # really_output method output_char = if need_clear then self#clear(); trans # output_char method output_string = if need_clear then self#clear(); trans # output_string method output_byte = if need_clear then self#clear(); trans # output_byte method output_buffer = if need_clear then self#clear(); trans # output_buffer method output_channel = if need_clear then self#clear(); trans #output_channel method flush = if need_clear then self#clear(); trans # flush method close_out() = if not !closed then ( if need_clear then self#clear(); ( try ( match close_mode with `Commit -> self # commit_work() | `Rollback -> self # rollback_work() ) with | error -> let bt = Printexc.get_backtrace() in Netlog.logf `Err "Netchannels.tempfile_trans_channel: \ Suppressed error in close_out: %s - backtrace: %s" (Netexn.to_string error) bt; ); Pervasives.close_in transch_in; trans # close_out(); (* closes transch_out *) out # close_out(); closed := true ) method pos_out = if need_clear then self#clear(); out # pos_out + trans # pos_out method commit_work() = need_clear <- true; let len = trans # pos_out in trans # flush(); Pervasives.seek_in transch_in 0; let trans' = new input_channel transch_in in ( try out # output_channel ~len trans'; out # flush(); with err -> self # rollback_work(); raise err ); self # clear() method rollback_work() = self # clear() method private clear() = (* delete the contents of the file *) (* First empty the file and reset the output channel: *) Pervasives.seek_out transch_out 0; Unix.ftruncate (Unix.descr_of_out_channel transch_out) 0; (* Renew the input channel. We create a new channel to avoid problems * with the internal buffer of the channel. * (Problem: transch_in has an internal buffer, and the buffer contains * old data now. So we drop the channel and create a new channel for the * same file descriptor. Note that we cannot set the file offset with * seek_in because neither the old nor the new channel is properly * synchronized with the file. So we fall back to lseek.) *) let fd = Unix.descr_of_in_channel transch_in in ignore(Unix.lseek fd 0 Unix.SEEK_END); (* set the offset *) transch_in <- Unix.in_channel_of_descr fd; (* renew channel *) (* Now check that everything worked: *) assert(pos_in transch_in = 0); assert(in_channel_length transch_in = 0); (* Note: the old transch_in will be automatically finalized, but the * underlying file descriptor will not be closed in this case *) need_clear <- false end ;; let id_conv incoming incoming_eof outgoing = (* Copies everything from [incoming] to [outgoing] *) let len = Netbuffer.length incoming in ignore (Netbuffer.add_inplace ~len outgoing (fun s_outgoing pos len' -> assert (len = len'); Netbuffer.blit incoming 0 s_outgoing pos len'; Netbuffer.clear incoming; len' )) ;; let call_input refill f arg = (* Try to satisfy the request: *) try f arg with Buffer_underrun -> (* Not enough data in the outgoing buffer. *) refill(); f arg ;; class pipe ?(conv = id_conv) ?(buffer_size = 1024) () : io_obj_channel = let _incoming = Netbuffer.create buffer_size in let _outgoing = Netbuffer.create buffer_size in object(self) (* The properties as "incoming buffer" [output_super] are simply inherited * from [output_netbuffer]. The "outgoing buffer" [input_super] invocations * are delegated to [input_netbuffer]. Inheritance does not work because * there is no way to make the public method [shutdown] private again. *) inherit output_netbuffer _incoming as output_super val conv = conv val incoming = _incoming val outgoing = _outgoing val input_super = new input_netbuffer _outgoing val mutable incoming_eof = false val mutable pos_in = 0 (* We must count positions ourselves. Can't use input_super#pos_in * because conv may manipulate the buffer. *) val mutable output_closed = false (* Input methods: *) method private refill() = conv incoming incoming_eof outgoing; if incoming_eof then input_super # shutdown() method input str pos len = let n = call_input self#refill (input_super#input str pos) len in pos_in <- pos_in + n; n method input_line() = let p = input_super # pos_in in let line = call_input self#refill (input_super#input_line) () in let p' = input_super # pos_in in pos_in <- pos_in + (p' - p); line method really_input str pos len = call_input self#refill (input_super#really_input str pos) len; pos_in <- pos_in + len method input_char() = let c = call_input self#refill (input_super#input_char) () in pos_in <- pos_in + 1; c method input_byte() = let b = call_input self#refill (input_super#input_byte) () in pos_in <- pos_in + 1; b method close_in() = (* [close_in] implies [close_out]: *) if not output_closed then ( output_super # close_out(); output_closed <- true; ); input_super # close_in() method pos_in = pos_in (* [close_out] also shuts down the input side of the pipe. *) method close_out () = if not output_closed then ( output_super # close_out(); output_closed <- true; ); incoming_eof <- true end class output_filter (p : io_obj_channel) (out : out_obj_channel) : out_obj_channel = object(self) val p = p val mutable p_closed = false (* output side of p is closed *) val out = out val buf = String.create 1024 (* for copy_channel *) method output s pos len = if p_closed then raise Closed_channel; let n = p # output s pos len in self # transfer(); n method really_output s pos len = if p_closed then raise Closed_channel; p # really_output s pos len; self # transfer(); method output_char c = if p_closed then raise Closed_channel; p # output_char c; self # transfer(); method output_string s = if p_closed then raise Closed_channel; p # output_string s; self # transfer(); method output_byte b = if p_closed then raise Closed_channel; p # output_byte b; self # transfer(); method output_buffer b = if p_closed then raise Closed_channel; p # output_buffer b; self # transfer(); method output_channel ?len ch = (* To avoid large intermediate buffers, the channel is copied * chunk by chunk *) if p_closed then raise Closed_channel; let len_to_do = ref (match len with None -> -1 | Some l -> max 0 l) in let buf = buf in while !len_to_do <> 0 do let n = if !len_to_do < 0 then 1024 else min !len_to_do 1024 in if copy_channel ~buf ~len:n ch (p :> out_obj_channel) then (* EOF *) len_to_do := 0 else if !len_to_do >= 0 then (len_to_do := !len_to_do - n; assert(!len_to_do >= 0)); self # transfer(); done method flush() = p # flush(); self # transfer(); out # flush() method close_out() = if not p_closed then ( p # close_out(); p_closed <- true; ( try self # transfer() with | error -> (* We report the error. However, we prevent that another, immediately following [close_out] reports the same error again. This is done by setting p_closed. *) raise error ) ) method pos_out = p # pos_out method private transfer() = (* Copy as much as possible from [p] to [out] *) try (* Call [copy_channel] directly (and not the method [output_channel]) * because we can pass the copy buffer ~buf *) ignore(copy_channel ~buf (p :> in_obj_channel) out); out # flush(); with Buffer_underrun -> () end let rec filter_input refill f arg = (* Try to satisfy the request: *) try f arg with Buffer_underrun -> (* Not enough data in the outgoing buffer. *) refill(); filter_input refill f arg ;; class input_filter (inp : in_obj_channel) (p : io_obj_channel) : in_obj_channel = object(self) val inp = inp val p = p val buf = String.create 1024 (* for copy_channel *) method private refill() = (* Copy some data from [inp] to [p] *) (* Call [copy_channel] directly (and not the method [output_channel]) * because we can pass the copy buffer ~buf *) let eof = copy_channel ~len:(String.length buf) ~buf inp (p :> out_obj_channel) in if eof then p # close_out(); method input str pos = filter_input self#refill (p#input str pos) method input_line = filter_input self#refill (p#input_line) method really_input str pos = filter_input self#refill (p#really_input str pos) method input_char = filter_input self#refill (p#input_char) method input_byte = filter_input self#refill (p#input_byte) method close_in() = p#close_in(); method pos_in = p#pos_in end