-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathwsproxy.ml
More file actions
156 lines (146 loc) · 5.58 KB
/
wsproxy.ml
File metadata and controls
156 lines (146 loc) · 5.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
(*
* Copyright (C) Citrix Systems Inc.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published
* by the Free Software Foundation; version 2.1 only. with the special
* exception on linking described in file LICENSE.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*)
open Wslib
module LwtWsIteratee = Wslib.Websockets.Wsprotocol (Lwt)
open Lwt.Infix
let with_fd = Lwt_support.with_fd
let start handler =
Logs_lwt.info (fun m -> m "Starting wsproxy") >>= fun () ->
let fd_sock = Lwt_unix.stdin in
let () = Lwt_unix.listen fd_sock 5 in
let rec loop () =
let ensure_close = function
| [] ->
Lwt.return_unit
| fds ->
Logs_lwt.warn (fun m -> m "Closing %d excess fds" (List.length fds))
>>= fun () ->
List.iter (fun fd -> try Unix.close fd with _ -> ()) fds ;
Lwt.return_unit
in
Lwt.catch
(fun () ->
Lwt_unix.accept fd_sock >>= fun (fd_sock', _) ->
(* Background thread per connection *)
let (_ : unit Lwt.t) =
let buffer = Bytes.make 16384 '\000' in
with_fd fd_sock' ~callback:(fun fd ->
let io_vectors = Lwt_unix.IO_vectors.create () in
Lwt_unix.IO_vectors.append_bytes io_vectors buffer 0 16384 ;
Lwt_unix.Versioned.recv_msg_2 ~socket:fd ~io_vectors)
>>= fun (len, newfds) ->
match newfds with
| [] ->
Logs_lwt.warn (fun m ->
m "No fd to start a connection: not proxying")
| ufd :: ufds ->
ensure_close ufds >>= fun () ->
with_fd (Lwt_unix.of_unix_file_descr ufd) ~callback:(fun fd ->
Logs_lwt.debug (fun m -> m "About to start connection")
>>= fun () ->
Lwt_unix.setsockopt fd Lwt_unix.SO_KEEPALIVE true ;
let msg = Bytes.(to_string @@ sub buffer 0 len) in
handler fd msg)
in
loop ())
(fun e ->
Logs_lwt.err (fun m -> m "Caught exception: %s" (Printexc.to_string e))
>>= fun () -> Lwt.return_unit)
>>= fun () -> loop ()
in
with_fd fd_sock ~callback:(fun _ -> loop ())
let proxy (fd : Lwt_unix.file_descr) addr protocol =
let open LwtWsIteratee in
let open Lwt_support in
( match protocol with
| "hixie76" ->
Logs_lwt.debug (fun m -> m "Old-style (hixie76) protocol") >>= fun () ->
Lwt.return (wsframe_old, wsunframe_old)
| "hybi10" ->
Logs_lwt.debug (fun m -> m "New-style (hybi10) protocol") >>= fun () ->
Lwt.return (wsframe, wsunframe)
| _ ->
Logs_lwt.warn (fun m -> m "Unknown protocol, fallback to hybi10")
>>= fun () -> Lwt.return (wsframe, wsunframe)
)
>>= fun (frame, unframe) ->
with_open_connection_fd addr ~callback:(fun localfd ->
let session_id = Uuidm.v `V4 |> Uuidm.to_string in
Logs_lwt.debug (fun m -> m "Starting proxy session %s" session_id)
>>= fun () ->
let thread1 =
lwt_fd_enumerator localfd (frame (writer (really_write fd) "thread1"))
>>= fun _ -> Lwt.return_unit
in
let thread2 =
lwt_fd_enumerator fd (unframe (writer (really_write localfd) "thread2"))
>>= fun _ -> Lwt.return_unit
in
(* closing the connection in one of the threads above in general leaves the other pending forever,
* by using choose here, we make sure that as soon as one of the threads completes, both are closed *)
Lwt.choose [thread1; thread2] >>= fun () ->
Logs_lwt.debug (fun m -> m "Closing proxy session %s" session_id))
module RX = struct
let socket = Re.Str.regexp "^/var/run/xen/vnc-[0-9]+$"
let port = Re.Str.regexp "^[0-9]+$"
end
let handler sock msg =
Logs_lwt.debug (fun m -> m "Got msg: '%s'" msg) >>= fun () ->
match Re.Str.(split @@ regexp "[:]") msg with
| ([protocol; _; path] | [protocol; path])
when Re.Str.string_match RX.socket path 0 ->
let addr = Unix.ADDR_UNIX path in
proxy sock addr protocol
| ([protocol; _; sport] | [protocol; sport])
when Re.Str.string_match RX.port sport 0 ->
let localhost = Unix.inet_addr_loopback in
let addr = Unix.ADDR_INET (localhost, int_of_string sport) in
proxy sock addr protocol
| _ ->
Logs_lwt.warn (fun m ->
m "The message '%s' is malformed: not proxying" msg)
(* Reporter taken from
* https://erratique.ch/software/logs/doc/Logs_lwt/index.html#report_ex
* under ISC License *)
let lwt_reporter () =
let buf_fmt ~like =
let b = Buffer.create 512 in
( Fmt.with_buffer ~like b
, fun () ->
let m = Buffer.contents b in
Buffer.reset b ; m )
in
let app, app_flush = buf_fmt ~like:Fmt.stdout in
let dst, dst_flush = buf_fmt ~like:Fmt.stderr in
let reporter = Logs_fmt.reporter ~app ~dst () in
let report src level ~over k msgf =
let k () =
let write () =
match level with
| Logs.App ->
Lwt_io.write Lwt_io.stdout (app_flush ())
| _ ->
Lwt_io.write Lwt_io.stderr (dst_flush ())
in
let unblock () = over () ; Lwt.return_unit in
Lwt.finalize write unblock |> Lwt.ignore_result ;
k ()
in
reporter.Logs.report src level ~over:(fun () -> ()) k msgf
in
{Logs.report}
let _ =
Logs.set_reporter (lwt_reporter ()) ;
Logs.set_level ~all:true (Some Logs.Info) ;
Lwt_main.run (start handler)