1 /**
2  * Copyright: Copyright Jason White, 2014-2016
3  * License:   $(WEB boost.org/LICENSE_1_0.txt, Boost License 1.0).
4  * Authors:   Thayne McCombs
5  */
6 module io.socket.stream;
7 
8 import io.stream;
9 import std.socket;
10 
11 /**
12  * A wrapper around a socket that provides stream functionality without buffering.
13  */
14 struct UnbufferedSocketStreamBase
15 {
16     /**
17      * Create a new Stream from an existing Socket.
18      *
19      * Params:
20      *   socket = The socket to make a stream from.
21      *
22      */
23     this(Socket socket)
24     {
25         _socket = socket;
26     }
27 
28     unittest
29     {
30         auto pair = socketPair();
31         auto sock = UnbufferedSocketStream(pair[0]);
32         assert(sock.isOpen);
33 
34         immutable ubyte[] data = [1,2,3,4,5];
35         ubyte[10] buff;
36 
37         sock.write(data);
38 
39         assert(pair[1].receive(buff) == 5);
40         assert(buff[0..5] == data);
41     }
42 
43     /**
44      * Create a new SocketStream that is connected to `address` as
45      * a client. The socket is created in streaming mode (obviously).
46      *
47      * Params:
48      *   address = The address to connect to.
49      */
50     this(Address address)
51     {
52         _socket = new Socket(address.addressFamily, SocketType.STREAM);
53         _socket.connect(address);
54     }
55 
56     unittest
57     {
58         auto server = new TcpSocket(AddressFamily.INET);
59         server.bind(new InternetAddress("localhost", InternetAddress.PORT_ANY));
60         server.listen(10);
61 
62         auto sock = SocketStream(server.localAddress);
63         assert(sock.remoteAddress == server.localAddress);
64         assert(sock.isOpen);
65     }
66 
67 
68     /**
69      * Copying is disabled, because reference counting should be used instead.
70      */
71     @disable this(this);
72 
73     /**
74      * Returns true if the socket is alive.
75      */
76     @property isOpen() @safe
77     {
78         return _socket && _socket.isAlive;
79     }
80 
81     /**
82      * Returns the underlying Socket.
83      */
84     @property socket() @safe
85     {
86         return _socket;
87     }
88 
89     /**
90      * Reads data from the socket.
91      *
92      * Params:
93      *   buf = The buffer to read the data into. The length of the buffer
94      *         specifies how much data should be read.
95      *
96      * Returns: The number of bytes that were read. If the socket is blocking
97      *          wait for data to be available. If the remote side has closed
98      *          the connection 0 is returned.
99      *
100      * Throws: SocketException on failure.
101      */
102     size_t read(scope ubyte[] buf) @safe
103     in { assert(isOpen); }
104     body
105     {
106         immutable n = _socket.receive(buf);
107         socketEnforce(n != Socket.ERROR, "Failed to read from socket");
108         return n;
109     }
110 
111     unittest
112     {
113         auto pair = socketPair();
114         auto sock = UnbufferedSocketStream(pair[0]);
115         immutable ubyte[] data = [10,20,30,40];
116         pair[1].send(data);
117         ubyte[10] buff;
118         assert(sock.read(buff) == 4);
119         assert(buff[0..4] == data);
120         pair[1].close();
121         assert(sock.read(buff) == 0);
122     }
123 
124     /**
125      * Writes data to the socket.
126      *
127      * Params:
128      *   data = The data to write to the file. The length of the slice indicates
129      *          how much data should be written.
130      *
131      * Returns: The number of bytes that were written.
132      *
133      * Throws: SocketException on failure.
134      */
135     size_t write(in ubyte[] data) @safe
136     in { assert(isOpen); }
137     body
138     {
139         immutable n = _socket.send(data);
140         socketEnforce(n != Socket.ERROR, "Failed to write to socket");
141         return n;
142     }
143 
144     unittest
145     {
146         auto pair = socketPair();
147         auto sock = UnbufferedSocketStream(pair[0]);
148         immutable ubyte[] data = [5,9,10];
149         sock.write(data);
150         ubyte[10] buff;
151         assert(pair[1].receive(buff) == 3);
152         assert(buff[0..3] == data);
153         sock.write(data);
154         sock.write(data);
155         assert(pair[1].receive(buff) == 6);
156         assert(buff[0..6] == data ~ data);
157     }
158 
159     /// ditto
160     alias put = write;
161 
162     /**
163      * If the Socket is open, shut down both directions and close.
164      * Otherwise, it does nothing.
165      */
166     void close() @safe
167     {
168         import std.socket : SocketShutdown;
169         if (isOpen)
170         {
171             _socket.shutdown(SocketShutdown.BOTH);
172             _socket.close();
173         }
174     }
175 
176     unittest
177     {
178         auto pair = socketPair();
179         auto sock = UnbufferedSocketStream(pair[0]);
180         sock.close();
181         assert(!sock.isOpen);
182         ubyte[1] buff;
183         assert(pair[1].receive(buff) == 0);
184     }
185 
186     /**
187      * Detach the socket from this socket stream and return it.
188      *
189      * The stream is closed after becoming detached.
190      * This can be used to avoid closing a socket when the stream is destroyed.
191      */
192     Socket detach() @safe
193     {
194         scope(success) { _socket = null; }
195         return _socket;
196     }
197 
198     /// Ditto
199     ~this()
200     {
201         close();
202     }
203 
204     alias _socket this;
205 
206 private:
207     Socket _socket;
208 }
209 
210 unittest
211 {
212     static assert(isSourceSink!UnbufferedSocketStreamBase);
213 }
214 
215 /**
216  * A stream that wraps a socket with buffered writes.
217  */
218 struct SocketStreamBase {
219     alias _stream this;
220 
221     @disable this(this);
222 
223     /**
224      * Forwards argument to UnbufferedSocketStreamBase
225      */
226     this(T...)(auto ref T args)
227     {
228         import std.functional : forward;
229         _stream = UnbufferedSocketStreamBase(forward!args);
230         _buffer.length = 8192;
231     }
232 
233 
234     /**
235      * Sets the size of the buffer. The default is 8192 bytes.
236      * If there is currently data in the buffer, it will be flushed.
237      */
238     @property void bufferSize(size_t size)
239     {
240         if (_pos > 0)
241         {
242             flush();
243         }
244         _buffer.length = size;
245     }
246 
247     /**
248      * Get the current buffer size. The default is 8192 bytes (8KB).
249      */
250     @property size_t bufferSize() const pure nothrow @nogc
251     {
252         return _buffer.length;
253     }
254 
255     /**
256      * Upon destruction, any pending writes are flushed
257      * to the socket.
258      */
259     ~this()
260     {
261         // don't use flush because we want to avoid throwing an error
262         _stream.socket.send(_buffer[0.._pos]);
263     }
264 
265     /**
266      * Writes any pending data to the socket.
267      */
268     void flush() @safe
269     {
270         if (_pos > 0)
271         {
272             _stream.write(_buffer[0.._pos]);
273             _pos = 0;
274         }
275     }
276 
277     /**
278      * Write data to the stream, but buffer input
279      * so that only sufficiently large packets are sent.
280      */
281     size_t write(in ubyte[] buf) @safe
282     {
283         immutable satisfied = writePartial(buf);
284         if (satisfied == buf.length)
285         {
286             return satisfied;
287         }
288 
289         const(ubyte)[] leftOver = buf[satisfied .. $];
290 
291         if (leftOver.length >= _buffer.length)
292         {
293             // leftOver is bigger than _buffer, write directly to socket
294             return satisfied + _stream.write(leftOver);
295         }
296         else
297         {
298             return satisfied + writePartial(leftOver);
299         }
300     }
301 
302     unittest
303     {
304         auto pair = socketPair();
305         auto sock = SocketStream(pair[0]);
306         auto other = pair[1];
307         other.blocking = false;
308         sock.bufferSize = 10;
309 
310         ubyte[] data = [1,2,3,4,5];
311         ubyte[20] buff;
312 
313         sock.write(data);
314         assert(other.receive(buff) == Socket.ERROR);
315         assert(wouldHaveBlocked());
316         sock.write(data);
317         assert(other.receive(buff) == 10);
318         assert(buff[0..10] == data ~ data);
319         sock.write(data);
320         sock.flush();
321         buff = 0;
322         assert(other.receive(buff) == 5);
323         assert(buff[0..5] == data);
324     }
325 
326     private size_t writePartial(in ubyte[] buf) @safe
327     {
328         import std.algorithm : min;
329         immutable satisfiable = min(_buffer.length - _pos, buf.length);
330         _buffer[_pos .. _pos + satisfiable] = buf[0 .. satisfiable];
331         _pos += satisfiable;
332 
333         if (_pos == _buffer.length)
334         {
335             // Buffer is full and there is more to write. Flush it.
336             flush();
337         }
338 
339         return satisfiable;
340     }
341 
342     /**
343      * Reads data from the socket.
344      * The stream itself doesn't buffer reading
345      * because the OS already buffers when receiving
346      * on a streaming socket.
347      */
348     size_t read(scope ubyte[] buf) @safe
349     {
350         return _stream.read(buf);
351     }
352 
353 
354 private:
355     UnbufferedSocketStreamBase _stream;
356     ubyte[] _buffer;
357     size_t _pos;
358 }
359 
360 unittest
361 {
362     static assert(isSourceSink!SocketStreamBase);
363 }
364 
365 import std.typecons : RefCounted, RefCountedAutoInitialize;
366 alias UnbufferedSocketStream = RefCounted!(StreamShim!UnbufferedSocketStreamBase, RefCountedAutoInitialize.no);
367 alias SocketStream = RefCounted!(StreamShim!SocketStreamBase, RefCountedAutoInitialize.no);
368 
369 unittest
370 {
371     static assert(isSourceSink!SocketStream);
372     static assert(isSourceSink!UnbufferedSocketStream);
373 }
374 
375 
376 /**
377  * Enforce that `check` is true, and throw a `SocketException` if it isn't.
378  */
379 void socketEnforce(string file = __FILE__, size_t line = __LINE__)(bool check, lazy string msg = null)
380 {
381     if (!check)
382     {
383         throw new SocketOSException(msg, file, line);
384     }
385 }
386 
387 /**
388  * Call `accept` on the socket and return the result as a `SocketStream`.
389  */
390 SocketStream acceptStream(Socket socket)
391 {
392     return SocketStream(socket.accept());
393 }