(* $Id: nethttpd_kernel.ml 2195 2015-01-01 12:23:39Z gerd $ * *) (* * Copyright 2005 Baretta s.r.l. and Gerd Stolpmann * * This file is part of Nethttpd. * * Nethttpd is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * Nethttpd is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with Nethttpd; if not, write to the Free Software * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA *) module Debug = struct let enable = ref false end let dlog = Netlog.Debug.mk_dlog "Nethttpd_kernel" Debug.enable let dlogr = Netlog.Debug.mk_dlogr "Nethttpd_kernel" Debug.enable let () = Netlog.Debug.register_module "Nethttpd_kernel" Debug.enable open Nethttpd_types open Nethttp open Nethttp.Header open Printf type fatal_error = [ `Broken_pipe | `Broken_pipe_ignore | `Message_too_long | `Timeout | `Unix_error of Unix.error | `TLS_error of string * string | `Server_error ] let string_of_fatal_error = function | `Broken_pipe -> "Nethttpd: Broken pipe" | `Broken_pipe_ignore -> "Nethttpd: Ignorable broken pipe" | `Message_too_long -> "Nethttpd: Message too long, dropping it" | `Timeout -> "Nethttpd: Connection timed out" | `Unix_error ue -> ("Nethttpd: System error: " ^ Unix.error_message ue) | `TLS_error(code,msg) -> ("Nethttpd: TLS error code " ^ code ^ ": " ^ msg) | `Server_error -> "Nethttpd: Terminating connection because of internal error" type bad_request_error = [ `Bad_header_field of string | `Bad_header | `Bad_trailer | `Bad_request_line | `Request_line_too_long | `Protocol_not_supported | `Unexpected_eof | `Format_error of string ] let string_of_bad_request_error = function | `Bad_header_field s -> "Nethttpd: Bad request header field: " ^ s | `Bad_header -> "Nethttpd: Bad request header" | `Bad_trailer -> "Nethttpd: Bad request trailer" | `Bad_request_line -> "Nethttpd: Bad request line" | `Request_line_too_long -> "Nethttpd: Request line too long" | `Protocol_not_supported -> "Nethttpd: Prototol not supported" | `Unexpected_eof -> "Nethttpd: Bad request: unexpected EOF" | `Format_error s -> "Nethttpd: Bad request: " ^ s let status_of_bad_request_error = function | `Request_line_too_long -> `Request_uri_too_long | `Protocol_not_supported -> `Http_version_not_supported | _ -> `Bad_request type data_chunk = string * int * int type status_line = int * string type transfer_coding = [ `Identity | `Chunked ] type resp_token = [ `Resp_info_line of (status_line * http_header) | `Resp_status_line of status_line | `Resp_header of http_header | `Resp_body of data_chunk | `Resp_trailer of http_trailer | `Resp_end | `Resp_action of (unit -> unit) ] let resp_100_continue = `Resp_info_line((100, "Continue"), new Netmime.basic_mime_header []) exception Send_queue_empty type front_token = [ `Resp_wire_data of data_chunk (* everything else *) | `Resp_end ] type front_token_x = [ front_token | `Resp_wire_action of (unit->unit) ] type front_token_opt = [ `None | front_token_x ] type resp_state = [ `Inhibited | `Queued | `Active | `Processed | `Error | `Dropped ] type announcement = [`Ignore | `Ocamlnet | `Ocamlnet_and of string | `As of string ] let string_of_front_token_x = function | `Resp_wire_data(s,pos,len) -> let n = min len 20 in sprintf "Resp_wire_data(%s%s)" (String.escaped (String.sub s pos n)) (if len > n then "..." else "") | `Resp_end -> "Resp_end" | `Resp_wire_action _ -> "Resp_wire_action" class type http_response = object method state : resp_state method set_state : resp_state -> unit method bidirectional_phase : bool method set_callback : (unit -> unit) -> unit method send : resp_token -> unit method send_queue_empty : bool method protocol : protocol method close_connection : bool method transfer_encoding : transfer_coding method front_token : front_token method advance : int -> unit method body_size : int64 end let string_of_state = function | `Inhibited -> "Inhibited" | `Queued -> "Queued" | `Active -> "Active" | `Processed -> "Processed" | `Error -> "Error" | `Dropped -> "Dropped" (* TODO: * - make http_repsonse_impl thread-safe * - implement trailers *) class http_response_impl ?(close=false) ?(suppress_body=false) fdi (req_version : protocol) (ann_server : announcement) : http_response = (* - [close]: If true, the connection will be closed after this response * - [suppress_body]: If true, the body will not be transmitted (e.g. in response * to a HEAD request) * - [req_version]: The version of the request. Used to limit features in the * response to what the client understands. *) object(self) val resp_version = `Http((1,1),[]) val mutable state = `Queued val mutable accept = (`Status : [ `Status | `Header | `Body | `End | `None] ) val mutable suppress_body = suppress_body val mutable front_token = (`None : front_token_opt) val mutable queue = (Queue.create() : front_token_x Queue.t) val mutable close_connection = close val mutable transfer_encoding = `Identity val mutable announced_content_length = None val mutable real_length = 0L (* int64 *) val mutable callback = (fun () -> ()) val mutable bidirectional_phase = false method state = state method bidirectional_phase = bidirectional_phase method set_state s = let old_state = state in state <- s; dlogr (fun () -> sprintf "FD %Ld resp-%d: response state: %s" fdi (Oo.id self) (string_of_state s)); if s <> old_state && (s = `Processed || s = `Error || s = `Dropped) then ( (* do all actions on the queue *) try while true do match Queue.take queue with | `Resp_wire_action f -> ( try f() with _ -> () ) | _ -> () done with Queue.Empty -> () ); if s <> old_state then callback(); method set_callback f = callback <- f method send tok = match tok with | `Resp_info_line((code, phrase), info_header) -> if code < 100 || code > 199 then failwith "Nethttpd_kernel.http_response: Bad informational status code"; if accept <> `Status then failwith "Nethttpd_kernel.http_response: Cannot send status line now"; ( match req_version with | `Http((1,n),_) when n >= 1 -> let s = Printf.sprintf "%s %3d %s\r\n" (string_of_protocol resp_version) code phrase in Queue.push (`Resp_wire_data (s, 0, String.length s)) queue; (* Convert the header to a data chunk: *) let b = Netbuffer.create 256 in (* Expect a short/empty header in most cases *) let ch = new Netchannels.output_netbuffer b in Netmime_string.write_header ch info_header#fields; Queue.push (`Resp_wire_data (Netbuffer.unsafe_buffer b, 0, Netbuffer.length b)) queue; Queue.push (`Resp_wire_action (fun () -> bidirectional_phase <- true)) queue | _ -> () (* Suppress this for HTTP 1.0 and lower *) ) | `Resp_status_line (code, phrase) -> if code < 200 || code > 999 then failwith "Nethttpd_kernel.http_response: Bad final status code"; if accept <> `Status then failwith "Nethttpd_kernel.http_response: Cannot send status line now"; let s = Printf.sprintf "%s %03d %s\r\n" (string_of_protocol resp_version) code phrase in Queue.push (`Resp_wire_action (fun () -> bidirectional_phase <- false)) queue; Queue.push (`Resp_wire_data (s, 0, String.length s)) queue; accept <- `Header | `Resp_header resp_header -> if accept <> `Header then failwith "Nethttpd_kernel.http_response: Cannot send header now"; (* Set [announced_content_length]: *) ( try let len = get_content_length resp_header in (* or Not_found *) announced_content_length <- Some len; dlogr (fun () -> sprintf "FD %Ld resp-%d: response anncounced_content_length=%Ld" fdi (Oo.id self) len) with Not_found -> announced_content_length <- None ); (* Update the values for [close_connection] and [transfer_encoding]: *) ( match req_version with | `Http((1,n),_) when n >= 1 -> transfer_encoding <- ( match announced_content_length with | Some _ -> `Identity | None -> `Chunked ); dlogr (fun () -> sprintf "FD %Ld resp-%d: response transfer_encoding=%s" fdi (Oo.id self) (match transfer_encoding with | `Identity -> "identity" | `Chunked -> "chunked" )) | _ -> (* Other protocol version: fall back to conservative defaults *) close_connection <- true; transfer_encoding <- `Identity; dlogr (fun () -> sprintf "FD %Ld resp-%d: response \ transfer_encoding=identity; close_connection=true" fdi (Oo.id self)) ); (* Update the header: *) ( match transfer_encoding, suppress_body with | (`Identity, false) | (_, true) -> resp_header # delete_field "Transfer-Encoding" | (`Chunked, false) -> set_transfer_encoding resp_header ["chunked", []] ); resp_header # delete_field "Trailer"; set_date resp_header (Unix.time()); ( match close_connection with | false -> resp_header # delete_field "Connection" | true -> set_connection resp_header ["close"] ); resp_header # delete_field "Upgrade"; ( match ann_server with | `Ignore -> () | `Ocamlnet -> let sh = "Ocamlnet/" ^ Netconst.ocamlnet_version in set_server resp_header sh | `Ocamlnet_and s -> let sh = s ^ " Ocamlnet/" ^ Netconst.ocamlnet_version in set_server resp_header sh | `As sh -> set_server resp_header sh ); (* Convert the header to a data chunk: *) let b = Netbuffer.create 4096 in let ch = new Netchannels.output_netbuffer b in Netmime_string.write_header ch resp_header#fields; Queue.push (`Resp_wire_data (Netbuffer.unsafe_buffer b, 0, Netbuffer.length b)) queue; (* What is accepted next: *) accept <- `Body | `Resp_body ((s,pos,len) as data) -> if accept <> `Body then failwith "Nethttpd_kernel.http_response: Cannot send body now"; if pos < 0 || len < 0 || pos + len > String.length s then invalid_arg "Nethttpd_kernel.http_response#send"; if not suppress_body then ( match transfer_encoding with | `Identity -> (* Check whether the length fits to the announced length: *) let len' = match announced_content_length with | None -> len | Some ann_len -> Int64.to_int (min (Int64.of_int len) (Int64.sub ann_len real_length)) in if len' > 0 then Queue.push (`Resp_wire_data (s,pos,len')) queue; if len > 0 && len' = 0 then dlogr (fun () -> sprintf "FD %Ld resp-%d: response warning: \ response is longer than announced" fdi (Oo.id self)) | `Chunked -> if len > 0 then ( (* Generate the chunk header: *) let u = Printf.sprintf "%x\r\n" len in Queue.push (`Resp_wire_data(u,0,String.length u)) queue; (* Output the chunk: *) Queue.push (`Resp_wire_data data) queue; (* Framing: *) Queue.push (`Resp_wire_data ("\r\n", 0, 2)) queue; ) ); real_length <- Int64.add real_length (Int64.of_int len); if real_length < 0L then (* Check for wrap around *) failwith "Nethttpd_kernel: response too long"; | `Resp_trailer resp_trailer -> if accept <> `Body then failwith "Nethttpd_kernel.http_response: Cannot send trailer now"; accept <- `End; (* trailers are ignored for now *) | `Resp_end -> if accept <> `Body && accept <> `End then failwith "Nethttpd_kernel.http_response: Cannot finish response now"; if not suppress_body then ( match transfer_encoding with | `Identity -> (* Check whether the length fits to the announced length: *) ( match announced_content_length with | None -> () | Some ann_len -> if ann_len > real_length then close_connection <- true ); Queue.push `Resp_end queue | `Chunked -> (* Add the last-chunk: *) let s = "0\r\n\r\n" in Queue.push (`Resp_wire_data(s,0,String.length s)) queue; Queue.push `Resp_end queue; ); accept <- `None | `Resp_action f -> Queue.push (`Resp_wire_action f) queue method send_queue_empty = (state = `Inhibited) || ( (front_token = `None) && (Queue.is_empty queue) ) method protocol = resp_version method close_connection = close_connection method transfer_encoding = transfer_encoding method body_size = real_length method front_token : front_token = if state = `Inhibited then raise Send_queue_empty; match front_token with | `None -> ( try let tok = Queue.take queue in front_token <- (tok :> front_token_opt); dlogr (fun () -> sprintf "FD %Ld resp-%d: response new front_token: %s" fdi (Oo.id self) (string_of_front_token_x tok)); self # front_token with Queue.Empty -> raise Send_queue_empty ) | `Resp_wire_action f -> front_token <- `None; dlogr (fun () -> sprintf "FD %Ld resp-%d: response Resp_wire_action" fdi (Oo.id self)); f(); self # front_token | #front_token as other -> other method advance n = if n > 0 then ( ignore(self # front_token); (* such that we can assert front_token <> `None *) match front_token with | `Resp_wire_data (s,pos,len) -> if n > len then invalid_arg "Nethttpd_kernel#http_response: Cannot advance past the current data chunk"; let len' = len - n in front_token <- if len'=0 then `None else `Resp_wire_data(s,pos+n,len'); if front_token = `None && Queue.is_empty queue then callback() | `Resp_end -> failwith "Nethttpd_kernel#http_response: Cannot advance past the end of the response" | `Resp_wire_action _ -> assert false | `None -> assert false ) end let send_static_response resp status hdr_opt body = let code = int_of_http_status status in let text = string_of_http_status status in let h = ( match hdr_opt with | None -> new Netmime.basic_mime_header [] | Some h -> h ) in ( try ignore(h # field "Content-Type") with Not_found -> h # update_field "Content-type" "text/html"; ); h # update_field "Content-Length" (string_of_int (String.length body)); resp # send (`Resp_status_line(code, text)); resp # send (`Resp_header h); resp # send (`Resp_body(body, 0, String.length body)); resp # send `Resp_end ;; let send_file_response resp status hdr_opt fd length = Netlog.Debug.track_fd ~owner:"Nethttpd_kernel" ~descr:"file response" fd; let hdr = match hdr_opt with | None -> new Netmime.basic_mime_header [] | Some h -> h in ( match status with | `No_content | `Reset_content | `Not_modified -> () | _ -> ( try ignore(hdr # field "Content-Type") with Not_found -> hdr # update_field "Content-type" "text/html"; ); ); hdr # update_field "Content-Length" (Int64.to_string length); let code = int_of_http_status status in let phrase = string_of_http_status status in resp # send (`Resp_status_line (code, phrase)); resp # send (`Resp_header hdr); let fd_open = ref true in let buf = String.create 8192 in let len = ref length in let rec feed() = match resp # state with | `Inhibited | `Queued -> assert false | `Active -> ( try let m = min 8192L !len in let n = Unix.read fd buf 0 (Int64.to_int m) in (* or Unix_error *) if n > 0 then ( len := Int64.sub !len (Int64.of_int n); resp # send (`Resp_body(buf, 0, n)); (* no copy of [buf]! *) resp # send (`Resp_action feed) (* loop *) ) else ( resp # send `Resp_end; fd_open := false; Netlog.Debug.release_fd fd; Unix.close fd; ) with | Unix.Unix_error((Unix.EAGAIN | Unix.EWOULDBLOCK),_,_) -> assert false (* Cannot happen when [fd] is a file! *) | Unix.Unix_error(Unix.EINTR, _, _) -> feed() ) | `Processed | `Error | `Dropped -> if !fd_open then ( Netlog.Debug.release_fd fd; Unix.close fd; ); fd_open := false in resp # send (`Resp_action feed) ;; type request_line = http_method * protocol type req_token = [ `Req_header of request_line * http_header * http_response | `Req_expect_100_continue | `Req_body of data_chunk | `Req_trailer of http_trailer | `Req_end | `Eof | `Fatal_error of fatal_error | `Bad_request_error of bad_request_error * http_response | `Timeout ] let string_of_req_token = function | `Req_header(((req_method, req_uri), req_proto), hd, resp) -> sprintf "Req_header(%s %s)" req_method req_uri | `Req_expect_100_continue -> "Req_expect_100_continue" | `Req_body(s,pos,len) -> let n = min len 20 in sprintf "Req_body_data(%s%s)" (String.escaped (String.sub s pos n)) (if len > n then "..." else "") | `Req_trailer tr -> "Req_trailer" | `Req_end -> "Req_end" | `Eof -> "Eof" | `Fatal_error e -> sprintf "Fatal_error(%s)" (string_of_fatal_error e) | `Bad_request_error (e, resp) -> sprintf "Bad_request_error(%s)" (string_of_bad_request_error e) | `Timeout -> "Timeout" exception Recv_queue_empty exception Buffer_exceeded (* Internally used by HTTP implementation: The buffer was not large enough for the * current token *) exception Timeout (* Internally used by HTTP implementation: socket blocks for too long *) exception Fatal_error of fatal_error (* Internally used by HTTP implementation: Indicate fatal error *) exception Bad_request of bad_request_error (* Internally used by HTTP implementation: Indicate bad request *) class type http_protocol_config = object method config_max_reqline_length : int method config_max_header_length : int method config_max_trailer_length : int method config_limit_pipeline_length : int method config_limit_pipeline_size : int method config_announce_server : announcement method config_suppress_broken_pipe : bool method config_tls : Netsys_crypto_types.tls_config option end class type http_protocol_hooks = object method tls_set_cache : store:(string -> string -> unit) -> remove:(string -> unit) -> retrieve:(string -> string) -> unit end let http_find_line_start s pos len = try Netmime_string.find_line_start s pos len with Not_found -> raise Buffer_exceeded let http_find_double_line_start s pos len = try Netmime_string.find_double_line_start s pos len with Not_found -> raise Buffer_exceeded let parse_req_line s pos len = (* Parses a request line: "WORD WORD WORD\r\n", where \r is optional. The words are separated by spaces or TABs. Once we did this with a regexp, but this caused stack overflows in the regexp interpreter for long request lines. Raises Not_found if not parseable. We don't support HTTP/0.9. One could recognize it easily because the third word is missing. *) let e = pos+len in let rec next_sep p = if p >= e then raise Not_found else let c = s.[p] in if c = ' ' || c = '\t' then p else if c = '\r' || c = '\n' then raise Not_found else next_sep(p+1) in let rec next_end p = if p >= e then raise Not_found else let c = s.[p] in if c = '\r' || c = '\n' then p else next_end(p+1) in let rec skip_sep p = if p >= e then raise Not_found else let c = s.[p] in if c = ' ' || c = '\t' then skip_sep(p+1) else p in let p1 = next_sep pos in let q1 = skip_sep p1 in let p2 = next_sep q1 in let q2 = skip_sep p2 in let p3 = next_end q2 in if s.[p3] = '\n' then ( if p3+1 <> e then raise Not_found ) else ( if s.[p3] <> '\r' then raise Not_found; if p3+2 <> e then raise Not_found; if s.[p3+1] <> '\n' then raise Not_found; ); let w1 = String.sub s pos (p1-pos) in let w2 = String.sub s q1 (p2-q1) in let w3 = String.sub s q2 (p3-q2) in (w1, w2, w3) let is_hex = function | '0'..'9' | 'a'..'f' | 'A'..'F' -> true | _ -> false let parse_chunk_header s pos len = (* Parses "HEXNUMBER OPTIONAL_SEMI_AND_IGNORED_EXTENSION CRLF", or raises Not_found. *) let e = pos+len in let rec skip_hex_number p = if p >= e then raise Not_found else let c = s.[p] in if is_hex c then skip_hex_number(p+1) else p in let p1 = skip_hex_number pos in if p1 = pos || p1 >= e then raise Not_found; let c1 = s.[p1] in if c1=';' then ( let p2 = Netmime_string.find_line_start s p1 (e - p1) in if p2 <> e then raise Not_found ) else ( if c1 = '\n' then ( if p1+1 <> e then raise Not_found ) else ( if p1+2 <> e then raise Not_found; if s.[p1+1] <> '\n' then raise Not_found ) ); String.sub s pos (p1-pos) type cont = [ `Continue of unit -> cont | `Restart | `Restart_with of unit -> cont ] module StrSet = Set.Make(String) class http_protocol (config : #http_protocol_config) (fd : Unix.file_descr) = let pa = Netsys_posix.create_poll_array 1 in let fdi = Netsys.int64_of_file_descr fd in let tls = match config # config_tls with | None -> None | Some tc -> Some(Netsys_tls.endpoint (Netsys_tls.create_file_endpoint ~role:`Server ~rd:fd ~wr:fd ~peer_name:None tc)) in let tls_message code = match config # config_tls with | None -> code | Some tc -> let module Config = (val tc : Netsys_crypto_types.TLS_CONFIG) in Config.TLS.error_message code in let hooks = ( object method tls_set_cache ~store ~remove ~retrieve = match tls with | None -> () | Some ep -> let module Endpoint = (val ep : Netsys_crypto_types.TLS_ENDPOINT) in Endpoint.TLS.set_session_cache ~store ~remove ~retrieve Endpoint.endpoint end ) in object(self) val mutable override_dir = None (* For TLS: can be set to [Some `R] or [Some `W] if the descriptor needs to be read or written *) val mutable tls_handshake = (tls <> None) (* Whether the handshake is not yet complete *) val mutable tls_shutdown = false (* Whether a TLS shutdown (for sending) needs to be done *) val mutable tls_shutdown_done = false (* Whether the TLS shutdown (for sending) is over *) val mutable tls_session_props = None (* Session properties (available after handshake) *) val mutable resp_queue = Queue.create() (* The queue of [http_response] objects. The first is currently being transmitted *) val mutable recv_buf = Netbuffer.create 8192 (* The buffer of received data that have not yet been tokenized *) val mutable recv_eof = false (* Whether EOF has been seen. This is also set if the protocol engine is no * longer interested in any input because of processing errors *) val mutable recv_fd_eof = false (* Whether the descriptor is at EOF. This can be different from recv_eof if TLS is active. *) val mutable recv_queue = (Queue.create() : (req_token * int) Queue.t) (* The queue of received tokens. The integer is the estimated buffer size *) val mutable recv_cont = (fun () -> `Restart) (* The continuation processing the beginning of [recv_buf] *) val mutable test_coverage = StrSet.empty (* Only used for certain tests: This set contains tokens for cases the program * ran into. *) val mutable pipeline_len = 0 val mutable recv_queue_byte_size = 0 val mutable waiting_for_next_message = true (* Updated by the input acceptor as side effect *) val mutable need_linger = true (* Whether we need a lingering close to reliable close the connection from the * server side. *) val linger_buf = String.create 256 (* A small buffer for data thrown away *) initializer ( recv_cont <- self # accept_header 0; Unix.set_nonblock fd ) method hooks = hooks method cycle ?(block=0.0) () = override_dir <- None; try (* Block until we have something to read or write *) if block <> 0.0 then self # block block; (* Maybe the TLS handshake is in progress: *) if tls_handshake then self # do_tls_handshake() else if tls_shutdown && not tls_shutdown_done then self # do_tls_shutdown() else ( (* Accept any arriving data, and process that *) self # accept_data(); (* Transmit any outgoing data *) self # transmit_response(); ) with | Fatal_error e -> dlog (sprintf "FD %Ld: fatal error" fdi); self # abort e; | Timeout -> dlog (sprintf "FD %Ld: timeout" fdi); self # timeout() | Bad_request e -> (* Stop only the input side of the engine! *) dlog (sprintf "FD %Ld: bad request" fdi); self # stop_input_acceptor(); let resp = new http_response_impl ~close:true fdi (`Http((1,0),[])) config#config_announce_server in self # push_recv (`Bad_request_error(e, resp), 0); self # push_recv (`Eof, 0); resp # set_state `Queued; (* allow response from now on *) Queue.push resp resp_queue method private block d = (* If d < 0 wait for undefinite time. If d >= 0 wait for a maximum of d seconds. * On expiration, raise [Timeout]. *) dlogr (fun () -> sprintf "FD %Ld: block %f" fdi d); let f_input = self#do_input in let f_output = self#do_output in if not f_input && not f_output then raise Timeout; Netsys_posix.set_poll_cell pa 0 { Netsys_posix.poll_fd = fd; poll_req_events = Netsys_posix.poll_req_events f_input f_output false; poll_act_events = Netsys_posix.poll_null_events() }; let t = Unix.gettimeofday() in try let n = Netsys_posix.poll pa 1 d in if n = 0 then raise Timeout; (* Check for error: *) let c = Netsys_posix.get_poll_cell pa 0 in let have_error = Netsys_posix.poll_err_result c.Netsys_posix.poll_act_events in if have_error then ( (* Now find out which error. Unfortunately, a simple Unix.read on the socket seems not to work. *) (* -- interpreting POLL_HUP here is difficult. Generally, POLL_HUP can also be set when the other side sends EOF, i.e. a harmless condition if Netsys_posix.poll_hup_result c.Netsys_posix.poll_act_events then raise(Fatal_error `Broken_pipe); *) let code = Unix.getsockopt_int fd Unix.SO_ERROR in let error = Netsys.unix_error_of_code code in let exn_arg = match error with | Unix.EPIPE | Unix.ECONNRESET -> `Broken_pipe | _ -> `Unix_error error in raise(Fatal_error exn_arg) ) with Unix.Unix_error(Unix.EINTR,_,_) -> dlog (sprintf "FD %Ld: block: EINTR" fdi); if d < 0.0 then self # block d else ( let t' = Unix.gettimeofday() in self # block (max 0.0 (t' -. t)) ) method private case name = test_coverage <- StrSet.add name test_coverage method test_coverage = StrSet.elements test_coverage (* ---- TLS ---- *) method private do_tls_handshake() = try ( match tls with | None -> () | Some t -> Netsys_tls.handshake t; tls_handshake <- false; ) with | Netsys_types.EAGAIN_RD -> dlogr (fun () -> sprintf "FD %Ld: handshake EAGAIN_RD" fdi); override_dir <- Some `R | Netsys_types.EAGAIN_WR -> dlogr (fun () -> sprintf "FD %Ld: handshake EAGAIN_WR" fdi); override_dir <- Some `W | Unix.Unix_error(Unix.EINTR,_,_) -> dlogr (fun () -> sprintf "FD %Ld: handshake EINTR" fdi) | Netsys_types.TLS_error code as e -> dlogr (fun () -> sprintf "FD %Ld: handshake TLS_ERROR %s" fdi (Netexn.to_string e)); self # abort(`TLS_error(code,tls_message code)) method private do_tls_shutdown() = try ( match tls with | None -> () | Some t -> Netsys_tls.shutdown t Unix.SHUTDOWN_SEND ); tls_shutdown_done <- true with | Netsys_types.EAGAIN_RD -> dlogr (fun () -> sprintf "FD %Ld: handshake EAGAIN_RD" fdi); override_dir <- Some `R | Netsys_types.EAGAIN_WR -> dlogr (fun () -> sprintf "FD %Ld: handshake EAGAIN_WR" fdi); override_dir <- Some `W | Unix.Unix_error(Unix.EINTR,_,_) -> dlogr (fun () -> sprintf "FD %Ld: handshake EINTR" fdi) | Netsys_types.TLS_error code as e -> dlogr (fun () -> sprintf "FD %Ld: handshake TLS_ERROR %s" fdi (Netexn.to_string e)); self # abort (`TLS_error(code,tls_message code)) (* ---- Process received data ---- *) method private stop_input_acceptor() = recv_cont <- (fun () -> `Restart); recv_eof <- true method private accept_data () = (* Check whether the socket is ready and we can receive input data. New data * are appended to [recv_buf]. Continue with the next acceptor *) let continue = try if recv_eof then ( if need_linger then ( (* Try to linger in the background... *) dlogr (fun () -> sprintf "FD %Ld: lingering for remaining input" fdi); let n = Unix.recv fd linger_buf 0 (String.length linger_buf) [] in (* or Unix_error *) if n=0 then need_linger <- false (* that's it! *) ); false (* no new data *) ) else ( if self # do_input then ( dlogr (fun () -> sprintf "FD %Ld: recv" fdi); let n = Netbuffer.add_inplace (* or Unix_error *) recv_buf (fun s pos len -> match tls with | None -> Unix.recv fd s pos len [] | Some t -> Netsys_tls.recv t s pos len ) in if n=0 then ( recv_eof <- true; ( match tls with | None -> recv_fd_eof <- true; need_linger <- false | Some t -> recv_fd_eof <- Netsys_tls.at_transport_eof t; tls_shutdown <- true; ); dlog (sprintf "FD %Ld: got EOF (fd_eof=%B)" fdi recv_fd_eof); ) else dlogr (fun () -> sprintf "FD %Ld: got %d bytes" fdi n); true ) else false (* no new data *) ) with | Unix.Unix_error((Unix.EAGAIN | Unix.EWOULDBLOCK), _,_) | Netsys_types.EAGAIN_RD -> dlogr (fun () -> sprintf "FD %Ld: recv EWOULDBLOCK" fdi); false (* socket not ready *) | Unix.Unix_error(Unix.EINTR,_,_) -> dlogr (fun () -> sprintf "FD %Ld: recv EINTR" fdi); false (* got signal *) | Unix.Unix_error(Unix.ECONNRESET, _,_) -> dlogr (fun () -> sprintf "FD %Ld: recv ECONNRESET" fdi); self # abort `Broken_pipe; false | Unix.Unix_error(e, _, _) -> dlogr (fun () -> sprintf "FD %Ld: recv ERROR" fdi); self # abort (`Unix_error e); false | Netsys_types.EAGAIN_WR -> (* Currently impossible *) assert false | Netsys_types.TLS_error code as e -> dlogr (fun () -> sprintf "FD %Ld: rev TLS_ERROR %s" fdi (Netexn.to_string e)); self # abort(`TLS_error(code,tls_message code)); false in if continue then self # accept_loop() method private accept_loop () = let next_cont_opt = recv_cont() in ( match next_cont_opt with | `Continue next_cont -> recv_cont <- next_cont; self # accept_loop() | `Restart -> () (* Stop here for now, restart the last function the next time *) | `Restart_with next_cont -> recv_cont <- next_cont (* Stop here, too, but use [next_cont] the next time *) ) (* The following methods are only called by [accept_loop] when [recv_cont] * is set to one of the methods. They return [`Continue] when another method * should continue to parse the message. They return [`Restart] when they could not * finish parsing, and they want to be called again the next time [cycle] is * invoked. They return [`Restart_with] when they could not finish parsing, * but another method is scheduled for the next time. * * Programming rules: * - When the buffer boundary is hit (often indicated by [Buffer_exceeded]), * check on EOF. Furthermore, delete processed parts from the input buffer. * * - After parts of the buffer have been deleted, [`Restart] must not be returned * (the position would be wrong) *) method private accept_header pos () : cont = (* Check whether the beginning of [recv_buf] contains the full request line and * the full header. If so, process that. If not, still check on premature EOF. *) (* (1) Skip any CRLF sequences at the beginning * (2) Check there is at least one character * (2a) Try to parse the request line * (3) Search the next occurence of CRLF CRLF * CHECK: Maybe use a faster algorithm, e.g. Knuth-Morris-Pratt * (4) Try to parse this block * (5) Create the corresponding response object, and put the token onto the queue * (6) Go on with body parsing * * If we ever hit the bounding of the buffer, raise Buffer_exceeded. This means * we don't have the header block yet. *) #ifdef Testing self # case "accept_header"; #endif waiting_for_next_message <- true; let l = Netbuffer.length recv_buf in let s = Netbuffer.unsafe_buffer recv_buf in let block_start = Netmime_string.skip_line_ends s pos (l - pos) in (* (1) *) try (* (2) *) if block_start = l || (block_start+1 = l && s.[block_start] = '\013') then ( #ifdef Testing self # case "accept_header/1"; #endif raise Buffer_exceeded; ); (* (2a) *) let reqline_end = try http_find_line_start s block_start (l - block_start) with Buffer_exceeded -> #ifdef Testing self # case "accept_header/reqline_ex"; #endif waiting_for_next_message <- false; if l-block_start > config#config_max_reqline_length then raise (Bad_request `Request_line_too_long); raise Buffer_exceeded in if reqline_end-block_start > config#config_max_reqline_length then raise (Bad_request `Request_line_too_long); waiting_for_next_message <- false; let ((meth,uri),req_version) as request_line = try let (meth, uri, proto_s) = parse_req_line s block_start (reqline_end - block_start) in (* or Not_found *) #ifdef Testing self # case "accept_header/4"; #endif let proto = protocol_of_string proto_s in ( match proto with | `Http((1,_),_) -> () | _ -> raise (Bad_request `Protocol_not_supported) ); ((meth, uri), proto) with | Not_found -> (* This is a bad request. Response should be "Bad Request" *) #ifdef Testing self # case "accept_header/3"; #endif raise (Bad_request `Bad_request_line) in (* (3) *) let config_max_header_length = config # config_max_header_length in let block_end = try http_find_double_line_start s block_start (l - block_start) with Buffer_exceeded -> #ifdef Testing self # case "accept_header/2"; #endif if l-block_start > config_max_header_length then raise (Fatal_error `Message_too_long); raise Buffer_exceeded in if block_end - block_start > config_max_header_length then raise (Fatal_error `Message_too_long); (* (4) *) (* For simplicity, we create an in_obj_channel reading the portion of the * buffer. *) let ch = new Netchannels.input_string ~pos:reqline_end ~len:(block_end - reqline_end) s in let str = new Netstream.input_stream ch in (* TODO: This is quite expensive. Create a new netstream class for cheaper access * in this case where we only read from a constant string. *) let req_h = try Netmime_channels.read_mime_header str with Failure _ -> #ifdef Testing self # case "accept_header/5"; #endif raise(Bad_request `Bad_header) in (* TLS: check whether Host header (if set) equals the SNI host name *) ( match self#tls_session_props with | None -> () | Some props -> ( try let host_hdr = match req_h # multiple_field "host" with | [ host_hdr ] -> host_hdr | [] -> raise Not_found | _ -> raise(Bad_request (`Bad_header_field "Host")) in let host_name, _ = try Nethttp.split_host_port host_hdr with _ -> host_hdr, None in if not (Nettls_support.is_endpoint_host host_name props) then raise(Bad_request (`Bad_header_field "Host")) with Not_found -> (* No "Host" header => no checks *) () ) ); (* (5) *) let close = match req_version with | `Http((1,0),_) -> false (* Ignore "Connection" header *) | `Http((1,n),_) when n >= 1 -> (try List.mem "close" (get_connection req_h) with Not_found -> false) | _ -> false in let suppress_body = (meth = "HEAD") in let resp = new http_response_impl ~close ~suppress_body fdi (snd request_line) config#config_announce_server in self # push_recv (`Req_header (request_line, req_h, resp), block_end-block_start); Queue.push resp resp_queue; (* (6) *) `Continue(self # accept_body_start meth req_version req_h resp block_end) with | Buffer_exceeded -> #ifdef Testing self # case "accept_header/exceeded"; #endif if recv_eof then ( if l = block_start then ( (* Regular EOF *) #ifdef Testing self # case "accept_header/regeof"; #endif self # push_recv (`Eof, 0); `Restart_with (fun () -> `Restart) ) else ( #ifdef Testing self # case "accept_header/eof"; #endif raise(Bad_request `Unexpected_eof) ) ) else ( #ifdef Testing self # case "accept_header/restart"; #endif Netbuffer.delete recv_buf 0 block_start; `Restart_with (self # accept_header 0) ) method private accept_body_start meth req_version req_h resp pos () : cont = (* Parse the start of the body at byte [pos] of [recv_buf]. This function * only checks the transfer encoding, and passes over to * [accept_body_identity] or [accept_body_chunked]. *) #ifdef Testing self # case "accept_body_start"; #endif let is_http_1_1 = function | `Http((1,1),_) -> true | _ -> false in try ( match req_version with | `Http((1,n),_) when n>=1 -> let expect_list = try get_expect req_h with Not_found -> [] in let rfc2068_expect = (is_http_1_1 req_version && (meth = "POST" || meth = "PUT")) in let rfc2616_expect = List.exists (fun (tok,_,_) -> tok = "100-continue") expect_list in if rfc2068_expect || rfc2616_expect then ( #ifdef Testing self # case "accept_body_start/100-continue"; #endif self # push_recv (`Req_expect_100_continue, 0); if resp#state = `Inhibited then resp # set_state `Queued (* allow response from now on *) ) | _ -> () ); let enc_list = try get_transfer_encoding req_h with Not_found -> [] in let chunked_encoding = (* The RFC talks about "non-identity transfer encoding"... *) match enc_list with [] | ["identity",_] -> false | _ -> true in if chunked_encoding then ( #ifdef Testing self # case "accept_body_start/chunked"; #endif `Continue (self # accept_body_chunked req_h resp pos) ) else ( let remaining_length = try Some(get_content_length req_h) with Not_found -> None in #ifdef Testing if remaining_length = None then self # case "accept_body_start/empty" else self # case "accept_body_start/identity"; #endif `Continue (self # accept_body_identity req_h resp pos remaining_length) ) with Bad_header_field name -> #ifdef Testing self # case "accept_body_start/bad_header_field"; #endif raise(Bad_request (`Bad_header_field name)) method private accept_body_identity req_h resp pos remaining_length () : cont = (* Accept a body with no transfer encoding. The body continues at byte [pos]. * In [remaining_length], the number of missing bytes is remembered until * the body is complete. [None] means there was neither [Content-length] nor * [Transfer-Encoding], so the body is empty (e.g. for GET). *) #ifdef Testing self # case "accept_body_identity"; #endif let l = Netbuffer.length recv_buf in match remaining_length with | Some rl -> let have_length = Int64.of_int (l - pos) in let take_length = min rl have_length in let n = Int64.to_int take_length in if n > 0 then self # push_recv (`Req_body(Netbuffer.sub recv_buf pos n, 0, n), n); let rl' = Int64.sub rl take_length in if rl' > 0L then ( #ifdef Testing self # case "accept_body_identity/exceeded"; #endif (* We hit the buffer boundary *) if recv_eof then ( (* This request was prematurely terminated by EOF. Simply drop it. *) #ifdef Testing self # case "accept_body_identity/eof"; #endif raise(Bad_request `Unexpected_eof) ) else ( (* Need to read the remaining part of the request: *) #ifdef Testing self # case "accept_body_identity/restart"; #endif Netbuffer.clear recv_buf; `Restart_with(self # accept_body_identity req_h resp 0 (Some rl')) ) ) else ( (* This was the last part of the message. *) #ifdef Testing self # case "accept_body_identity/last"; #endif self # push_recv (`Req_end, 0); pipeline_len <- pipeline_len + 1; if resp#state = `Inhibited then resp # set_state `Queued; (* allow response from now on *) `Continue(self # accept_header (pos+n)) ) | None -> self # push_recv (`Req_end, 0); pipeline_len <- pipeline_len + 1; if resp#state = `Inhibited then resp # set_state `Queued; (* allow response from now on *) `Continue(self # accept_header pos) method private accept_body_chunked req_h resp pos () : cont = (* Check for a chunk header at byte position [pos]. If complete, parse the number of * bytes the chunk will consist of, and continue with [accept_body_chunked_contents]. *) #ifdef Testing self # case "accept_body_chunked"; #endif let l = Netbuffer.length recv_buf in let s = Netbuffer.unsafe_buffer recv_buf in try let p = http_find_line_start s pos (l - pos) (* or Buffer_exceeded *) in ( try let hex_digits = parse_chunk_header s pos (p-pos) in (* or Not_found *) let chunk_length = try int_of_string("0x" ^ hex_digits) with Failure _ -> #ifdef Testing self # case "accept_body_chunked/ch_large"; #endif raise(Bad_request (`Format_error "Chunk too large")) in (* Continue with chunk data or chunk end *) if chunk_length > 0 then ( #ifdef Testing self # case "accept_body_chunked/go_on"; #endif `Continue(self # accept_body_chunked_contents req_h resp p chunk_length) ) else ( #ifdef Testing self # case "accept_body_chunked/end"; #endif `Continue(self # accept_body_chunked_end req_h resp p) ) with Not_found -> #ifdef Testing self # case "accept_body_chunked/invalid_ch"; #endif raise(Bad_request (`Format_error "Invalid chunk")) ) with | Buffer_exceeded -> #ifdef Testing self # case "accept_body_chunked/exceeded"; #endif if recv_eof then ( #ifdef Testing self # case "accept_body_chunked/eof"; #endif raise(Bad_request `Unexpected_eof); ); if pos > 0 then Netbuffer.delete recv_buf 0 pos; if Netbuffer.length recv_buf > 500 then ( #ifdef Testing self # case "accept_body_chunked/ch_hdr_large"; #endif raise(Bad_request (`Format_error "Chunk header too large")); ); #ifdef Testing self # case "accept_body_chunked/restart"; #endif `Restart_with(self # accept_body_chunked req_h resp 0) method private accept_body_chunked_contents req_h resp pos remaining_length () : cont = (* Read the chunk body at [pos], at most [remaining_length] bytes *) #ifdef Testing self # case "accept_body_chunked_contents"; #endif let l = Netbuffer.length recv_buf in let s = Netbuffer.unsafe_buffer recv_buf in if remaining_length > 0 then ( #ifdef Testing self # case "accept_body_chunked_contents/data"; #endif (* There are still data to read *) let have_length = l - pos in let take_length = min have_length remaining_length in let rem_length' = remaining_length - take_length in if take_length > 0 then self # push_recv (`Req_body(Netbuffer.sub recv_buf pos take_length, 0, take_length), take_length); if take_length = remaining_length then `Continue (self # accept_body_chunked_contents req_h resp (pos+take_length) 0) else ( #ifdef Testing self # case "accept_body_chunked_contents/exceeded"; #endif (* We hit the buffer boundary. Delete the buffer *) if recv_eof then ( #ifdef Testing self # case "accept_body_chunked_contents/eof"; #endif raise(Bad_request `Unexpected_eof); ); Netbuffer.clear recv_buf; #ifdef Testing self # case "accept_body_chunked_contents/restart"; #endif `Restart_with (self # accept_body_chunked_contents req_h resp 0 rem_length') ) ) else ( #ifdef Testing self # case "accept_body_chunked_contents/end"; #endif (* End of chunk reached. There must a (single) line end at the end of the chunk *) if (l > pos && s.[pos] = '\010') then ( #ifdef Testing self # case "accept_body_chunked_contents/lf"; #endif `Continue(self # accept_body_chunked req_h resp (pos+1)) ) else if (l > pos+1 && s.[pos] = '\013' && s.[pos+1] = '\010') then ( #ifdef Testing self # case "accept_body_chunked_contents/crlf"; #endif `Continue(self # accept_body_chunked req_h resp (pos+2)) ) else if l > pos+1 then ( #ifdef Testing self # case "accept_body_chunked_contents/no_eol"; #endif raise (Bad_request (`Format_error "Chunk not followed by line terminator")) ) else ( #ifdef Testing self # case "accept_body_chunked_contents/e_exceeded"; #endif (* We hit the buffer boundary *) if recv_eof then ( #ifdef Testing self # case "accept_body_chunked_contents/e_eof"; #endif raise(Bad_request `Unexpected_eof); ); Netbuffer.delete recv_buf 0 pos; #ifdef Testing self # case "accept_body_chunked_contents/e_restart"; #endif `Restart_with (self # accept_body_chunked_contents req_h resp 0 0) ) ) method private accept_body_chunked_end req_h resp pos () : cont = (* Read the trailer *) #ifdef Testing self # case "accept_body_chunked_end"; #endif let l = Netbuffer.length recv_buf in let s = Netbuffer.unsafe_buffer recv_buf in let config_max_trailer_length = max 2 (config # config_max_trailer_length) in try (* Check if there is a trailer *) if l > pos && s.[pos] = '\010' then ( #ifdef Testing self # case "accept_body_chunked_end/lf"; #endif self # push_recv (`Req_end, 0); pipeline_len <- pipeline_len + 1; if resp#state = `Inhibited then resp # set_state `Queued; (* allow response from now on *) `Continue(self # accept_header (pos+1)) ) else if l > pos+1 && s.[pos] = '\013' && s.[pos+1] = '\010' then ( #ifdef Testing self # case "accept_body_chunked_end/crlf"; #endif self # push_recv (`Req_end, 0); pipeline_len <- pipeline_len + 1; if resp#state = `Inhibited then resp # set_state `Queued; (* allow response from now on *) `Continue(self # accept_header (pos+2)) ) else ( #ifdef Testing self # case "accept_body_chunked_end/trailer"; #endif (* Assume there is a trailer. *) let trailer_end = http_find_double_line_start s pos (l-pos) in (* or Buf_exceeded *) #ifdef Testing self # case "accept_body_chunked_end/tr_found"; #endif (* Now we are sure there is a trailer! *) if trailer_end - pos > config_max_trailer_length then ( #ifdef Testing self # case "accept_body_chunked_end/tr_long"; #endif raise(Bad_request (`Format_error "Trailer too long")); ); let ch = new Netchannels.input_string ~pos:pos ~len:(trailer_end - pos) s in let str = new Netstream.input_stream ch in let req_tr = try Netmime_channels.read_mime_header str with Failure _ -> #ifdef Testing self # case "accept_body_chunked_end/bad_tr"; #endif raise(Bad_request `Bad_trailer) in self # push_recv (`Req_trailer req_tr, trailer_end-pos); self # push_recv (`Req_end, 0); pipeline_len <- pipeline_len + 1; if resp#state = `Inhibited then resp # set_state `Queued; (* allow response from now on *) `Continue(self # accept_header trailer_end) ) with Buffer_exceeded -> #ifdef Testing self # case "accept_body_chunked_end/exceeded"; #endif if recv_eof then ( #ifdef Testing self # case "accept_body_chunked_end/eof"; #endif raise(Bad_request `Unexpected_eof); ); if l-pos > config_max_trailer_length then ( #ifdef Testing self # case "accept_body_chunked_end/tr_long"; #endif raise(Bad_request (`Format_error "Trailer too long")); ); Netbuffer.delete recv_buf 0 pos; #ifdef Testing self # case "accept_body_chunked_end/restart"; #endif `Restart_with (self # accept_body_chunked_end req_h resp 0) (* ---- Process responses ---- *) method resp_queue_len = Queue.length resp_queue method private transmit_response() = (* Try to transmit the response. Do nothing if the socket is not ready for * transmission. If a fatal error happens, the connection is aborted. *) if not (Queue.is_empty resp_queue) then let resp = Queue.peek resp_queue in try resp # set_state `Active; match resp # front_token with (* or Send_queue_empty, Unix_error *) | `Resp_wire_data (s,pos,len) -> (* Try to write: *) dlogr (fun () -> sprintf "FD %Ld: send" fdi); let n = match tls with | None -> Unix.send fd s pos len [] (* or Unix.error *) | Some t -> Netsys_tls.send t s pos len in dlogr (fun () -> sprintf "FD %Ld: sent %d bytes" fdi n); (* Successful. Advance by [n] *) resp # advance n; | `Resp_end -> dlogr (fun () -> sprintf "FD %Ld: found Resp_end in the queue" fdi); pipeline_len <- pipeline_len - 1; resp # set_state `Processed; (* Check if we have to close the connection: *) if resp # close_connection then ( match tls with | None -> self # shutdown_sending() | Some _ -> tls_shutdown <- true ); (* Continue with the next response, if any, and if possible: *) let next_resp = Queue.take resp_queue in next_resp # set_state `Active; (* ... unless dropped *) (* If the queue is still non-empty, and if the connection is closed, * the remaining, already computed responses, cannot be sent at all. * Drop the responses in this case. *) if resp # close_connection then self # drop_remaining_responses() with | Send_queue_empty -> () (* nothing to do *) | Unix.Unix_error((Unix.EAGAIN | Unix.EWOULDBLOCK),_,_) | Netsys_types.EAGAIN_WR -> () (* socket not ready *) | Unix.Unix_error(Unix.EINTR,_,_) -> () (* Signal happened, try later again *) | Unix.Unix_error((Unix.EPIPE | Unix.ECONNRESET), _,_) -> resp # set_state `Error; ignore(Queue.take resp_queue); self # abort `Broken_pipe | Unix.Unix_error(e, _, _) -> resp # set_state `Error; ignore(Queue.take resp_queue); self # abort (`Unix_error e) | Netsys_types.EAGAIN_RD -> assert false (* not possible here *) | Netsys_types.TLS_error code as e -> dlogr (fun () -> sprintf "FD %Ld: rev TLS_ERROR %s" fdi (Netexn.to_string e)); self # abort (`TLS_error(code,tls_message code)) method private drop_remaining_responses() = (* Set the state to [`Dropped] for all responses in the [resp_queue]: *) Queue.iter (fun resp -> resp # set_state `Dropped) resp_queue; Queue.clear resp_queue (* ---- Queue management ---- *) method receive () = try let (tok, size) = Queue.take recv_queue in recv_queue_byte_size <- recv_queue_byte_size - size; tok with Queue.Empty -> raise Recv_queue_empty method peek_recv () = try fst(Queue.peek recv_queue) with Queue.Empty -> raise Recv_queue_empty method private push_recv ( (token,size) as qelem ) = Queue.push qelem recv_queue; recv_queue_byte_size <- recv_queue_byte_size + size method recv_queue_len = Queue.length recv_queue method recv_queue_byte_size = recv_queue_byte_size (* ---- Socket stuff ---- *) val mutable fd_down = false method shutdown () = (* The shutdown issue is discussed here: * http://ftp.ics.uci.edu/pub/ietf/http/draft-ietf-http-connection-00.txt * * Recommendation: * - Only shutdown for sending * - Keep receiving data for a while ("lingering"). In principle, until the client has * seen the half shutdown, but we do not know when. Apache lingers for 30 seconds. *) dlogr (fun () -> sprintf "FD %Ld: shutdown" fdi); self # stop_input_acceptor(); self # drop_remaining_responses(); tls_shutdown <- true; if tls = None then self # shutdown_sending() method private shutdown_sending() = dlogr (fun () -> sprintf "FD %Ld: shutdown_sending" fdi); if not fd_down then ( try Unix.shutdown fd Unix.SHUTDOWN_SEND; with Unix.Unix_error(Unix.ENOTCONN,_,_) -> need_linger <- false (* the peer has already shut down in the meantime *) ); fd_down <- true method waiting_for_next_message = waiting_for_next_message method input_timeout_class : [ `Normal | `Next_message | `None ] = (* Do we have an active response object? In this case, it might be the case the * connection is output-driven, and no timeout applies: *) try let first = Queue.peek resp_queue in (* or Queue.Empty *) if first#state <> `Active then raise Queue.Empty; (* If the response object is in the bidirectional phase, the normal input * timeout applies nevertheless. *) if first#bidirectional_phase then `Normal else `None with Queue.Empty -> if waiting_for_next_message then `Next_message else `Normal method timeout () = if waiting_for_next_message then ( (* Indicate a "soft" timeout. Processing is nevertheless similar to [abort]: *) dlogr (fun () -> sprintf "FD %Ld: soft timeout" fdi); need_linger <- false; self # shutdown(); self # push_recv (`Timeout, 0); self # push_recv (`Eof, 0); ) else self # abort `Timeout (* "hard" timeout *) method abort (err : fatal_error) = dlogr (fun () -> sprintf "FD %Ld: abort %s" fdi (string_of_fatal_error err)); need_linger <- false; self # shutdown(); tls_shutdown_done <- true; (* don't do the TLS shutdown protocol *) let err' = if err=`Broken_pipe && config#config_suppress_broken_pipe then `Broken_pipe_ignore else err in self # push_recv (`Fatal_error err', 0); self # push_recv (`Eof, 0) method fd = fd method config = (config :> http_protocol_config) method pipeline_len = pipeline_len method do_input = match override_dir with | None -> not recv_fd_eof && (* CHECK: can we get further alerts after TLS END ? *) (pipeline_len <= config#config_limit_pipeline_length) && (recv_queue_byte_size <= config#config_limit_pipeline_size) | Some `R -> true | Some `W -> false method do_output = match override_dir with | None -> self # resp_queue_filled | Some `R -> false | Some `W -> true method resp_queue_filled = ( not (Queue.is_empty resp_queue) && not ((Queue.peek resp_queue) # send_queue_empty) ) || ( tls_shutdown && not tls_shutdown_done ) method need_linger = (* TLS: never need lingering close, as TLS alerts indicate the end of the data stream *) tls = None && need_linger method tls_session_props = if not tls_handshake && tls_session_props = None then ( match tls with | Some t -> tls_session_props <- Some (Nettls_support.get_tls_session_props t) | None -> () ); tls_session_props end class lingering_close ?(preclose = fun () -> ()) fd = let fd_style = Netsys.get_fd_style fd in object(self) val start_time = Unix.gettimeofday() val timeout = 60.0 val junk_buffer = String.create 256 val mutable lingering = true method cycle ?(block=false) () = try if block then ( let now = Unix.gettimeofday() in let sel_time = max (timeout -. (now -. start_time)) 0.0 in if sel_time = 0.0 then raise Not_found; ignore(Netsys.wait_until_readable fd_style fd sel_time); () ); let n = Unix.recv fd junk_buffer 0 (String.length junk_buffer) [] in if n = 0 then ( lingering <- false; preclose(); Unix.close fd ) with | Unix.Unix_error((Unix.EAGAIN | Unix.EWOULDBLOCK), _,_)-> () (* socket not ready *) | Unix.Unix_error(Unix.EINTR,_,_) -> () (* got signal *) | Unix.Unix_error(_, _,_) | Not_found -> (* Any other error means we are done! *) lingering <- false; preclose(); Unix.close fd method lingering = lingering method fd = fd end let default_http_protocol_config = ( object method config_max_reqline_length = 32768 method config_max_header_length = 65536 method config_max_trailer_length = 32768 method config_limit_pipeline_length = 5 method config_limit_pipeline_size = 65536 method config_announce_server = `Ocamlnet method config_suppress_broken_pipe = false method config_tls = None end ) let override v opt = match opt with | None -> v | Some x -> x class modify_http_protocol_config ?config_max_reqline_length ?config_max_header_length ?config_max_trailer_length ?config_limit_pipeline_length ?config_limit_pipeline_size ?config_announce_server ?config_suppress_broken_pipe ?config_tls (config : http_protocol_config) : http_protocol_config = let config_max_reqline_length = override config#config_max_reqline_length config_max_reqline_length in let config_max_header_length = override config#config_max_header_length config_max_header_length in let config_max_trailer_length = override config#config_max_trailer_length config_max_trailer_length in let config_limit_pipeline_length = override config#config_limit_pipeline_length config_limit_pipeline_length in let config_limit_pipeline_size = override config#config_limit_pipeline_size config_limit_pipeline_size in let config_announce_server = override config#config_announce_server config_announce_server in let config_suppress_broken_pipe = override config#config_suppress_broken_pipe config_suppress_broken_pipe in let config_tls = override config#config_tls config_tls in object method config_max_reqline_length = config_max_reqline_length method config_max_header_length = config_max_header_length method config_max_trailer_length = config_max_trailer_length method config_limit_pipeline_length = config_limit_pipeline_length method config_limit_pipeline_size = config_limit_pipeline_size method config_announce_server = config_announce_server method config_suppress_broken_pipe = config_suppress_broken_pipe method config_tls = config_tls end