(* $Id: netmech_scram.ml 1954 2014-03-03 09:40:56Z gerd $ *) (* Steps: client <-> server ---------------------------------------------------------------------- username, nonce -> <- salt, i, nonce' clientproof, nonce' -> (=algo(password, salt, i)) <- serversignature *) open Printf type ptype = [ `GSSAPI ] type mechanism = [ `SHA_1 ] type profile = { ptype : ptype; mechanism : mechanism; return_unknown_user : bool; iteration_count_limit : int; } type client_first = (* actually client_first_bare *) { c1_username : string; (* "=xx" encoding not yet applied *) c1_nonce : string; (* anything but comma *) c1_extensions : (string * string) list } type server_first = { s1_nonce : string; (* anything but comma *) s1_salt : string; (* decoded *) s1_iteration_count : int; s1_extensions : (string * string) list } type client_final = { cf_chanbind : string; cf_nonce : string; (* anything but comma *) cf_extensions : (string * string) list; cf_proof : string option; (* decoded *) } type server_error = [ `Invalid_encoding | `Extensions_not_supported | `Invalid_proof | `Channel_bindings_dont_match | `Server_does_support_channel_binding | `Channel_binding_not_supported | `Unsupported_channel_binding_type | `Unknown_user | `Invalid_username_encoding | `No_resources | `Other_error | `Extension of string ] type server_error_or_verifier = [ `Error of server_error | `Verifier of string ] type server_final = { sf_error_or_verifier : server_error_or_verifier; sf_extensions : (string * string) list; } type specific_keys = { kc : string; ke : string; ki : string } type client_session = { cs_profile : profile; mutable cs_state : [ `Start | `C1 | `S1 | `CF | `SF | `Connected | `Error ]; mutable cs_c1 : client_first option; mutable cs_s1 : server_first option; mutable cs_s1_raw : string; mutable cs_cf : client_final option; mutable cs_sf : server_final option; mutable cs_salted_pw : string; mutable cs_auth_message : string; mutable cs_proto_key : string option; cs_username : string; cs_password : string; mutable cs_chanbind : string; } type server_session = { ss_profile : profile; mutable ss_state : [ `Start | `C1 | `S1 | `CF | `SF | `Connected | `Error ]; mutable ss_c1 : client_first option; mutable ss_c1_raw : string; mutable ss_s1 : server_first option; mutable ss_s1_raw : string; mutable ss_cf : client_final option; mutable ss_cf_raw : string; mutable ss_sf : server_final option; mutable ss_spw: string option; mutable ss_err : server_error option; mutable ss_proto_key : string option; ss_authenticate_opt : (string -> (string * string * int)) option; } (* Exported: *) exception Invalid_encoding of string * string exception Invalid_username_encoding of string * string exception Extensions_not_supported of string * string exception Protocol_error of string exception Invalid_server_signature exception Server_error of server_error (* Not exported: *) exception Invalid_proof of string module Debug = struct let enable = ref false end let dlog = Netlog.Debug.mk_dlog "Netmech_scram" Debug.enable let dlogr = Netlog.Debug.mk_dlogr "Netmech_scram" Debug.enable let () = Netlog.Debug.register_module "Netmech_scram" Debug.enable let profile ?(return_unknown_user=false) ?(iteration_count_limit=100000) pt = { ptype = pt; mechanism = `SHA_1; return_unknown_user = return_unknown_user; iteration_count_limit = iteration_count_limit; } let saslprep s = (* We do not want to implement SASLprep here. It's brainf*ck, because the ambiguities it resolves do not occur in practice (people are not as dumb as the Unicode guys think). The RFC says we have to limit the strings then to US-ASCII. *) for k = 0 to String.length s - 1 do let c = s.[k] in if c < '\x20' || c >= '\x7f' then raise(Invalid_encoding("Netmech_scram.saslprep: restricted to US-ASCII", s)); done; s let username_saslprep s = try saslprep s with | Invalid_encoding(s1,s2) -> raise(Invalid_username_encoding(s1,s2)) let comma_re = Netstring_str.regexp "," let comma_split s = Netstring_str.split_delim comma_re s let n_value_re = Netstring_str.regexp "\\([a-zA-Z]\\)=\\(.*\\)" let n_value_split s = match Netstring_str.string_match n_value_re s 0 with | None -> raise (Invalid_encoding("n_value_split", s)) | Some r -> (Netstring_str.matched_group r 1 s, Netstring_str.matched_group r 2 s) let check_value_safe_chars s = let enc = `Enc_subset(`Enc_utf8, fun i -> i <> 0 && i <> 0x2c && i <> 0x3d) in try Netconversion.verify enc s with _ -> raise(Invalid_encoding("check_value_safe_chars",s)) let check_value_chars s = let enc = `Enc_subset(`Enc_utf8, fun i -> i <> 0 && i <> 0x2c) in try Netconversion.verify enc s with _ -> raise(Invalid_encoding("check_value_chars",s)) let check_printable s = for i = 0 to String.length s - 1 do match s.[i] with | '\x21'..'\x2b' -> () | '\x2d'..'\x7e' -> () | _ -> raise(Invalid_encoding("check_printable",s)) done let pos_re = Netstring_str.regexp "[1-9][0-9]+$" let check_positive_number s = match Netstring_str.string_match pos_re s 0 with | None -> raise(Invalid_encoding("check_positive_number",s)) | Some _ -> () let comma_slash_re = Netstring_str.regexp "[,/]" let rev_comma_slash_re = Netstring_str.regexp "\\(=2C\\|=3D\\|=\\|,\\)" let encode_saslname s = ( try Netconversion.verify `Enc_utf8 s with _ -> raise(Invalid_username_encoding("encode_saslname",s)) ); Netstring_str.global_substitute comma_slash_re (fun r s -> match Netstring_str.matched_string r s with | "," -> "=2C" | "/" -> "=3D" | _ -> assert false ) s let decode_saslname s = let s' = Netstring_str.global_substitute rev_comma_slash_re (fun r s -> match Netstring_str.matched_string r s with | "=2C" -> "," | "=3D" -> "/" | "=" | "," -> raise(Invalid_username_encoding("decode_saslname",s)) | _ -> assert false ) s in ( try Netconversion.verify `Enc_utf8 s' with _ -> raise(Invalid_username_encoding("decode_saslname",s)) ); s' let encode_c1_message c1 = (* No gs2-header in GSS-API *) "n=" ^ encode_saslname(username_saslprep c1.c1_username) ^ ",r=" ^ c1.c1_nonce ^ (if c1.c1_extensions <> [] then "," ^ String.concat "," (List.map (fun (n,v) -> n ^ "=" ^ v) c1.c1_extensions) else "" ) let decode_c1_message s = let l = List.map n_value_split (comma_split s) in match l with | [] -> raise(Invalid_encoding("decode_c1_mesage: empty", s)) | ("m",_) :: _ -> raise(Extensions_not_supported("decode_c1_mesage: unsupported", s)) | ("n", username_raw) :: ("r", nonce) :: l' -> let username = decode_saslname username_raw in let username' = username_saslprep username in if username <> username' then raise(Invalid_username_encoding("Netmech_scram.decode_c1_message", s)); { c1_username = username; c1_nonce = nonce; c1_extensions = l' } | _ -> raise(Invalid_encoding("decode_c1_mesage", s)) let encode_s1_message s1 = "r=" ^ s1.s1_nonce ^ ",s=" ^ Netencoding.Base64.encode s1.s1_salt ^ ",i=" ^ string_of_int s1.s1_iteration_count ^ ( if s1.s1_extensions <> [] then "," ^ String.concat "," (List.map (fun (n,v) -> n ^ "=" ^ v) s1.s1_extensions) else "" ) let decode_s1_message s = let l = List.map n_value_split (comma_split s) in match l with | [] -> raise(Invalid_encoding("decode_s1_mesage: empty", s)) | ("m",_) :: _ -> raise(Extensions_not_supported("decode_s1_mesage: unsupported", s)) | ("r",nonce) :: ("s",salt_b64) :: ("i",icount_raw) :: l' -> let salt = try Netencoding.Base64.decode salt_b64 with _ -> raise(Invalid_encoding("decode_s1_message: invalid s", s)) in check_positive_number icount_raw; let icount = try int_of_string icount_raw with _ -> raise(Invalid_encoding("decode_s1_message: invalid i", s)) in { s1_nonce = nonce; s1_salt = salt; s1_iteration_count = icount; s1_extensions = l' } | _ -> raise(Invalid_encoding("decode_s1_mesage", s)) (* About the inclusion of "c": RFC 5802 is not entirely clear about this. I asked the authors of the RFC what to do. The idea is that the GSS-API flavor of SCRAM is obtained by removing the GS2 (RFC 5801) part from the description in RFC 5802 for SASL. This leads to the interpretation that the "c" parameter is required, and it includes the channel binding string as-is, without any prefixed gs2-header. (Remember that GS2 is a wrapper around GSS-API, and it can then pass the right channel binding string down, i.e. a string that includes the gs2-header.) *) let encode_cf_message cf = "c=" ^ Netencoding.Base64.encode cf.cf_chanbind ^ ",r=" ^ cf.cf_nonce ^ ( if cf.cf_extensions <> [] then "," ^ String.concat "," (List.map (fun (n,v) -> n ^ "=" ^ v) cf.cf_extensions) else "" ) ^ ( match cf.cf_proof with | None -> "" | Some p -> ",p=" ^ Netencoding.Base64.encode p ) let decode_cf_message expect_proof s = let l = List.map n_value_split (comma_split s) in match l with | [] -> raise(Invalid_encoding("decode_cf_mesage: empty", s)) | ("c",chanbind_b64) :: ("r",nonce) :: l' -> let chanbind = try Netencoding.Base64.decode chanbind_b64 with _ -> raise(Invalid_encoding("decode_cf_mesage: invalid c", s)) in let p, l'' = if expect_proof then match List.rev l' with | ("p", proof_b64) :: l''_rev -> let p = try Netencoding.Base64.decode proof_b64 with _ -> raise(Invalid_encoding("decode_cf_mesage: invalid p", s)) in (Some p, List.rev l''_rev) | _ -> raise(Invalid_encoding("decode_cf_mesage: proof not found", s)) else None, l' in { cf_chanbind = chanbind; cf_nonce = nonce; cf_extensions = l''; cf_proof = p } | _ -> raise(Invalid_encoding("decode_cf_mesage", s)) let strip_cf_proof s = let l = List.rev (List.map n_value_split (comma_split s)) in match l with | ("p",_) :: l' -> String.concat "," (List.map (fun (n,v) -> n ^ "=" ^ v) (List.rev l')) | _ -> assert false let string_of_server_error = function | `Invalid_encoding -> "invalid-encoding" | `Extensions_not_supported -> "extensions-not-supported" | `Invalid_proof -> "invalid-proof" | `Channel_bindings_dont_match -> "channel-bindings-dont-match" | `Server_does_support_channel_binding -> "server-does-support-channel-binding" | `Channel_binding_not_supported -> "channel-binding-not-supported" | `Unsupported_channel_binding_type -> "unsupported-channel-binding-type" | `Unknown_user -> "unknown-user" | `Invalid_username_encoding -> "invalid-username-encoding" | `No_resources -> "no-resources" | `Other_error -> "other-error" | `Extension s -> s let server_error_of_string = function | "invalid-encoding" -> `Invalid_encoding | "extensions-not-supported" -> `Extensions_not_supported | "invalid-proof" -> `Invalid_proof | "channel-bindings-dont-match" -> `Channel_bindings_dont_match | "server-does-support-channel-binding" -> `Server_does_support_channel_binding | "channel-binding-not-supported" -> `Channel_binding_not_supported | "unsupported-channel-binding-type" -> `Unsupported_channel_binding_type | "unknown-user" -> `Unknown_user | "invalid-username-encoding" -> `Invalid_username_encoding | "no-resources" -> `No_resources | "other-error" -> `Other_error | s -> `Extension s let () = Netexn.register_printer (Server_error `Invalid_encoding) (fun e -> match e with | Server_error token -> sprintf "Server_error(%s)" (string_of_server_error token) | _ -> assert false ) let encode_sf_message sf = ( match sf.sf_error_or_verifier with | `Error e -> "e=" ^ string_of_server_error e | `Verifier v -> "v=" ^ Netencoding.Base64.encode v ) ^ ( if sf.sf_extensions <> [] then "," ^ String.concat "," (List.map (fun (n,v) -> n ^ "=" ^ v) sf.sf_extensions) else "" ) let decode_sf_message s = let l = List.map n_value_split (comma_split s) in match l with | [] -> raise(Invalid_encoding("decode_cf_mesage: empty", s)) | ("v",verf_raw) :: l' -> let verf = try Netencoding.Base64.decode verf_raw with _ -> raise(Invalid_encoding("decode_sf_message: invalid v", s)) in { sf_error_or_verifier = `Verifier verf; sf_extensions = l' } | ("e",error_s) :: l' -> let error = server_error_of_string error_s in { sf_error_or_verifier = `Error error; sf_extensions = l' } | _ -> raise(Invalid_encoding("decode_sf_mesage", s)) let sha1 s = Cryptokit.hash_string (Cryptokit.Hash.sha1()) s let hmac key str = Netauth.hmac ~h:sha1 ~b:64 ~l:20 ~k:key ~message:str let int_s i = let s = String.make 4 '\000' in s.[0] <- Char.chr ((i lsr 24) land 0xff); s.[1] <- Char.chr ((i lsr 16) land 0xff); s.[2] <- Char.chr ((i lsr 8) land 0xff); s.[3] <- Char.chr (i land 0xff); s let hi str salt i = let rec uk k = if k=1 then let u = hmac str (salt ^ int_s 1) in let h = u in (u,h) else ( let (u_pred, h_pred) = uk (k-1) in let u = hmac str u_pred in let h = Netauth.xor_s u h_pred in (u,h) ) in snd (uk i) let lsb128 s = (* The least-significant 128 bits *) let l = String.length s in if l < 16 then failwith "Netmech_scram.lsb128"; String.sub s (l-16) 16 let create_nonce() = let s = String.make 16 ' ' in Netsys_rng.fill_random s; Digest.to_hex s let create_salt = create_nonce let create_client_session profile username password = ignore(saslprep username); ignore(saslprep password); (* Check for errors *) { cs_profile = profile; cs_state = `Start; cs_c1 = None; cs_s1 = None; cs_s1_raw = ""; cs_cf = None; cs_sf = None; cs_auth_message = ""; cs_salted_pw = ""; cs_username = username; cs_password = password; cs_proto_key = None; cs_chanbind = ""; } let client_emit_flag cs = match cs.cs_state with | `Start | `S1 -> true | _ -> false let client_recv_flag cs = match cs.cs_state with | `C1 | `CF -> true | _ -> false let client_finish_flag cs = cs.cs_state = `Connected let client_error_flag cs = cs.cs_state = `Error let catch_error cs f arg = try f arg with | error -> dlog (sprintf "Client caught error: %s" (Netexn.to_string error)); cs.cs_state <- `Error; raise error let client_protocol_key cs = cs.cs_proto_key let client_user_name cs = cs.cs_username let client_configure_channel_binding cs cb = ( match cs.cs_state with | `Start | `C1 | `S1 -> () | _ -> failwith "Netmech_scram.client_configure_channel_binding" ); cs.cs_chanbind <- cb let client_channel_binding cs = cs.cs_chanbind let client_export cs = if not (client_finish_flag cs) then failwith "Netmech_scram.client_export: context not yet established"; Marshal.to_string cs [] let client_import s = ( Marshal.from_string s 0 : client_session) let salt_password password salt iteration_count = let sp = hi (saslprep password) salt iteration_count in (* eprintf "salt_password(%S,%S,%d) = %S\n" password salt iteration_count sp; *) sp let client_emit_message cs = catch_error cs (fun () -> match cs.cs_state with | `Start -> let c1 = { c1_username = cs.cs_username; c1_nonce = create_nonce(); c1_extensions = [] } in cs.cs_c1 <- Some c1; cs.cs_state <- `C1; let m = encode_c1_message c1 in dlog (sprintf "Client state `Start emitting message: %s" m); m | `S1 -> let c1 = match cs.cs_c1 with None -> assert false | Some c1 -> c1 in let s1 = match cs.cs_s1 with None -> assert false | Some s1 -> s1 in let salted_pw = salt_password cs.cs_password s1.s1_salt s1.s1_iteration_count in let client_key = hmac salted_pw "Client Key" in let stored_key = sha1 client_key in let cf_no_proof = encode_cf_message { cf_chanbind = cs.cs_chanbind; cf_nonce = s1.s1_nonce; cf_extensions = []; cf_proof = None } in let auth_message = encode_c1_message c1 ^ "," ^ cs.cs_s1_raw ^ "," ^ cf_no_proof in let client_signature = hmac stored_key auth_message in let p = Netauth.xor_s client_key client_signature in let cf = { cf_chanbind = cs.cs_chanbind; cf_nonce = s1.s1_nonce; cf_extensions = []; cf_proof = Some p; } in cs.cs_cf <- Some cf; cs.cs_state <- `CF; cs.cs_auth_message <- auth_message; cs.cs_salted_pw <- salted_pw; cs.cs_proto_key <- Some ( lsb128 (hmac stored_key ("GSS-API session key" ^ client_key ^ auth_message))); let m = encode_cf_message cf in dlog (sprintf "Client state `S1 emitting message: %s" m); m | _ -> failwith "Netmech_scram.client_emit_message" ) () let client_recv_message cs message = catch_error cs (fun () -> match cs.cs_state with | `C1 -> dlog (sprintf "Client state `C1 receiving message: %s" message); let s1 = decode_s1_message message in let c1 = match cs.cs_c1 with None -> assert false | Some c1 -> c1 in if String.length s1.s1_nonce < String.length c1.c1_nonce then raise (Protocol_error "client_recv_message: Nonce from the server is too short"); if String.sub s1.s1_nonce 0 (String.length c1.c1_nonce) <> c1.c1_nonce then raise (Protocol_error "client_recv_message: bad nonce from the server"); if s1.s1_iteration_count > cs.cs_profile.iteration_count_limit then raise (Protocol_error "client_recv_message: iteration count too high"); cs.cs_state <- `S1; cs.cs_s1 <- Some s1; cs.cs_s1_raw <- message | `CF -> dlog (sprintf "Client state `CF receiving message: %s" message); let sf = decode_sf_message message in ( match sf.sf_error_or_verifier with | `Verifier v -> let salted_pw = cs.cs_salted_pw in let server_key = hmac salted_pw "Server Key" in let server_signature = hmac server_key cs.cs_auth_message in if v <> server_signature then raise Invalid_server_signature; cs.cs_state <- `Connected; dlog "Client is authenticated" | `Error e -> cs.cs_state <- `Error; dlog (sprintf "Client got error token from server: %s" (string_of_server_error e)); raise(Server_error e) ) | _ -> failwith "Netmech_scram.client_recv_message" ) () let create_server_session profile auth = (* auth: called as: let (salted_pw, salt, i) = auth username *) { ss_profile = profile; ss_state = `Start; ss_c1 = None; ss_c1_raw = ""; ss_s1 = None; ss_s1_raw = ""; ss_cf = None; ss_cf_raw = ""; ss_sf = None; ss_authenticate_opt = Some auth; ss_spw = None; ss_err = None; ss_proto_key = None; } let server_emit_flag ss = match ss.ss_state with | `C1 | `CF -> true | _ -> false let server_recv_flag ss = match ss.ss_state with | `Start | `S1 -> true | _ -> false let server_finish_flag ss = ss.ss_state = `Connected let server_error_flag ss = ss.ss_state = `Error let server_protocol_key ss = ss.ss_proto_key let server_export ss = if not (server_finish_flag ss) then failwith "Netmech_scram.server_export: context not yet established"; Marshal.to_string { ss with ss_authenticate_opt = None } [] let server_import s = ( Marshal.from_string s 0 : server_session) let catch_condition ss f arg = let debug e = dlog (sprintf "Server caught error: %s" (Netexn.to_string e)) in try f arg with (* After such an error the protocol will continue, but the final server message will return the condition *) | Invalid_encoding(_,_) as e -> debug e; if ss.ss_err = None then ss.ss_err <- Some `Invalid_encoding | Invalid_username_encoding _ as e -> debug e; if ss.ss_err = None then ss.ss_err <- Some `Invalid_username_encoding | Extensions_not_supported(_,_) as e -> debug e; if ss.ss_err = None then ss.ss_err <- Some `Extensions_not_supported | Invalid_proof _ as e -> debug e; if ss.ss_err = None then ss.ss_err <- Some `Invalid_proof exception Skip_proto let server_emit_message ss = match ss.ss_state with | `C1 -> let m = try let c1 = match ss.ss_c1 with | None -> raise Skip_proto | Some c1 -> c1 in let (spw, salt, i) = match ss.ss_authenticate_opt with | Some auth -> auth c1.c1_username | None -> assert false in let s1 = { s1_nonce = c1.c1_nonce ^ create_nonce(); s1_salt = salt; s1_iteration_count = i; s1_extensions = [] } in ss.ss_state <- `S1; ss.ss_s1 <- Some s1; ss.ss_spw <- Some spw; let s1 = encode_s1_message s1 in ss.ss_s1_raw <- s1; s1 with Not_found | Skip_proto -> (* continue with a dummy auth *) dlog "Server does not know this user"; let c1_nonce = match ss.ss_c1 with | None -> create_nonce() | Some c1 -> c1.c1_nonce in let s1 = { s1_nonce = c1_nonce ^ create_nonce(); s1_salt = create_nonce(); s1_iteration_count = 4096; s1_extensions = [] } in ss.ss_state <- `S1; ss.ss_s1 <- Some s1; if ss.ss_err = None then ss.ss_err <- Some (if ss.ss_profile.return_unknown_user then `Unknown_user else `Invalid_proof); (* This will keep the client off being successful *) let s1 = encode_s1_message s1 in ss.ss_s1_raw <- s1; s1 in dlog (sprintf "Server state `C1 emitting message: %s" m); m | `CF -> ( match ss.ss_err with | Some err -> let sf = { sf_error_or_verifier = `Error err; sf_extensions = [] } in ss.ss_sf <- Some sf; ss.ss_state <- `Error; let m = encode_sf_message sf in dlog (sprintf "Server state `CF[Err] emitting message: %s" m); m | None -> let spw = match ss.ss_spw with | None -> assert false | Some spw -> spw in let cf_no_proof = strip_cf_proof ss.ss_cf_raw in let auth_message = ss.ss_c1_raw ^ "," ^ ss.ss_s1_raw ^ "," ^ cf_no_proof in let server_key = hmac spw "Server Key" in let server_signature = hmac server_key auth_message in let sf = { sf_error_or_verifier = `Verifier server_signature; sf_extensions = [] } in ss.ss_sf <- Some sf; ss.ss_state <- `Connected; let m = encode_sf_message sf in dlog (sprintf "Server state `CF emitting message: %s" m); m ) | _ -> failwith "Netmech_scram.server_emit_message" let server_recv_message ss message = match ss.ss_state with | `Start -> dlog (sprintf "Server state `Start receiving message: %s" message); catch_condition ss (fun () -> let c1 = decode_c1_message message in ss.ss_c1 <- Some c1; ) (); ss.ss_c1_raw <- message; ss.ss_state <- `C1 (* Username is checked later *) | `S1 -> dlog (sprintf "Server state `S1 receiving message: %s" message); catch_condition ss (fun () -> try let s1 = match ss.ss_s1 with | None -> raise Skip_proto | Some s1 -> s1 in let salted_pw = match ss.ss_spw with | None -> raise Skip_proto | Some spw -> spw in let cf = decode_cf_message true message in if s1.s1_nonce <> cf.cf_nonce then raise (Invalid_proof "nonce mismatch"); let client_key = hmac salted_pw "Client Key" in let stored_key = sha1 client_key in let cf_no_proof = strip_cf_proof message in let auth_message = ss.ss_c1_raw ^ "," ^ ss.ss_s1_raw ^ "," ^ cf_no_proof in let client_signature = hmac stored_key auth_message in let p = Netauth.xor_s client_key client_signature in if Some p <> cf.cf_proof then raise (Invalid_proof "bad client signature"); ss.ss_cf <- Some cf; ss.ss_proto_key <- Some ( lsb128 (hmac stored_key ("GSS-API session key" ^ client_key ^ auth_message))); with | Skip_proto -> () ) (); ss.ss_cf_raw <- message; ss.ss_state <- `CF | _ -> failwith "Netmech_scram.server_recv_message" let server_channel_binding ss = match ss.ss_cf with | None -> None | Some cf -> Some(cf.cf_chanbind) let server_user_name ss = match ss.ss_c1 with | None -> None | Some c1 -> Some c1.c1_username let transform_mstrings (trafo:Cryptokit.transform) ms_list = (* Like Cryptokit's transform_string, but for "mstring list" *) let blen = 256 in let s = String.create blen in let rec loop in_list out_list = match in_list with | ms :: in_list' -> let ms_len = ms#length in ( match ms#preferred with | `String -> let (s,start) = ms#as_string in trafo#put_substring s start ms_len; if trafo#available_output > 0 then let o = trafo#get_string in let ms' = Xdr_mstring.string_to_mstring o in loop in_list' (ms' :: out_list) else loop in_list' out_list | `Memory -> let (m,start) = ms#as_memory in let k = ref 0 in let ol = ref out_list in while !k < ms_len do let n = min blen (ms_len - !k) in Netsys_mem.blit_memory_to_string m (start + !k) s 0 n; trafo#put_substring s 0 n; k := !k + n; if trafo#available_output > 0 then ( let o = trafo#get_string in let ms' = Xdr_mstring.string_to_mstring o in ol := ms' :: !ol; ) done; loop in_list' !ol ) | [] -> trafo # finish; let out_list' = if trafo#available_output > 0 then let o = trafo#get_string in let ms' = Xdr_mstring.string_to_mstring o in ms' :: out_list else out_list in List.rev out_list' in loop ms_list [] let hash_mstrings (hash:Cryptokit.hash) ms_list = (* Like Cryptokit's hash_string, but for "mstring list" *) let blen = 1024 in let s = String.create blen in let rec loop in_list = match in_list with | ms :: in_list' -> let ms_len = ms#length in ( match ms#preferred with | `String -> let (s,start) = ms#as_string in hash#add_substring s start ms_len; loop in_list' | `Memory -> let (m,start) = ms#as_memory in let k = ref 0 in while !k < ms_len do let n = min blen (ms_len - !k) in Netsys_mem.blit_memory_to_string m (start + !k) s 0 n; hash#add_substring s 0 n; k := !k + n; done; loop in_list' ) | [] -> hash#result in loop ms_list let hmac_sha1_mstrings key ms_list = let h = Cryptokit.MAC.hmac_sha1 key in hash_mstrings h ms_list (* Encryption for GSS-API *) module AES_CTS = struct (* FIXME: avoid copying strings all the time *) let c = 128 (* bits *) let m = 1 (* byte *) let encrypt key s = (* AES with CTS as defined in RFC 3962, section 5. It is a bit unclear why the RFC uses CTS because the upper layer already ensures that s consists of a whole number of cipher blocks *) let l = String.length s in if l <= 16 then ( (* Corner case: exactly one AES block of 128 bits or less *) let cipher = Cryptokit.Cipher.aes ~mode:Cryptokit.Cipher.ECB ~pad:Cryptokit.Padding.length (* any padding is ok here *) key Cryptokit.Cipher.Encrypt in Cryptokit.transform_string cipher s ) else ( (* Cipher-text stealing, also see http://en.wikipedia.org/wiki/Ciphertext_stealing http://www.wordiq.com/definition/Ciphertext_stealing *) (* Cryptokit's padding feature is unusable here *) let m = l mod 16 in let s_padded = if m = 0 then s else s ^ String.make (16-m) '\000' in let cipher = Cryptokit.Cipher.aes ~mode:Cryptokit.Cipher.CBC key Cryptokit.Cipher.Encrypt in let u = Cryptokit.transform_string cipher s_padded in let ulen = String.length u in assert(ulen >= 32 && ulen mod 16 = 0); let v = String.sub u (ulen-16) 16 in String.blit u (ulen-32) u (ulen-16) 16; String.blit v 0 u (ulen-32) 16; String.sub u 0 l ) let encrypt_mstrings key ms_list = (* Exactly the same, but we get input as "mstring list" and return output in the same way *) let l = Xdr_mstring.length_mstrings ms_list in if l <= 16 then ( let cipher = Cryptokit.Cipher.aes ~mode:Cryptokit.Cipher.ECB ~pad:Cryptokit.Padding.length (* any padding is ok here *) key Cryptokit.Cipher.Encrypt in transform_mstrings cipher ms_list ) else ( let m = l mod 16 in let ms_padded = if m=0 then ms_list else ms_list @ [ Xdr_mstring.string_to_mstring (String.make (16-m) '\000') ] in let cipher = Cryptokit.Cipher.aes ~mode:Cryptokit.Cipher.CBC key Cryptokit.Cipher.Encrypt in let u = transform_mstrings cipher ms_padded in let ulen = Xdr_mstring.length_mstrings u in assert(ulen >= 32 && ulen mod 16 = 0); let u0 = Xdr_mstring.shared_sub_mstrings u 0 (ulen-32) in let u1 = Xdr_mstring.shared_sub_mstrings u (ulen-32) 16 in let u2 = Xdr_mstring.shared_sub_mstrings u (ulen-16) 16 in let u' = u0 @ u2 @ u1 in Xdr_mstring.shared_sub_mstrings u' 0 l ) let decrypt key s = let l = String.length s in if l <= 16 then ( if l <> 16 then invalid_arg "Netmech_scram.AES256_CTS: bad length of plaintext"; let cipher = Cryptokit.Cipher.aes ~mode:Cryptokit.Cipher.ECB key Cryptokit.Cipher.Decrypt in Cryptokit.transform_string cipher s (* This string is still padded! *) ) else ( let k_last = ((l - 1) / 16) * 16 in let k_last_len = l - k_last in let k_second_to_last = k_last - 16 in let dn_cipher = Cryptokit.Cipher.aes ~mode:Cryptokit.Cipher.ECB key Cryptokit.Cipher.Decrypt in let c_2nd_to_last = String.sub s k_second_to_last 16 in let dn = Cryptokit.transform_string dn_cipher c_2nd_to_last in let cn = (String.sub s k_last k_last_len) ^ (String.sub dn k_last_len (16 - k_last_len)) in let u = String.create (k_last+16) in String.blit s 0 u 0 k_second_to_last; String.blit cn 0 u k_second_to_last 16; String.blit c_2nd_to_last 0 u k_last 16; let cipher = Cryptokit.Cipher.aes ~mode:Cryptokit.Cipher.CBC key Cryptokit.Cipher.Decrypt in let v = Cryptokit.transform_string cipher u in String.sub v 0 l ) let decrypt_mstrings key ms_list = let l = Xdr_mstring.length_mstrings ms_list in if l <= 16 then ( if l <> 16 then invalid_arg "Netmech_scram.AES256_CTS: bad length of plaintext"; let cipher = Cryptokit.Cipher.aes ~mode:Cryptokit.Cipher.ECB key Cryptokit.Cipher.Decrypt in transform_mstrings cipher ms_list (* This string is still padded! *) ) else ( let k_last = ((l - 1) / 16) * 16 in let k_last_len = l - k_last in let k_second_to_last = k_last - 16 in let dn_cipher = Cryptokit.Cipher.aes ~mode:Cryptokit.Cipher.ECB key Cryptokit.Cipher.Decrypt in let c_2nd_to_last = Xdr_mstring.shared_sub_mstrings ms_list k_second_to_last 16 in let dn = transform_mstrings dn_cipher c_2nd_to_last in let cn0 = Xdr_mstring.shared_sub_mstrings ms_list k_last k_last_len in let cn1 = Xdr_mstring.shared_sub_mstrings dn k_last_len (16-k_last_len) in let cn = cn0 @ cn1 in let s0 = Xdr_mstring.shared_sub_mstrings ms_list 0 k_second_to_last in let u = s0 @ cn @ c_2nd_to_last in let cipher = Cryptokit.Cipher.aes ~mode:Cryptokit.Cipher.CBC key Cryptokit.Cipher.Decrypt in let v = transform_mstrings cipher u in Xdr_mstring.shared_sub_mstrings v 0 l ) (* Test vectors from the RFC (for 128 bit AES): *) let k_128 = "\x63\x68\x69\x63\x6b\x65\x6e\x20\x74\x65\x72\x69\x79\x61\x6b\x69" let v1_in = "\x49\x20\x77\x6f\x75\x6c\x64\x20\x6c\x69\x6b\x65\x20\x74\x68\x65\x20" let v1_out = "\xc6\x35\x35\x68\xf2\xbf\x8c\xb4\xd8\xa5\x80\x36\x2d\xa7\xff\x7f\x97" let v2_in = "\x49\x20\x77\x6f\x75\x6c\x64\x20\x6c\x69\x6b\x65\x20\x74\x68\x65\x20\ \x47\x65\x6e\x65\x72\x61\x6c\x20\x47\x61\x75\x27\x73\x20" let v2_out = "\xfc\x00\x78\x3e\x0e\xfd\xb2\xc1\xd4\x45\xd4\xc8\xef\xf7\xed\x22\ \x97\x68\x72\x68\xd6\xec\xcc\xc0\xc0\x7b\x25\xe2\x5e\xcf\xe5" let v3_in = "\x49\x20\x77\x6f\x75\x6c\x64\x20\x6c\x69\x6b\x65\x20\x74\x68\x65\ \x20\x47\x65\x6e\x65\x72\x61\x6c\x20\x47\x61\x75\x27\x73\x20\x43" let v3_out = "\x39\x31\x25\x23\xa7\x86\x62\xd5\xbe\x7f\xcb\xcc\x98\xeb\xf5\xa8\ \x97\x68\x72\x68\xd6\xec\xcc\xc0\xc0\x7b\x25\xe2\x5e\xcf\xe5\x84" let v4_in = "\x49\x20\x77\x6f\x75\x6c\x64\x20\x6c\x69\x6b\x65\x20\x74\x68\x65\ \x20\x47\x65\x6e\x65\x72\x61\x6c\x20\x47\x61\x75\x27\x73\x20\x43\ \x68\x69\x63\x6b\x65\x6e\x2c\x20\x70\x6c\x65\x61\x73\x65\x2c" let v4_out = "\x97\x68\x72\x68\xd6\xec\xcc\xc0\xc0\x7b\x25\xe2\x5e\xcf\xe5\x84\ \xb3\xff\xfd\x94\x0c\x16\xa1\x8c\x1b\x55\x49\xd2\xf8\x38\x02\x9e\ \x39\x31\x25\x23\xa7\x86\x62\xd5\xbe\x7f\xcb\xcc\x98\xeb\xf5" let v5_in = "\x49\x20\x77\x6f\x75\x6c\x64\x20\x6c\x69\x6b\x65\x20\x74\x68\x65\ \x20\x47\x65\x6e\x65\x72\x61\x6c\x20\x47\x61\x75\x27\x73\x20\x43\ \x68\x69\x63\x6b\x65\x6e\x2c\x20\x70\x6c\x65\x61\x73\x65\x2c\x20" let v5_out = "\x97\x68\x72\x68\xd6\xec\xcc\xc0\xc0\x7b\x25\xe2\x5e\xcf\xe5\x84\ \x9d\xad\x8b\xbb\x96\xc4\xcd\xc0\x3b\xc1\x03\xe1\xa1\x94\xbb\xd8\ \x39\x31\x25\x23\xa7\x86\x62\xd5\xbe\x7f\xcb\xcc\x98\xeb\xf5\xa8" let v6_in = "\x49\x20\x77\x6f\x75\x6c\x64\x20\x6c\x69\x6b\x65\x20\x74\x68\x65\ \x20\x47\x65\x6e\x65\x72\x61\x6c\x20\x47\x61\x75\x27\x73\x20\x43\ \x68\x69\x63\x6b\x65\x6e\x2c\x20\x70\x6c\x65\x61\x73\x65\x2c\x20\ \x61\x6e\x64\x20\x77\x6f\x6e\x74\x6f\x6e\x20\x73\x6f\x75\x70\x2e" let v6_out = "\x97\x68\x72\x68\xd6\xec\xcc\xc0\xc0\x7b\x25\xe2\x5e\xcf\xe5\x84\ \x39\x31\x25\x23\xa7\x86\x62\xd5\xbe\x7f\xcb\xcc\x98\xeb\xf5\xa8\ \x48\x07\xef\xe8\x36\xee\x89\xa5\x26\x73\x0d\xbc\x2f\x7b\xc8\x40\ \x9d\xad\x8b\xbb\x96\xc4\xcd\xc0\x3b\xc1\x03\xe1\xa1\x94\xbb\xd8" let tests = [ k_128, v1_in, v1_out; k_128, v2_in, v2_out; k_128, v3_in, v3_out; k_128, v4_in, v4_out; k_128, v5_in, v5_out; k_128, v6_in, v6_out; ] let run_tests() = List.for_all (fun (k, v_in, v_out) -> encrypt k v_in = v_out && decrypt k v_out = v_in ) tests let run_mtests() = let j = ref 1 in List.for_all (fun (k, v_in, v_out) -> prerr_endline("Test: " ^ string_of_int !j); let v_in_ms = Xdr_mstring.string_to_mstring v_in in let v_out_ms = Xdr_mstring.string_to_mstring v_out in let e = Xdr_mstring.concat_mstrings (encrypt_mstrings k [v_in_ms]) in prerr_endline " enc ok"; let d = Xdr_mstring.concat_mstrings (decrypt_mstrings k [v_out_ms]) in prerr_endline " dec ok"; incr j; e = v_out && d = v_in ) tests end module Cryptosystem = struct (* RFC 3961 section 5.3 *) module C = AES_CTS (* Cipher *) module I = struct (* Integrity *) let hmac = hmac (* hmac-sha1 *) let hmac_mstrings = hmac_sha1_mstrings let h = 12 end exception Integrity_error let derive_keys protocol_key usage = let k = 8 * String.length protocol_key in if k <> 128 && k <> 256 then invalid_arg "Netmech_scram.Cryptosystem.derive_keys"; let derive kt = Netauth.derive_key_rfc3961_simplified ~encrypt:(C.encrypt protocol_key) ~random_to_key:(fun s -> s) ~block_size:C.c ~k ~usage ~key_type:kt in { kc = derive `Kc; ke = derive `Ke; ki = derive `Ki; } let rec identity x = x let encrypt_and_sign s_keys message = let c_bytes = C.c/8 in let conf = String.make c_bytes '\000' in Netsys_rng.fill_random conf; let l = String.length message in let p = (l + c_bytes) mod (identity C.m) in (* Due to a bug in the ARM code generator, avoid "... mod 1" *) let pad = if p = 0 then "" else String.make (C.m - p) '\000' in let p1 = conf ^ message ^ pad in let c1 = C.encrypt s_keys.ke p1 in let h1 = I.hmac s_keys.ki p1 in c1 ^ String.sub h1 0 I.h let encrypt_and_sign_mstrings s_keys message = let c_bytes = C.c/8 in let conf = String.make c_bytes '\000' in Netsys_rng.fill_random conf; let l = Xdr_mstring.length_mstrings message in let p = (l + c_bytes) mod C.m in let pad = if p = 0 then "" else String.make (C.m - p) '\000' in let p1 = ( ( Xdr_mstring.string_to_mstring conf ) :: message ) @ [ Xdr_mstring.string_to_mstring pad ] in let c1 = C.encrypt_mstrings s_keys.ke p1 in let h1 = I.hmac_mstrings s_keys.ki p1 in c1 @ [ Xdr_mstring.string_to_mstring(String.sub h1 0 I.h) ] let decrypt_and_verify s_keys ciphertext = let c_bytes = C.c/8 in let l = String.length ciphertext in if l < I.h then invalid_arg "Netmech_scram.Cryptosystem.decrypt_and_verify"; let c1 = String.sub ciphertext 0 (l - I.h) in let h1 = String.sub ciphertext (l - I.h) I.h in let p1 = C.decrypt s_keys.ke c1 in let h1' = String.sub (I.hmac s_keys.ki p1) 0 I.h in if h1 <> h1' then raise Integrity_error; let q = String.length p1 in if q < c_bytes then raise Integrity_error; String.sub p1 c_bytes (q-c_bytes) (* This includes any padding or residue from the lower layer! *) let decrypt_and_verify_mstrings s_keys ciphertext = let c_bytes = C.c/8 in let l = Xdr_mstring.length_mstrings ciphertext in if l < I.h then invalid_arg "Netmech_scram.Cryptosystem.decrypt_and_verify"; let c1 = Xdr_mstring.shared_sub_mstrings ciphertext 0 (l - I.h) in let h1 = Xdr_mstring.concat_mstrings (Xdr_mstring.shared_sub_mstrings ciphertext (l - I.h) I.h) in let p1 = C.decrypt_mstrings s_keys.ke c1 in let h1' = String.sub (I.hmac_mstrings s_keys.ki p1) 0 I.h in if h1 <> h1' then raise Integrity_error; let q = Xdr_mstring.length_mstrings p1 in if q < c_bytes then raise Integrity_error; Xdr_mstring.shared_sub_mstrings p1 c_bytes (q-c_bytes) (* This includes any padding or residue from the lower layer! *) let get_ec s_keys n = if n < 16 then invalid_arg "Netmech_scram.Cryptosystem.get_ec"; 0 let get_mic s_keys message = String.sub (I.hmac s_keys.kc message) 0 I.h let get_mic_mstrings s_keys message = String.sub (I.hmac_mstrings s_keys.kc message) 0 I.h end